HMLP: High-performance Machine Learning Primitives
gnbx.hpp
1 
23 #ifndef GNBX_HPP
24 #define GNBX_HPP
25 
26 #include <assert.h>
27 #include <typeinfo>
28 #include <algorithm>
29 
30 #include <hmlp.h>
31 #include <hmlp_internal.hpp>
32 #include <hmlp_base.hpp>
33 
35 //#include <primitives/strassen.hpp>
36 
38 #include <packing.hpp>
39 #include <semiring_mrxnr.hpp>
40 #include <fused_mrxnr.hpp>
41 
42 using namespace std;
43 
44 
45 namespace hmlp
46 {
47 namespace gnbx
48 {
49 
55 template<int KC, typename SEMIRINGKERNEL, typename TA, typename TB, typename TV>
56 void rank_k_macro_kernel
57 (
58  Worker &Comm4th,
59  int ic, int jc, int pc,
60  int m, int n, int k,
61  TA *packA,
62  TB *packB,
63  TV *V, int rs_v, int cs_v,
64  SEMIRINGKERNEL semiringkernel
65 )
66 {
68  const static int MR = SEMIRINGKERNEL::mr;
69  const static int NR = SEMIRINGKERNEL::nr;
70  const static int PACK_MR = SEMIRINGKERNEL::pack_mr;
71  const static int PACK_NR = SEMIRINGKERNEL::pack_nr;
72 
74  thread_communicator &ic_comm = *Comm4th.comm;
75 
77  auto Loop3rd = Comm4th.DistributeOver1DGangs( 0, n, NR );
78  auto Pack3rd = Comm4th.DistributeOver1DGangs( 0, n, PACK_NR );
79  auto Loop2nd = Comm4th.DistributeOver1DThreads( 0, m, MR );
80  auto Pack2nd = Comm4th.DistributeOver1DThreads( 0, m, PACK_MR );
81 
83  for ( int j = get<0>( Loop3rd ), jp = get<0>( Pack3rd );
84  j < get<1>( Loop3rd );
85  j += get<2>( Loop3rd ), jp += get<2>( Pack3rd ) )
86  {
87  struct aux_s<TA, TB, TV, TV> aux;
88  aux.pc = pc;
89  aux.b_next = packB;
90  aux.do_packC = 0;
91  aux.jb = std::min( n - j, NR );
92 
94  for ( int i = get<0>( Loop2nd ), ip = get<0>( Pack2nd );
95  i < get<1>( Loop2nd );
96  i += get<2>( Loop2nd ), ip += get<2>( Pack2nd ) )
97  {
98  aux.ib = std::min( m - i, MR );
99  if ( i + MR >= m )
100  {
101  aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
102  }
103 
104  if ( aux.jb == NR && aux.ib == MR )
105  {
106  semiringkernel
107  (
108  k,
109  &packA[ ip * k ],
110  &packB[ jp * k ],
111  &V[ i * rs_v + j * cs_v ], rs_v, cs_v,
112  &aux
113  );
114  }
115  else // corner case
116  {
117  TV vtmp[ MR * NR ];
118 
119  if ( pc ) // initilize ctmp
120  {
121  for ( auto jj = 0; jj < aux.jb; jj ++ )
122  for ( auto ii = 0; ii < aux.ib; ii ++ )
123  vtmp[ jj * MR + ii ] =
124  V[ ( j + jj ) * cs_v + ( i + ii ) * rs_v ];
125  }
126 
127  semiringkernel
128  (
129  k,
130  &packA[ ip * k ],
131  &packB[ jp * k ],
132  vtmp, 1, MR,
133  &aux
134  );
135 
136  for ( auto jj = 0; jj < aux.jb; jj ++ )
137  for ( auto ii = 0; ii < aux.ib; ii ++ )
138  V[ ( j + jj ) * cs_v + ( i + ii ) * rs_v ] = vtmp[ jj * MR + ii ];
139  }
140  } // end 2nd loop
141  } // end 3rd loop
142 }; // end rank_k_macro_kernel
143 
144 
145 
146 
147 
154 template<int KC, typename FUSEDKERNEL, typename TA, typename TB, typename TC, typename TV>
155 void fused_macro_kernel
156 (
157  Worker &Comm4th,
158  int m, int n,
159  int ic, int jc, int pc,
160  int mc, int nc, int kc,
161  TA *packA,
162  TB *packB,
163  TC *C,
164  TV *V, int rs_v, int cs_v,
165  int batchId,
166  FUSEDKERNEL fusedkernel
167 )
168 {
170  const static int MR = FUSEDKERNEL::mr;
171  const static int NR = FUSEDKERNEL::nr;
172  const static int PACK_MR = FUSEDKERNEL::pack_mr;
173  const static int PACK_NR = FUSEDKERNEL::pack_nr;
174 
176  thread_communicator &ic_comm = *Comm4th.comm;
177 
179  auto Loop3rd = Comm4th.DistributeOver1DGangs( 0, nc, NR );
180  auto Pack3rd = Comm4th.DistributeOver1DGangs( 0, nc, PACK_NR );
181  auto Loop2nd = Comm4th.DistributeOver1DThreads( 0, mc, MR );
182  auto Pack2nd = Comm4th.DistributeOver1DThreads( 0, mc, PACK_MR );
183 
185  for ( int j = get<0>( Loop3rd ), jp = get<0>( Pack3rd );
186  j < get<1>( Loop3rd );
187  j += get<2>( Loop3rd ), jp += get<2>( Pack3rd ) )
188  {
189  struct aux_s<TA, TB, TC, TV> aux;
190  aux.pc = pc;
191  aux.b_next = packB;
192  aux.do_packC = 0;
193 
195  for ( int i = get<0>( Loop2nd ), ip = get<0>( Pack2nd );
196  i < get<1>( Loop2nd );
197  i += get<2>( Loop2nd ), ip += get<2>( Pack2nd ) )
198  {
203  aux.m = m;
204  aux.n = n;
205  aux.i = ic + i;
206  aux.j = jc + j;
207  aux.b = batchId;
208 
212  aux.ib = std::min( mc - i, MR );
213  aux.jb = std::min( nc - j, NR );
214 
218  aux.V = V + i * rs_v + j * cs_v;
219  aux.ldv = cs_v;
220 
221  if ( i + MR >= mc )
222  {
223  aux.b_next += ic_comm.GetNumThreads() * PACK_NR * kc;
224  }
225 
226  if ( aux.jb == NR && aux.ib == MR )
227  {
228  fusedkernel
229  (
230  kc,
231  &packA[ ip * kc ],
232  &packB[ jp * kc ],
233  C,
234  &V[ i * rs_v + j * cs_v ], rs_v, cs_v,
235  &aux
236  );
237  }
238  else
239  {
240  TV vtmp[ MR * NR ];
241  if ( pc ) // initilize ctmp
242  {
243  for ( auto jj = 0; jj < aux.jb; jj ++ )
244  for ( auto ii = 0; ii < aux.ib; ii ++ )
245  vtmp[ jj * MR + ii ] =
246  V[ ( j + jj ) * cs_v + ( i + ii ) * rs_v ];
247  aux.V = vtmp;
248  aux.ldv = MR;
249  }
250  fusedkernel
251  (
252  kc,
253  &packA[ ip * kc ],
254  &packB[ jp * kc ],
255  C,
256  vtmp, 1, MR,
257  &aux
258  );
259  }
260  }
261  }
262 
263 };
275 template<
276  int MC, int NC, int KC,
277  typename TPACKA, typename TPACKB, typename TV,
278  typename TA, typename TB, typename TC,
279  typename SEMIRINGKERNEL, typename MICROKERNEL>
280 void gnbx_internal
281 (
282  Worker &thread,
283  int batchId, int m, int n, int k, int k_stra,
284  TA& A,
285  TB& B,
286  TC& C,
287  TV* V, int rs_v, int cs_v,
288  SEMIRINGKERNEL semiringkernel,
289  MICROKERNEL microkernel
290 )
291 {
293  const static int MR = SEMIRINGKERNEL::mr;
294  const static int NR = SEMIRINGKERNEL::nr;
295  const static int PACK_MR = SEMIRINGKERNEL::pack_mr;
296  const static int PACK_NR = SEMIRINGKERNEL::pack_nr;
297  const static int ALIGN_SIZE = SEMIRINGKERNEL::align_size;
298  const static int PACK_MC = ( MC / MR ) * PACK_MR;
299  const static int PACK_NC = ( NC / NR ) * PACK_NR;
300 
302  auto CommGLB = thread.Split();
303  auto Comm6th = CommGLB.Split();
304  auto Comm5th = Comm6th.Split();
305  auto Comm4th = Comm5th.Split();
306 
307 
309  int nc = CommGLB.BalanceOver1DGangs( n, NC, NR );
310  int pack_nc = ( nc / NR ) * PACK_NR;
311 
312 
313 
314  //printf( "CommGLB %s tid %d gid %d ngangs %d\n", CommGLB.comm->name.data(), CommGLB.tid, CommGLB.gid, CommGLB.comm->GetNumGroups() );
315  //printf( "Comm6th %s tid %d gid %d ngangs %d\n", Comm6th.comm->name.data(), Comm6th.tid, Comm6th.gid, Comm6th.comm->GetNumGroups() );
316  //printf( "Comm5th %s tid %d gid %d ngangs %d\n", Comm5th.comm->name.data(), Comm5th.tid, Comm5th.gid, Comm5th.comm->GetNumGroups() );
317  //printf( "Comm4th %s tid %d gid %d ngangs %d\n", Comm4th.comm->name.data(), Comm4th.tid, Comm4th.gid, Comm4th.comm->GetNumGroups() );
318  //fflush( stdout );
319 
326  auto *packA = Comm4th.AllocateSharedMemory<ALIGN_SIZE, TPACKA>( KC * ( PACK_MC + 1 ) );
327  auto *packB = Comm5th.AllocateSharedMemory<ALIGN_SIZE, TPACKB>( KC * ( pack_nc + 1 ) );
328 
330  auto Loop6th = CommGLB.DistributeOver1DGangs( 0, n, nc );
331  auto Loop5th = Comm6th.DistributeOver1DGangs( k_stra, k, KC );
332  auto Loop4th = Comm5th.DistributeOver1DGangs( 0, m, MC );
333 
335  for ( int jc = get<0>( Loop6th );
336  jc < get<1>( Loop6th );
337  jc += get<2>( Loop6th ) )
338  {
339  auto jb = std::min( n - jc, nc );
340 
341 
343  for ( int pc = get<0>( Loop5th );
344  pc < get<1>( Loop5th );
345  pc += get<2>( Loop5th ) )
346  {
347  auto pb = std::min( k - pc, KC );
348  auto is_the_last_pc_iteration = ( pc + KC >= k );
349  auto LooppkB = Comm5th.DistributeOver1DThreads( 0, jb, NR );
350  auto PackpkB = Comm5th.DistributeOver1DThreads( 0, jb, PACK_NR );
351 
352  for ( int j = get<0>( LooppkB ), jp = get<0>( PackpkB );
353  j < get<1>( LooppkB );
354  j += get<2>( LooppkB ), jp += get<2>( PackpkB ) )
355  {
357  B.Pack(
358  k, pc, pb,
359  n, jc + j, std::min( jb - j, NR ),
360  &packB[ jp * pb ] );
361  }
362  Comm5th.Barrier();
363 
364 
366  for ( int ic = get<0>( Loop4th );
367  ic < get<1>( Loop4th );
368  ic += get<2>( Loop4th ) )
369  {
370  auto &ic_comm = *thread.ic_comm;
371  auto ib = std::min( m - ic, MC );
372  auto LooppkA = Comm4th.DistributeOver1DThreads( 0, ib, MR );
373  auto PackpkA = Comm4th.DistributeOver1DThreads( 0, ib, PACK_MR );
374 
375  for ( int i = get<0>( LooppkA ), ip = get<0>( PackpkA );
376  i < get<1>( LooppkA );
377  i += get<2>( LooppkA ), ip += get<2>( PackpkA ) )
378  {
380  A.Pack(
381  m, ic + i, std::min( ib - i, MR ),
382  k, pc, pb,
383  &packA[ ip * pb ] );
384  }
385  Comm4th.Barrier();
386 
387  if ( is_the_last_pc_iteration ) // fused_macro_kernel
388  {
389  fused_macro_kernel<KC>
390  (
391  Comm4th,
392  m, n,
393  ic, jc, pc,
394  ib, jb, pb,
395  packA,
396  packB,
397  &C,
398  V + ic * rs_v + jc * cs_v, rs_v, cs_v,
399  batchId,
400  microkernel
401  );
402 
403  }
404  else // semiring rank-k update
405  {
406  rank_k_macro_kernel<KC>
407  (
408  Comm4th,
409  ic, jc, pc,
410  ib, jb, pb,
411  packA,
412  packB,
413  V + ic * rs_v + jc * cs_v, rs_v, cs_v,
414  semiringkernel
415  );
416  }
417  Comm4th.Barrier();
418  } // end 4th loop
419  Comm5th.Barrier();
420  } // end 5th loop
421  Comm6th.Barrier();
422  } // end 6th loop
423  CommGLB.Barrier();
424 
426  Comm4th.FreeSharedMemory( packA );
427  Comm5th.FreeSharedMemory( packB );
428 
429 };
441 template<
442  int MC, int NC, int KC,
443  typename TPACKA, typename TPACKB, typename TV,
444  typename TA, typename TB, typename TC,
445  typename SEMIRINGKERNEL, typename MICROKERNEL>
446 void gnbx
447 (
448  int batchId, int m, int n, int k,
449  TA& A,
450  TB& B,
451  TC& C,
452  SEMIRINGKERNEL semiringkernel,
453  MICROKERNEL microkernel
454 )
455 {
456  const static int MR = SEMIRINGKERNEL::mr;
457  const static int NR = SEMIRINGKERNEL::nr;
458  const static int PACK_MR = SEMIRINGKERNEL::pack_mr;
459  const static int PACK_NR = SEMIRINGKERNEL::pack_nr;
460  const static int ALIGN_SIZE = SEMIRINGKERNEL::align_size;
461  const static int PACK_MC = ( MC / MR ) * PACK_MR;
462  const static int PACK_NC = ( NC / NR ) * PACK_NR;
463  const static bool USE_STRASSEN = false;
464 
466  if ( m == 0 || n == 0 || k == 0 ) return;
467 
468 
469  TV *V = NULL;
470  int rs_v = 0;
471  int cs_v = 0;
472 
473 
474  if ( k > KC && !is_same<TC, MatrixLike<PACK_MR, TV, TV>>::value )
475  {
476  //printf( "here m %d n %d\n", m, n );
477  V = hmlp_malloc<ALIGN_SIZE, TV>( m * n );
478  rs_v = 1;
479  cs_v = m;
480  }
481  else
482  {
484  V = reinterpret_cast<TV*>( C.X );
485  rs_v = C.rs;
486  cs_v = C.cs;
487  }
488 
489 
490  int k_stra = 0;
491  if ( USE_STRASSEN )
492  {
493  assert( typeid(TPACKA) == typeid(TPACKB) );
494  assert( typeid(TC) == typeid(TV) );
495  k_stra = k - k % KC;
496 
497  if ( k_stra == k ) k_stra -= KC;
498  }
499 
500  int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1;
501  if ( omp_get_num_threads() == 1 && omp_get_max_threads() > 1 )
502  {
504  jc_nt = hmlp_read_nway_from_env( "KS_JC_NT" );
505  ic_nt = hmlp_read_nway_from_env( "KS_IC_NT" );
506  jr_nt = hmlp_read_nway_from_env( "KS_JR_NT" );
507  }
508 
510  thread_communicator my_comm( jc_nt, pc_nt, ic_nt, jr_nt );
511 
512  #pragma omp parallel num_threads( my_comm.GetNumThreads() )
513  {
514  Worker thread( &my_comm );
515 
517  thread.InitWithCommunicator( &my_comm, omp_get_thread_num(), 0 );
518 
519  //if ( USE_STRASSEN )
520  //{
521  // strassen::strassen_internal
522  // <MC, NC, KC, MR, NR,
523  // PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
524  // USE_STRASSEN,
525  // SEMIRINGKERNEL, SEMIRINGKERNEL,
526  // TA, TPACKA, TB, TPACKB, TC, TV>
527  // (
528  // thread,
529  // m, n, k_stra,
530  // A, packakernel,
531  // B, packbkernel,
532  // V, ldv,
533  // semiringkernel, semiringkernel,
534  // nc, pack_nc,
535  // packA_buff,
536  // packB_buff
537  // );
538  //}
539 
540  gnbx_internal<MC, NC, KC, TPACKA, TPACKB>
541  (
542  thread,
543  batchId, m, n, k, k_stra,
544  A,
545  B,
546  C,
547  V, rs_v, cs_v,
548  semiringkernel, microkernel
549  );
550  } // end omp parallel
551 
552  if ( k > KC && !is_same<TC, MatrixLike<PACK_MR, TV, TV>>::value )
553  {
554  hmlp_free( V );
555  }
556 }; // end gkmx
557 
558 
559 
560 
561 
565 template<
566  int MR, int NR, int MC, int NC, int KC,
567  typename TPACKA, typename TPACKB, typename TPACKC, typename TV,
568  typename TA, typename TB, typename TC,
569  typename OPKERNEL, typename OP1, typename OP2>
570 void gnbx
571 (
572  int batchId, int m, int n, int k,
573  TA& A,
574  TB& B,
575  TC& C,
576  OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV
577 )
578 {
581 
582  semiringkernel.op1 = op1;
583  semiringkernel.op2 = op2;
584  semiringkernel.initV = initV;
585 
586  gkrmkernel.op1 = op1;
587  gkrmkernel.op2 = op2;
588  gkrmkernel.opkernel = opkernel;
589  gkrmkernel.initV = initV;
590 
591  gnbx<MC, NC, KC, TPACKA, TPACKB, TV>
592  ( batchId, m, n, k, A, B, C, semiringkernel, gkrmkernel );
593 
594 };
596 };
597 };
599 #endif
Definition: semiring_mrxnr.hpp:11
Worker Split()
Definition: thread.cpp:391
Definition: thread.hpp:107
Definition: fused_mrxnr.hpp:216
Definition: hmlp_internal.hpp:38
Definition: packing.hpp:198
void hmlp_free(T *ptr)
Free the aligned memory.
Definition: util.hpp:88
Definition: gofmm.hpp:83
Definition: thread.hpp:166