HMLP: High-performance Machine Learning Primitives
gkmx.hpp
1 
23 #ifndef GKMX_HPP
24 #define GKMX_HPP
25 
26 #include <assert.h>
27 #include <typeinfo>
28 #include <algorithm>
29 
30 #include <hmlp.h>
31 #include <hmlp_internal.hpp>
32 #include <hmlp_base.hpp>
33 
35 #include <primitives/strassen.hpp>
36 
38 #include <semiring_mrxnr.hpp>
39 #include <fused_mrxnr.hpp>
40 
41 //#define GKMX_CONFIG \
42 
43 
44 namespace hmlp
45 {
46 namespace gkmx
47 {
48 
54 template<
55  int KC, int MR, int NR, int PACK_MR, int PACK_NR,
56  typename SEMIRINGKERNEL,
57  typename TA, typename TB, typename TC, typename TV>
59 (
60  Worker &thread,
61  int ic, int jc, int pc,
62  int m, int n, int k,
63  TA *packA,
64  TB *packB,
65  TV *V, int ldv,
66  SEMIRINGKERNEL semiringkernel
67 )
68 {
69  thread_communicator &ic_comm = *thread.ic_comm;
70 
71  auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
72  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
73  auto loop2nd = GetRange( 0, m, MR );
74  auto pack2nd = GetRange( 0, m, PACK_MR );
75 
76  for ( int j = loop3rd.beg(), jp = pack3rd.beg();
77  j < loop3rd.end();
78  j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop
79  {
80  struct aux_s<TA, TB, TC, TV> aux;
81  aux.pc = pc;
82  aux.b_next = packB;
83  aux.do_packC = 0;
84  aux.jb = std::min( n - j, NR );
85 
86  for ( int i = loop2nd.beg(), ip = pack2nd.beg();
87  i < loop2nd.end();
88  i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop
89  {
90  aux.ib = std::min( m - i, MR );
91  if ( i + MR >= m )
92  {
93  aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
94  }
95 
96  if ( aux.jb == NR && aux.ib == MR )
97  {
98  semiringkernel
99  (
100  k,
101  &packA[ ip * k ],
102  &packB[ jp * k ],
103  &V[ j * ldv + i ], 1, ldv,
104  &aux
105  );
106  }
107  else // corner case
108  {
109  TV vtmp[ MR * NR ];
110 
111  if ( pc ) // initilize ctmp
112  {
113  for ( auto jj = 0; jj < aux.jb; jj ++ )
114  for ( auto ii = 0; ii < aux.ib; ii ++ )
115  vtmp[ jj * MR + ii ] = V[ ( j + jj ) * ldv + i + ii ];
116  }
117 
118  semiringkernel
119  (
120  k,
121  &packA[ ip * k ],
122  &packB[ jp * k ],
123  vtmp, 1, MR,
124  &aux
125  );
126 
127  for ( auto jj = 0; jj < aux.jb; jj ++ )
128  for ( auto ii = 0; ii < aux.ib; ii ++ )
129  V[ ( j + jj ) * ldv + i + ii ] = vtmp[ jj * MR + ii ];
130  }
131  } // end 2nd loop
132  } // end 3rd loop
133 } // end rank_k_macro_kernel
134 
135 
142 template<
143 int KC, int MR, int NR, int PACK_MR, int PACK_NR,
144 bool REUSE_C,
145 typename FUSEDKERNEL,
146 typename TA, typename TB, typename TC, typename TV>
147 void fused_macro_kernel
148 (
149  Worker &thread,
150  int ic, int jc, int pc,
151  int m, int n, int k,
152  TA *packA,
153  TB *packB,
154  TC *C, int ldc,
155  TV *V, int ldv,
156  int batchId,
157  FUSEDKERNEL fusedkernel
158 )
159 {
160  thread_communicator &ic_comm = *thread.ic_comm;
161 
162  auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
163  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
164  auto loop2nd = GetRange( 0, m, MR );
165  auto pack2nd = GetRange( 0, m, PACK_MR );
166 
167  for ( int j = loop3rd.beg(), jp = pack3rd.beg();
168  j < loop3rd.end();
169  j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop
170  {
171  struct aux_s<TA, TB, TC, TV> aux;
172  aux.pc = pc;
173  aux.b_next = packB;
174  aux.do_packC = 0;
175 
176  for ( int i = loop2nd.beg(), ip = pack2nd.beg();
177  i < loop2nd.end();
178  i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop
179  {
180  // These auxiluary infos are used to access data in the closure of
181  // opkernel and opreduce.
182  aux.i = ic + i;
183  aux.j = jc + j;
184  aux.b = batchId;
185 
186  aux.ib = std::min( m - i, MR );
187  aux.jb = std::min( n - j, NR );
188 
189  aux.V = V + j * ldv + i;
190  aux.ldv = ldv;
191 
192  if ( i + MR >= m )
193  {
194  aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
195  }
196 
197  if ( aux.jb == NR && aux.ib == MR )
198  {
199  fusedkernel
200  (
201  k,
202  &packA[ ip * k ],
203  &packB[ jp * k ],
204  &C[ j * ldc + i ], 1, ldc,
205  //&C[ ( j / NR ) * ldc + i ], ldc, // for conv_relu_pool
206  &aux
207  );
208  }
209  else // corner case
210  {
211  TC ctmp[ MR * NR ];
212  TV vtmp[ MR * NR ];
213 
214  if ( pc ) // initilize ctmp
215  {
216  if ( REUSE_C )
217  {
218  for ( auto jj = 0; jj < aux.jb; jj ++ )
219  for ( auto ii = 0; ii < aux.ib; ii ++ )
220  ctmp[ jj * MR + ii ] = C[ ( j + jj ) * ldc + i + ii ];
221  }
222  else
223  {
224  for ( auto jj = 0; jj < aux.jb; jj ++ )
225  for ( auto ii = 0; ii < aux.ib; ii ++ )
226  vtmp[ jj * MR + ii ] = V[ ( j + jj ) * ldv + i + ii ];
227  aux.V = vtmp;
228  aux.ldv = MR;
229  }
230  }
231 
232  fusedkernel
233  (
234  k,
235  &packA[ ip * k ],
236  &packB[ jp * k ],
237  ctmp, 1, MR,
238  &aux
239  );
240 
241  for ( auto jj = 0; jj < aux.jb; jj ++ )
242  for ( auto ii = 0; ii < aux.ib; ii ++ )
243  C[ ( j + jj ) * ldc + i + ii ] = ctmp[ jj * MR + ii ];
244 
245  }
246  } // end 2nd loop
247  } // end 3rd loop
248 }; // end fused_macro_kernel
249 
250 
251 
252 
253 
261 template<
262  int MC,
263  int NC,
264  int KC,
265  int MR,
266  int NR,
267  int PACK_MC,
268  int PACK_NC,
269  int PACK_MR,
270  int PACK_NR,
271  int ALIGN_SIZE,
272  bool USE_STRASSEN,
273  bool REUSE_C,
274  typename SEMIRINGKERNEL, typename MICROKERNEL,
275  typename TA, typename TB, typename TC, typename TV>
276 void gkmx_internal
277 (
278  Worker &thread,
279  hmlpOperation_t transA, hmlpOperation_t transB,
280  int m, int n, int k, int k_stra,
281  TA *A, int lda,
282  TB *B, int ldb,
283  TC *C, int ldc,
284  TV *V, int ldv,
285  int batchId,
286  SEMIRINGKERNEL semiringkernel,
287  MICROKERNEL microkernel,
288  int nc, int pack_nc,
289  TA *packA,
290  TB *packB
291 )
292 {
293  packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC
294  + ( thread.ic_id ) * PACK_MC * KC;
295  packB += ( thread.jc_id ) * pack_nc * KC;
296 
297  auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt );
298  auto loop5th = GetRange( k_stra, k, KC );
299  auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
300 
301  for ( int jc = loop6th.beg();
302  jc < loop6th.end();
303  jc += loop6th.inc() ) // beg 6th loop
304  {
305  auto &jc_comm = *thread.jc_comm;
306  auto jb = std::min( n - jc, nc );
307 
308  for ( int pc = loop5th.beg();
309  pc < loop5th.end();
310  pc += loop5th.inc() )
311  {
312  auto &pc_comm = *thread.pc_comm;
313  auto pb = std::min( k - pc, KC );
314  auto is_the_last_pc_iteration = ( pc + KC >= k );
315  auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() );
316  auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
317 
318  for ( int j = looppkB.beg(), jp = packpkB.beg();
319  j < looppkB.end();
320  j += looppkB.inc(), jp += packpkB.inc() )
321  {
322  if ( transB == HMLP_OP_N )
323  {
324  pack2D<true, PACK_NR> // packB
325  (
326  std::min( jb - j, NR ), pb,
327  &B[ ( jc + j ) * ldb + pc ], ldb, &packB[ jp * pb ]
328  );
329  }
330  else
331  {
332  pack2D<false, PACK_NR> // packB (transB)
333  (
334  std::min( jb - j, NR ), pb,
335  &B[ pc * ldb + ( jc + j ) ], ldb, &packB[ jp * pb ]
336  );
337  }
338  }
339  pc_comm.Barrier();
340 
341  for ( int ic = loop4th.beg();
342  ic < loop4th.end();
343  ic += loop4th.inc() ) // beg 4th loop
344  {
345  auto &ic_comm = *thread.ic_comm;
346  auto ib = std::min( m - ic, MC );
347  auto looppkA = GetRange( 0, ib, MR, thread.jr_id, thread.jr_nt );
348  auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt );
349 
350  for ( int i = looppkA.beg(), ip = packpkA.beg();
351  i < looppkA.end();
352  i += looppkA.inc(), ip += packpkA.inc() )
353  {
354  if ( transA == HMLP_OP_N )
355  {
356  pack2D<false, PACK_MR> // packA
357  (
358  std::min( ib - i, MR ), pb,
359  &A[ pc * lda + ( ic + i ) ], lda, &packA[ ip * pb ]
360  );
361  }
362  else
363  {
364  pack2D<true, PACK_MR> // packA (transA)
365  (
366  std::min( ib - i, MR ), pb,
367  &A[ ( ic + i ) * lda + pc ], lda, &packA[ ip * pb ]
368  );
369  }
370  }
371  ic_comm.Barrier();
372 
373  if ( is_the_last_pc_iteration ) // fused_macro_kernel
374  {
375  fused_macro_kernel
376  <KC, MR, NR, PACK_MR, PACK_NR, REUSE_C, MICROKERNEL, TA, TB, TC, TV>
377  (
378  thread,
379  ic, jc, pc,
380  ib, jb, pb,
381  packA,
382  packB,
383  C + jc * ldc + ic, ldc,
384  V + jc * ldv + ic, ldv, // if REUSE_C, then V = C.
385  batchId,
386  microkernel
387  );
388  }
389  else // semiring rank-k update
390  {
392  <KC, MR, NR, PACK_MR, PACK_NR, SEMIRINGKERNEL, TA, TB, TC, TV>
393  (
394  thread,
395  ic, jc, pc,
396  ib, jb, pb,
397  packA,
398  packB,
399  //C + jc * ldc + ic, ldc,
400  V + jc * ldv + ic, ldv,
401  semiringkernel
402  );
403  }
404  ic_comm.Barrier(); // sync all jr_id!!
405  } // end 4th loop
406  pc_comm.Barrier();
407  } // end 5th loop
408  } // end 6th loop
409 } // end gkmx_internal
410 
411 
412 
413 
414 
421 template<
422  int MC,
423  int NC,
424  int KC,
425  int MR,
426  int NR,
427  int PACK_MC,
428  int PACK_NC,
429  int PACK_MR,
430  int PACK_NR,
431  int ALIGN_SIZE,
432  bool USE_STRASSEN = false,
433  bool REUSE_C,
434  typename SEMIRINGKERNEL, typename MICROKERNEL,
435  typename TA, typename TB, typename TC, typename TV = TC>
436 void gkmx
437 (
438  hmlpOperation_t transA, hmlpOperation_t transB,
439  int m, int n, int k,
440  TA *A, int lda,
441  TB *B, int ldb,
442  TC *C, int ldc,
443  int batchId,
444  SEMIRINGKERNEL semiringkernel,
445  MICROKERNEL microkernel
446 )
447 {
448  int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1;
449  int k_stra = 0;
450  int ldv = 0;
451  int nc = NC, pack_nc = PACK_NC;
452  char *str;
453 
454  TA *packA_buff = NULL;
455  TB *packB_buff = NULL;
456  TV *V = NULL;
457 
458  // Early return if possible
459  if ( m == 0 || n == 0 || k == 0 ) return;
460 
461  // type checking (currently assume TC == TV)
462  if ( typeid(TC) != typeid(TV) && k > KC )
463  {
464  printf( "gkmx: currently k(%d) must be smaller than %d when TC != TV\n", k, KC );
465  exit( 1 );
466  }
467 
468  if ( omp_get_num_threads() == 1 && omp_get_max_threads() > 1 )
469  {
470  // Check the environment variable.
471  jc_nt = hmlp_read_nway_from_env( "KS_JC_NT" );
472  ic_nt = hmlp_read_nway_from_env( "KS_IC_NT" );
473  jr_nt = hmlp_read_nway_from_env( "KS_JR_NT" );
474  }
475 
476  if ( jc_nt > 1 )
477  {
478  nc = ( ( n - 1 ) / ( NR * jc_nt ) + 1 ) * NR;
479  pack_nc = ( nc / NR ) * PACK_NR;
480  }
481 
482  // allocate packing memory
483  packA_buff = hmlp_malloc<ALIGN_SIZE, TA>( KC * ( PACK_MC + 1 ) * jc_nt * ic_nt );
484  packB_buff = hmlp_malloc<ALIGN_SIZE, TB>( KC * ( pack_nc + 1 ) * jc_nt );
485 
486 
487  // allocate V if k > KC
488  if ( k > KC && !std::is_same<TC, TV>::value && !REUSE_C )
489  {
490  V = hmlp_malloc<ALIGN_SIZE, TV>( m * n );
491  ldv = m;
492  }
493  else // TODO: do not free V in this case.
494  {
495  V = reinterpret_cast<TV*>( C );
496  ldv = ldc;
497  }
498 
499  // allocate tree communicator
500  thread_communicator my_comm( jc_nt, pc_nt, ic_nt, jr_nt );
501 
502 
503  if ( USE_STRASSEN )
504  {
505  assert( typeid(TA) == typeid(TB) );
506  assert( typeid(TC) == typeid(TV) );
507  k_stra = k - k % KC;
508 
509  if ( k_stra == k ) k_stra -= KC;
510 
511  if ( k_stra )
512  {
513  #pragma omp parallel for
514  for ( int i = 0; i < n * ldv; i ++ ) V[ i ] = 0.0;
515  }
516  }
517 
518 
519  #pragma omp parallel num_threads( my_comm.GetNumThreads() )
520  {
521  Worker thread( &my_comm );
522 
523  if ( USE_STRASSEN )
524  {
525  strassen::strassen_internal
526  <MC, NC, KC, MR, NR,
527  PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
528  USE_STRASSEN,
529  SEMIRINGKERNEL, SEMIRINGKERNEL,
530  TA, TB, TC, TV>
531  (
532  thread,
533  transA, transB,
534  m, n, k_stra,
535  A, lda,
536  B, ldb,
537  V, ldv,
538  semiringkernel, semiringkernel,
539  nc, pack_nc,
540  packA_buff,
541  packB_buff
542  );
543  }
544 
545  gkmx_internal
546  <MC, NC, KC, MR, NR,
547  PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
548  USE_STRASSEN, REUSE_C,
549  SEMIRINGKERNEL, MICROKERNEL,
550  TA, TB, TC, TV>
551  (
552  thread,
553  transA, transB,
554  m, n, k, k_stra,
555  A, lda,
556  B, ldb,
557  C, ldc,
558  V, ldv,
559  batchId,
560  semiringkernel, microkernel,
561  nc, pack_nc,
562  packA_buff,
563  packB_buff
564  );
565  } // end omp parallel
566 
567  hmlp_free( packA_buff );
568  hmlp_free( packB_buff );
569  //hmlp_free( V );
570 }; // end gkmx
571 
572 
573 
574 
575 
579 template<
580  int MC = 104,
581  int NC = 1024,
582  int KC = 256,
583  int MR = 8,
584  int NR = 4,
585  int PACK_MC = 104,
586  int PACK_NC = 1024,
587  int PACK_MR = 8,
588  int PACK_NR = 4,
589  int ALIGN_SIZE = 32,
590  bool USE_STRASSEN = false,
591  bool REUSE_C = false,
592  typename OPKERNEL, typename OP1, typename OP2,
593  typename TA, typename TB, typename TC, typename TV>
594 void gkmm
595 (
596  hmlpOperation_t transA, hmlpOperation_t transB,
597  int m, int n, int k,
598  TA *A, int lda,
599  TB *B, int ldb,
600  TC *C, int ldc,
601  int batchId,
602  OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV
603 )
604 {
607 
608  semiringkernel.op1 = op1;
609  semiringkernel.op2 = op2;
610  semiringkernel.initV = initV;
611 
612  gkmmkernel.op1 = op1;
613  gkmmkernel.op2 = op2;
614  gkmmkernel.opkernel = opkernel;
615  gkmmkernel.initV = initV;
616 
617  gkmx
618  <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
619  USE_STRASSEN, REUSE_C,
622  TA, TB, TC, TV>
623  (
624  transA, transB,
625  m, n, k,
626  A, lda,
627  B, ldb,
628  C, ldc,
629  batchId,
630  semiringkernel, gkmmkernel
631  );
632 };
633 
634 
642 template<
643  int MC, int NC, int KC, int MR, int NR,
644  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
645  bool USE_STRASSEN, bool REUSE_C,
646  typename OPKERNEL, typename OP1, typename OP2,
647  typename TA, typename TB, typename TC, typename TV>
648 void gkmm
649 (
650  hmlpOperation_t transA, hmlpOperation_t transB,
651  int m, int n, int k,
652  TA *Aarray[], int lda,
653  TB *Barray[], int ldb,
654  TC *Carray[], int ldc,
655  int batchSize,
656  OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV
657 )
658 {
659  #pragma omp parallel for
660  for ( auto b = 0; b < batchSize; b ++ )
661  {
662  gkmm
663  <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
664  USE_STRASSEN,
665  OPKERNEL, OP1, OP2,
666  TA, TB, TC, TV>
667  (
668  transA, transB,
669  m, n, k,
670  Aarray[ b ], lda,
671  Barray[ b ], ldb,
672  Carray[ b ], ldc,
673  b,
674  opkernel, op1, op2, initV
675  );
676  }
677 }; // end gkmm
678 
679 
687 template<
688  int MC,
689  int NC,
690  int KC, int MR, int NR,
691  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
692  bool USE_STRASSEN, bool REUSE_C,
693  typename OPKERNEL, typename OP1, typename OP2,
694  typename TA, typename TB, typename TC, typename TV>
695 void gkmm
696 (
697  hmlpOperation_t transA, hmlpOperation_t transB,
698  int m, int n, int k,
699  TA *Aarray, int lda, int loa,
700  TB *Barray, int ldb, int lob,
701  TC *Carray, int ldc, int loc,
702  int batchSize,
703  OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV
704 )
705 {
706  #pragma omp parallel for
707  for ( auto b = 0; b < batchSize; b ++ )
708  {
709  gkmm
710  <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
711  USE_STRASSEN, REUSE_C,
712  OPKERNEL, OP1, OP2,
713  TA, TB, TC, TV>
714  (
715  transA, transB,
716  m, n, k,
717  Aarray + b * loa, lda,
718  Barray + b * lob, ldb,
719  Carray + b * loc, ldc,
720  b,
721  opkernel, op1, op2, initV
722  );
723  }
724 }; // end gkmm
725 
726 
727 
728 
729 
730 
731 
732 
733 
734 
735 
736 
737 
738 
739 
740 
741 
742 
743 
751 template<
752  int MC = 104,
753  int NC = 1024,
754  int KC = 256,
755  int MR = 8,
756  int NR = 4,
757  int PACK_MC = 104,
758  int PACK_NC = 1024,
759  int PACK_MR = 8,
760  int PACK_NR = 4,
761  int ALIGN_SIZE = 32,
762  bool USE_STRASSEN = false,
763  typename OPKERNEL, typename OP1, typename OP2, typename OPREDUCE,
764  typename TA, typename TB, typename TC, typename TV = TC>
765 void gkrm
766 (
767  hmlpOperation_t transA, hmlpOperation_t transB,
768  int m, int n, int k,
769  TA *A, int lda,
770  TB *B, int ldb,
771  TC *C, int ldc,
772  int batchId,
773  OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV,
774  OPREDUCE opreduce, TC initC
775 )
776 {
779 
780  semiringkernel.op1 = op1;
781  semiringkernel.op2 = op2;
782  semiringkernel.initV = initV;
783 
784  gkrmkernel.op1 = op1;
785  gkrmkernel.op2 = op2;
786  gkrmkernel.opkernel = opkernel;
787  gkrmkernel.initV = initV;
788  gkrmkernel.opreduce = opreduce;
789  gkrmkernel.initC = initC;
790 
791  gkmx
792  <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
793  USE_STRASSEN,
796  TA, TB, TC, TV>
797  (
798  transA, transB,
799  m, n, k,
800  A, lda,
801  B, ldb,
802  C, 0, // TODO: is there a better way to do this?
803  batchId,
804  semiringkernel, gkrmkernel
805  );
806 }; // end gkrm
807 
808 
809 
810 
814 template<
815  typename OPKERNEL, typename OP1, typename OP2,
816  typename TA, typename TB, typename TC, typename TV = TC>
817 void gkmm_ref
818 (
819  hmlpOperation_t transA, hmlpOperation_t transB,
820  int m, int n, int k,
821  TA *A, int lda,
822  TB *B, int ldb,
823  TC *C, int ldc,
824  OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV
825 )
826 {
827  for ( int i = 0; i < m; i ++ )
828  {
829  for ( int j = 0; j < n; j ++ )
830  {
831  auto v = initV;
832  for ( int p = 0; p < k; p ++ )
833  {
834  TA a;
835  TB b;
836  if ( transA == HMLP_OP_N ) a = A[ p * lda + i ];
837  else a = A[ i * lda + p ];
838  if ( transB == HMLP_OP_N ) b = B[ j * ldb + p ];
839  else b = B[ p * ldb + j ];
840  v = op1( v, op2( a, b ) );
841  }
842  C[ j * ldc + i ] = opkernel( v );
843  }
844  }
845 }; // end gkmm_ref
846 
847 
853 template<
854  typename OPKERNEL, typename OP1, typename OP2, typename OPREDUCE,
855  typename TA, typename TB, typename TC, typename TV = TC>
856 void gkrm_ref
857 (
858  hmlpOperation_t transA, hmlpOperation_t transB,
859  int m, int n, int k,
860  TA *A, int lda,
861  TB *B, int ldb,
862  TC *C, int ldc,
863  int batchId,
864  OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV,
865  OPREDUCE opreduce, TC initC
866  )
867 {
868  for ( int i = 0; i < m; i ++ )
869  {
870  auto c = initC;
871  for ( int j = 0; j < n; j ++ )
872  {
873  auto v = initV;
874  for ( int p = 0; p < k; p ++ )
875  {
876  TA a;
877  TB b;
878  if ( transA == HMLP_OP_N ) a = A[ p * lda + i ];
879  else a = A[ i * lda + p ];
880  if ( transB == HMLP_OP_N ) b = B[ j * ldb + p ];
881  else b = B[ p * ldb + j ];
882  v = op1( v, op2( a, b ) );
883  }
884  c = opreduce( c, opkernel( v ) );
885  }
886  C[ i ] = c;
887  }
888 }; // end gkrm_ref
889 
890 
891 }; // end namespace gkmx
892 }; // end namespace hmlp
893 
894 #endif // define GKMX_HPP
Definition: semiring_mrxnr.hpp:11
Definition: thread.hpp:107
void rank_k_macro_kernel(tci::Comm &Comm3rd, int ic, int jc, int pc, int m, int n, int k, TA *packA, TB *packB, TV *V, int rs_v, int cs_v, SEMIRINGKERNEL semiringkernel)
Macro kernel contains the 3rd and 2nd loops. Depending on the configuration of the communicator...
Definition: rank_k.hpp:51
This kernel takes opkernel, op1 and op2 to implement an MR-by-NR GKMM operation.
Definition: fused_mrxnr.hpp:12
Definition: fused_mrxnr.hpp:127
Definition: hmlp_internal.hpp:38
void Barrier()
OpenMP thread barrier from BLIS.
Definition: thread.cpp:227
void hmlp_free(T *ptr)
Free the aligned memory.
Definition: util.hpp:88
Definition: gofmm.hpp:83
Definition: thread.hpp:166