HMLP: High-performance Machine Learning Primitives
strassen.hpp
1 
22 #ifndef STRASSEN_HPP
23 #define STRASSEN_HPP
24 
25 #define STRAPRIM( A0,A1,gamma,B0,B1,delta,C0,C1,alpha0,alpha1 ) \
26  straprim \
27  <MC, NC, KC, MR, NR, \
28  PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, \
29  USE_STRASSEN, \
30  STRA_SEMIRINGKERNEL, STRA_MICROKERNEL, \
31  TA, TB, TC, TB> \
32  ( \
33  thread, \
34  transA, transB, \
35  md, nd, kd, \
36  A0, A1, lda, gamma, \
37  B0, B1, ldb, delta, \
38  C0, C1, ldc, alpha0, alpha1, \
39  stra_semiringkernel, stra_microkernel, \
40  nc, pack_nc, \
41  packA_buff, \
42  packB_buff \
43  ); \
44 
45 #define STRAPRIM_MAP( A0,A1,gamma,B0,B1,delta,C0,C1,alpha0,alpha1 ) \
46  straprim \
47  <MC, NC, KC, MR, NR, \
48  PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, \
49  USE_STRASSEN, \
50  STRA_SEMIRINGKERNEL, STRA_MICROKERNEL, \
51  TA, TB, TC, TB> \
52  ( \
53  thread, \
54  transA, transB, \
55  md, nd, kd, \
56  A0, A1, lda, gamma, amap, \
57  B0, B1, ldb, delta, bmap, \
58  C0, C1, ldc, alpha0, alpha1, \
59  stra_semiringkernel, stra_microkernel, \
60  nc, pack_nc, \
61  packA_buff, \
62  packB_buff \
63  ); \
64 
65 #include <hmlp.h>
66 #include <hmlp_internal.hpp>
67 #include <hmlp_base.hpp>
68 
69 namespace hmlp
70 {
71 namespace strassen
72 {
73 
74 //#define min( i, j ) ( (i)<(j) ? (i): (j) )
75 
79 template<
80  int KC, int MR, int NR, int PACK_MR, int PACK_NR,
81  typename SEMIRINGKERNEL,
82  typename TA, typename TB, typename TC, typename TV>
84 (
85  Worker &thread,
86  int ic, int jc, int pc,
87  int m, int n, int k,
88  TA *packA,
89  TB *packB,
90  TV *C0, TV *C1, int ldc, TV alpha0, TV alpha1,
91  SEMIRINGKERNEL semiringkernel
92 )
93 {
94  thread_communicator &ic_comm = *thread.ic_comm;
95 
96  auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
97  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
98  auto loop2nd = GetRange( 0, m, MR );
99  auto pack2nd = GetRange( 0, m, PACK_MR );
100 
101  for ( int j = loop3rd.beg(), jp = pack3rd.beg();
102  j < loop3rd.end();
103  j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop
104  {
105  struct aux_s<TA, TB, TC, TV> aux;
106  aux.pc = pc;
107  aux.b_next = packB;
108  aux.do_packC = 0;
109  aux.jb = std::min( n - j, NR );
110 
111  for ( int i = loop2nd.beg(), ip = pack2nd.beg();
112  i < loop2nd.end();
113  i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop
114  {
115  aux.ib = std::min( m - i, MR );
116  if ( aux.ib != MR )
117  {
118  aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
119  }
120 
121  if ( aux.jb == NR && aux.ib == MR )
122  {
123 
124  if ( alpha1 == 0 || C1 == NULL ) {
125  TV *c_list[1], alpha_list[1];
126  c_list[0] = &C0[ j * ldc + i ];
127  alpha_list[0] = alpha0;
128 
129  semiringkernel
130  (
131  k,
132  &packA[ ip * k ],
133  &packB[ jp * k ],
134  1, c_list, ldc, alpha_list,
135  &aux
136  );
137 
138  } else {
139 
140  TV *c_list[2], alpha_list[2];
141  c_list[0] = &C0[ j * ldc + i ]; c_list[1] = &C1[ j * ldc + i ];
142  alpha_list[0] = alpha0; alpha_list[1] = alpha1;
143  semiringkernel
144  (
145  k,
146  &packA[ ip * k ],
147  &packB[ jp * k ],
148  2, c_list, ldc, alpha_list,
149  &aux
150  );
151 
152  }
153 
154  //semiringkernel
155  //(
156  // k,
157  // &packA[ ip * k ],
158  // &packB[ jp * k ],
159  // &C0[ j * ldc + i ], &C1[ j * ldc + i ], ldc, alpha0, alpha1,
160  // &aux
161  //);
162 
163 
164  }
165  else // corner case
166  {
167 
168  //printf( "Enter corner case!\n" );
169  // TODO: this should be initC.
170  TV ctmp[ MR * NR ] = { (TV)0.0 };
171 
172  TV *c_list[1], alpha_list[1];
173  c_list[0] = ctmp;
174  alpha_list[0] = 1;
175 
176  semiringkernel
177  (
178  k,
179  &packA[ ip * k ],
180  &packB[ jp * k ],
181  //ctmp, MR,
182  1, c_list, MR, alpha_list,
183  &aux
184  );
185 
186 
189  //semiringkernel
190  //(
191  // k,
192  // &packA[ ip * k ],
193  // &packB[ jp * k ],
194  // //ctmp, MR,
195  // ctmp, NULL, MR, 1, 0,
196  // &aux
197  //);
198  //if ( pc )
199  {
200  for ( auto jj = 0; jj < aux.jb; jj ++ )
201  {
202  for ( auto ii = 0; ii < aux.ib; ii ++ )
203  {
204  C0[ ( j + jj ) * ldc + i + ii ] += alpha0 * ctmp[ jj * MR + ii ];
205 
206  if ( alpha1 != 0 && C1 != NULL ) {
207  C1[ ( j + jj ) * ldc + i + ii ] += alpha1 * ctmp[ jj * MR + ii ];
208  }
209  }
210  }
211  }
212  //else
213  //{
214  // for ( auto jj = 0; jj < aux.jb; jj ++ )
215  // {
216  // for ( auto ii = 0; ii < aux.ib; ii ++ )
217  // {
218  // C0[ ( j + jj ) * ldc + i + ii ] = alpha0 * ctmp[ jj * MR + ii ];
219 
220  // if ( alpha1 != 0 && C1 != NULL ) {
221  // C1[ ( j + jj ) * ldc + i + ii ] = alpha1 * ctmp[ jj * MR + ii ];
222  // }
223  // }
224  // }
225  //}
226  }
227  } // end 2nd loop
228  } // end 3rd loop
229 } // end rank_k_macro_kernel
230 
234 //template<int KC, int MR, int NR, int PACK_MR, int PACK_NR,
235 // typename MICROKERNEL,
236 // typename TA, typename TB, typename TC, typename TV>
237 //void fused_macro_kernel
238 //(
239 // Worker &thread,
240 // int ic, int jc, int pc,
241 // int m, int n, int k,
242 // TA *packA,
243 // TB *packB,
244 // TV *C0, TV *C1, int ldc, TV alpha0, TV alpha1,
245 // MICROKERNEL microkernel
246 //)
247 //{
248 // thread_communicator &ic_comm = *thread.ic_comm;
249 //
250 // auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
251 // auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
252 // auto loop2nd = GetRange( 0, m, MR );
253 // auto pack2nd = GetRange( 0, m, PACK_MR );
254 //
255 // for ( int j = loop3rd.beg(), jp = pack3rd.beg();
256 // j < loop3rd.end();
257 // j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop
258 // {
259 // struct aux_s<TA, TB, TC, TV> aux;
260 // aux.pc = pc;
261 // aux.b_next = packB;
262 // aux.do_packC = 0;
263 // aux.jb = std::min( n - j, NR );
264 //
265 // for ( int i = loop2nd.beg(), ip = pack2nd.beg();
266 // i < loop2nd.end();
267 // i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop
268 // {
269 // aux.ib = std::min( m - i, MR );
270 // if ( aux.ib != MR )
271 // {
272 // aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
273 // }
274 //
275 // if ( aux.jb == NR && aux.ib == MR )
276 // {
277 //
278 // if ( alpha1 == 0 || C1 == NULL ) {
279 //
280 // double *c_list[1], alpha_list[1];
281 // c_list[0] = &C0[ j * ldc + i ];
282 // alpha_list[0] = alpha0;
283 //
284 // microkernel
285 // (
286 // k,
287 // &packA[ ip * k ],
288 // &packB[ jp * k ],
289 // 1, c_list, ldc, alpha_list,
290 // &aux
291 // );
292 // } else {
293 //
294 // double *c_list[2], alpha_list[2];
295 // c_list[0] = &C0[ j * ldc + i ]; c_list[1] = &C1[ j * ldc + i ];
296 // alpha_list[0] = alpha0; alpha_list[1] = alpha1;
297 //
298 // microkernel
299 // (
300 // k,
301 // &packA[ ip * k ],
302 // &packB[ jp * k ],
303 // 2, c_list, ldc, alpha_list,
304 // &aux
305 // );
306 //
307 // }
308 //
309 //
310 // //microkernel
311 // //(
312 // // k,
313 // // &packA[ ip * k ],
314 // // &packB[ jp * k ],
315 // // &C0[ j * ldc + i ], &C1[ j * ldc + i ], ldc, alpha0, alpha1,
316 // // &aux
317 // //);
318 // }
319 // else // corner case
320 // {
321 // //printf( "Enter corner case!\n" );
322 // // TODO: this should be initC.
323 // TV ctmp[ MR * NR ] = { (TV)0.0 };
324 //
325 // double *c_list[1], alpha_list[1];
326 // c_list[0] = ctmp;
327 // alpha_list[0] = 1;
328 //
329 // microkernel
330 // (
331 // k,
332 // &packA[ ip * k ],
333 // &packB[ jp * k ],
334 // //ctmp, MR,
335 // 1, c_list, MR, alpha_list,
336 // &aux
337 // );
338 //
339 // ////rank_k_int_d8x4 rankk_microkernel;
340 // ////rankk_microkernel
341 // //microkernel
342 // //(
343 // // k,
344 // // &packA[ ip * k ],
345 // // &packB[ jp * k ],
346 // // //ctmp, MR,
347 // // ctmp, NULL, MR, 1, 0,
348 // // &aux
349 // //);
350 //
351 // //if ( pc )
352 // {
353 // for ( auto jj = 0; jj < aux.jb; jj ++ )
354 // {
355 // for ( auto ii = 0; ii < aux.ib; ii ++ )
356 // {
357 // C0[ ( j + jj ) * ldc + i + ii ] += alpha0 * ctmp[ jj * MR + ii ];
358 //
359 // if ( alpha1 != 0 && C1 != NULL ) {
360 // C1[ ( j + jj ) * ldc + i + ii ] += alpha1 * ctmp[ jj * MR + ii ];
361 // }
362 // }
363 // }
364 // }
365 // //else
366 // //{
367 // // for ( auto jj = 0; jj < aux.jb; jj ++ )
368 // // {
369 // // for ( auto ii = 0; ii < aux.ib; ii ++ )
370 // // {
371 // // C0[ ( j + jj ) * ldc + i + ii ] = alpha0 * ctmp[ jj * MR + ii ];
372 //
373 // // if ( alpha1 != 0 && C1 != NULL ) {
374 // // C1[ ( j + jj ) * ldc + i + ii ] = alpha1 * ctmp[ jj * MR + ii ];
375 // // }
376 // // }
377 // // }
378 // //}
379 // }
380 // } // end 2nd loop
381 // } // end 3rd loop
382 //} // end fused_macro_kernel
383 
384 
385 /*
386  *
387  */
388 template<
389  int MC, int NC, int KC, int MR, int NR,
390  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
391  bool USE_STRASSEN,
392  typename STRA_SEMIRINGKERNEL, typename STRA_MICROKERNEL,
393  typename TA, typename TB, typename TC, typename TV>
394 void straprim
395 (
396  Worker &thread,
397  hmlpOperation_t transA, hmlpOperation_t transB,
398  int m, int n, int k,
399  TA *A0, TA *A1, int lda, TA gamma,
400  TB *B0, TB *B1, int ldb, TB delta,
401  TV *C0, TV *C1, int ldc, TV alpha0, TV alpha1,
402  STRA_SEMIRINGKERNEL stra_semiringkernel,
403  STRA_MICROKERNEL stra_microkernel,
404  int nc, int pack_nc,
405  TA *packA,
406  TB *packB
407 )
408 {
409  //printf( "m: %d, n: %d, k: %d\n", m, n, k );
410 
411  packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC
412  + ( thread.ic_id ) * PACK_MC * KC;
413  packB += ( thread.jc_id ) * pack_nc * KC;
414 
415  auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt );
416  auto loop5th = GetRange( 0, k, KC );
417  auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
418 
419  for ( int jc = loop6th.beg();
420  jc < loop6th.end();
421  jc += loop6th.inc() ) // beg 6th loop
422  {
423  auto &jc_comm = *thread.jc_comm;
424  auto jb = std::min( n - jc, nc );
425 
426  for ( int pc = loop5th.beg();
427  pc < loop5th.end();
428  pc += loop5th.inc() )
429  {
430  auto &pc_comm = *thread.pc_comm;
431  auto pb = std::min( k - pc, KC );
432  auto is_the_last_pc_iteration = ( pc + KC >= k );
433  auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() );
434  auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
435 
436  for ( int j = looppkB.beg(), jp = packpkB.beg();
437  j < looppkB.end();
438  j += looppkB.inc(), jp += packpkB.inc() )
439  {
440 
441  //printf( "before packB\n" );
442  if ( transB == HMLP_OP_N )
443  {
444 
445  if ( delta == 0 || B1 == NULL ) {
446  pack2D<true, PACK_NR> // packB
447  (
448  std::min( jb - j, NR ), pb,
449  &B0[ ( jc + j ) * ldb + pc ], ldb, &packB[ jp * pb ]
450  );
451  } else {
452 
453  pack2D<true, PACK_NR> // packB
454  (
455  std::min( jb - j, NR ), pb,
456  &B0[ ( jc + j ) * ldb + pc ], &B1[ ( jc + j ) * ldb + pc ], ldb, delta, &packB[ jp * pb ]
457  );
458 
459  }
460 
461  }
462  else
463  {
464  if ( delta == 0 || B1 == NULL ) {
465  pack2D<false, PACK_NR> // packB (transB)
466  (
467  std::min( jb - j, NR ), pb,
468  &B0[ pc * ldb + ( jc + j ) ], ldb, &packB[ jp * pb ]
469  );
470  } else {
471 
472  //printf( "before pack2D\n" );
473  //printf( "B1[%d]=%lf\n", pc * ldb + ( jc + j ), B1[ pc * ldb + ( jc + j ) ] );
474 
475  pack2D<false, PACK_NR> // packB (transB)
476  (
477  std::min( jb - j, NR ), pb,
478  &B0[ pc * ldb + ( jc + j ) ], &B1[ pc * ldb + ( jc + j ) ], ldb, delta, &packB[ jp * pb ]
479  );
480  //printf( "after pack2D\n" );
481 
482  }
483 
484  }
485  //printf( "After packB\n" );
486  }
487  pc_comm.Barrier();
488 
489  //printf( "packB:\n" );
490  //hmlp_printmatrix( 4, 1, packB, PACK_NR );
491 
492 
493 
494  for ( int ic = loop4th.beg();
495  ic < loop4th.end();
496  ic += loop4th.inc() ) // beg 4th loop
497  {
498  auto &ic_comm = *thread.ic_comm;
499  auto ib = std::min( m - ic, MC );
500  auto looppkA = GetRange( 0, ib, MR, thread.jr_id, thread.jr_nt );
501  auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt );
502 
503  for ( int i = looppkA.beg(), ip = packpkA.beg();
504  i < looppkA.end();
505  i += looppkA.inc(), ip += packpkA.inc() )
506  {
507 
508  //printf( "Before packA\n" );
509 
510  if ( transA == HMLP_OP_N )
511  {
512 
513  if ( gamma == 0 || A1 == NULL ) {
514  pack2D<false, PACK_MR> // packA
515  (
516  std::min( ib - i, MR ), pb,
517  &A0[ pc * lda + ( ic + i ) ], lda, &packA[ ip * pb ]
518  );
519  } else {
520 
521  //printf( "flag1\n" );
522  pack2D<false, PACK_MR> // packA
523  (
524  std::min( ib - i, MR ), pb,
525  &A0[ pc * lda + ( ic + i ) ], &A1[ pc * lda + ( ic + i ) ], lda, gamma, &packA[ ip * pb ]
526  );
527  //printf( "flag2\n" );
528  }
529 
530  }
531  else
532  {
533 
534  if ( gamma == 0 || A1 == NULL ) {
535  pack2D<true, PACK_MR> // packA (transA)
536  (
537  std::min( ib - i, MR ), pb,
538  &A0[ ( ic + i ) * lda + pc ], lda, &packA[ ip * pb ]
539  );
540  } else {
541  pack2D<true, PACK_MR> // packA (transA)
542  (
543  std::min( ib - i, MR ), pb,
544  &A0[ ( ic + i ) * lda + pc ], &A1[ ( ic + i ) * lda + pc ], lda, gamma, &packA[ ip * pb ]
545  );
546  }
547 
548  }
549 
550  //printf( "After packA\n" );
551  }
552  ic_comm.Barrier();
553 
554 // if ( is_the_last_pc_iteration ) // fused_macro_kernel
555 // {
556 // if ( alpha1 == 0 || C1 == NULL ) {
557 //
558 // //hmlp::gkmx::fused_macro_kernel
559 // //<KC, MR, NR, PACK_MR, PACK_NR, RANK_MICROKERNEL, TA, TB, TC, TV>
560 // //(
561 // // thread,
562 // // ic, jc, pc,
563 // // ib, jb, pb,
564 // // packA,
565 // // packB,
566 // // C0 + jc * ldc + ic, ldc,
567 // // rank_microkernel
568 // //);
569 //
570 // //printf( "before fused macro kernel\n" );
571 // fused_macro_kernel
572 // <KC, MR, NR, PACK_MR, PACK_NR, STRA_MICROKERNEL, TA, TB, TC, TV>
573 // (
574 // thread,
575 // ic, jc, pc,
576 // ib, jb, pb,
577 // packA,
578 // packB,
579 // C0 + jc * ldc + ic,
580 // NULL, ldc, alpha0, 0,
581 // stra_microkernel
582 // );
583 // //printf( "after fused macro kernel\n" );
584 //
585 // } else {
586 // fused_macro_kernel
587 // <KC, MR, NR, PACK_MR, PACK_NR, STRA_MICROKERNEL, TA, TB, TC, TV>
588 // (
589 // thread,
590 // ic, jc, pc,
591 // ib, jb, pb,
592 // packA,
593 // packB,
594 // C0 + jc * ldc + ic,
595 // C1 + jc * ldc + ic, ldc, alpha0, alpha1,
596 // stra_microkernel
597 // );
598 // }
599 //
600 // }
601 // else // semiring rank-k update
602 // {
603 
604  if ( alpha1 == 0 || C1 == NULL )
605  {
606  //hmlp::gkmx::rank_k_macro_kernel
607  //<KC, MR, NR, PACK_MR, PACK_NR, RANK_SEMIRINGKERNEL, TA, TB, TC, TV>
608  //(
609  // thread,
610  // ic, jc, pc,
611  // ib, jb, pb,
612  // packA,
613  // packB,
614  // C0 + jc * ldc + ic, ldc,
615  // rank_semiringkernel
616  //);
617 
619  //strassen_macro_kernel
620  <KC, MR, NR, PACK_MR, PACK_NR, STRA_SEMIRINGKERNEL, TA, TB, TC, TV>
621  (
622  thread,
623  ic, jc, pc,
624  ib, jb, pb,
625  packA,
626  packB,
627  C0 + jc * ldc + ic,
628  NULL, ldc, alpha0, 0,
629  stra_semiringkernel
630  );
631 
632  }
633  else
634  {
635 
637  //strassen_macro_kernel
638  <KC, MR, NR, PACK_MR, PACK_NR, STRA_SEMIRINGKERNEL, TA, TB, TC, TV>
639  (
640  thread,
641  ic, jc, pc,
642  ib, jb, pb,
643  packA,
644  packB,
645  C0 + jc * ldc + ic,
646  C1 + jc * ldc + ic, ldc, alpha0, alpha1,
647  stra_semiringkernel
648  );
649 
650  }
651 
652 // }
653  ic_comm.Barrier(); // sync all jr_id!!
654  } // end 4th loop
655  pc_comm.Barrier();
656  } // end 5th loop
657  } // end 6th loop
658 } // end strassen_internal
659 
660 
661 
662 
663 /*
664  *
665  */
666 template<
667  int MC, int NC, int KC, int MR, int NR,
668  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
669  bool USE_STRASSEN,
670  typename STRA_SEMIRINGKERNEL, typename STRA_MICROKERNEL,
671  typename TA, typename TB, typename TC, typename TV>
672 void straprim
673 (
674  Worker &thread,
675  hmlpOperation_t transA, hmlpOperation_t transB,
676  int m, int n, int k,
677  TA *A0, TA *A1, int lda, TA gamma, int *amap,
678  TB *B0, TB *B1, int ldb, TB delta, int *bmap,
679  TV *C0, TV *C1, int ldc, TV alpha0, TV alpha1,
680  STRA_SEMIRINGKERNEL stra_semiringkernel,
681  STRA_MICROKERNEL stra_microkernel,
682  int nc, int pack_nc,
683  TA *packA,
684  TB *packB
685 )
686 {
687  //printf( "m: %d, n: %d, k: %d\n", m, n, k );
688 
689  packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC
690  + ( thread.ic_id ) * PACK_MC * KC;
691  packB += ( thread.jc_id ) * pack_nc * KC;
692 
693  auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt );
694  auto loop5th = GetRange( 0, k, KC );
695  auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
696 
697  for ( int jc = loop6th.beg();
698  jc < loop6th.end();
699  jc += loop6th.inc() ) // beg 6th loop
700  {
701  auto &jc_comm = *thread.jc_comm;
702  auto jb = std::min( n - jc, nc );
703 
704  for ( int pc = loop5th.beg();
705  pc < loop5th.end();
706  pc += loop5th.inc() )
707  {
708  auto &pc_comm = *thread.pc_comm;
709  auto pb = std::min( k - pc, KC );
710  auto is_the_last_pc_iteration = ( pc + KC >= k );
711  auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() );
712  auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
713 
714  for ( int j = looppkB.beg(), jp = packpkB.beg();
715  j < looppkB.end();
716  j += looppkB.inc(), jp += packpkB.inc() )
717  {
718 
719  //printf( "before packB\n" );
720  if ( transB == HMLP_OP_N )
721  {
722 
723  if ( delta == 0 || B1 == NULL ) {
724  // ldb == k
725  pack2D<true, PACK_NR> // packB
726  (
727  std::min( jb - j, NR ), pb,
728  &B0[ pc ], ldb, &bmap[ jc + j ], &packB[ jp * pb ]
729  );
730  } else {
731  pack2D<true, PACK_NR> // packB
732  (
733  std::min( jb - j, NR ), pb,
734  &B0[ pc ], &B1[ pc ], ldb, delta, &bmap[ jc + j ], &packB[ jp * pb ]
735  );
736  }
737 
738  }
739  else
740  {
741  if ( delta == 0 || B1 == NULL ) {
742  pack2D<false, PACK_NR> // packB (transB)
743  (
744  std::min( jb - j, NR ), pb,
745  &B0[ pc ], ldb, &bmap[ jc + j ], &packB[ jp * pb ]
746  );
747  } else {
748  pack2D<false, PACK_NR> // packB (transB)
749  (
750  std::min( jb - j, NR ), pb,
751  &B0[ pc ], &B1[ pc ], ldb, delta, &bmap[ jc + j ], &packB[ jp * pb ]
752  );
753 
754 
755  }
756 
757  }
758 
759  }
760  pc_comm.Barrier();
761 
762  for ( int ic = loop4th.beg();
763  ic < loop4th.end();
764  ic += loop4th.inc() ) // beg 4th loop
765  {
766  auto &ic_comm = *thread.ic_comm;
767  auto ib = std::min( m - ic, MC );
768  auto looppkA = GetRange( 0, ib, MR, thread.jr_id, thread.jr_nt );
769  auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt );
770 
771  for ( int i = looppkA.beg(), ip = packpkA.beg();
772  i < looppkA.end();
773  i += looppkA.inc(), ip += packpkA.inc() )
774  {
775 
776  //assert( lda == k );
777  //For transpose cases, lda should be equal to k.
778 
779  if ( transA == HMLP_OP_N )
780  {
781 
782  if ( gamma == 0 || A1 == NULL ) {
783  pack2D<false, PACK_MR> // packA
784  (
785  std::min( ib - i, MR ), pb,
786  &A0[ pc ], lda, &amap[ ic + i ], &packA[ ip * pb ]
787  );
788  } else {
789  pack2D<false, PACK_MR> // packA
790  (
791  std::min( ib - i, MR ), pb,
792  &A0[ pc ], &A1[ pc ], lda, gamma, &amap[ ic + i ], &packA[ ip * pb ]
793  );
794  }
795 
796  }
797  else
798  {
799 
800  if ( gamma == 0 || A1 == NULL ) {
801  pack2D<true, PACK_MR> // packA (transA)
802  (
803  std::min( ib - i, MR ), pb,
804  &A0[ pc ], lda, &amap[ ic + i ], &packA[ ip * pb ]
805  );
806  } else {
807  pack2D<true, PACK_MR> // packA (transA)
808  (
809  std::min( ib - i, MR ), pb,
810  &A0[ pc ], &A1[ pc ], lda, gamma, &amap[ ic + i ], &packA[ ip * pb ]
811  );
812 
813  }
814 
815  }
816 
817  }
818  ic_comm.Barrier();
819 
820 // if ( is_the_last_pc_iteration ) // fused_macro_kernel
821 // {
822 // if ( alpha1 == 0 || C1 == NULL ) {
823 //
824 // //hmlp::gkmx::fused_macro_kernel
825 // //<KC, MR, NR, PACK_MR, PACK_NR, RANK_MICROKERNEL, TA, TB, TC, TV>
826 // //(
827 // // thread,
828 // // ic, jc, pc,
829 // // ib, jb, pb,
830 // // packA,
831 // // packB,
832 // // C0 + jc * ldc + ic, ldc,
833 // // rank_microkernel
834 // //);
835 //
836 // //printf( "before fused macro kernel\n" );
837 // fused_macro_kernel
838 // <KC, MR, NR, PACK_MR, PACK_NR, STRA_MICROKERNEL, TA, TB, TC, TV>
839 // (
840 // thread,
841 // ic, jc, pc,
842 // ib, jb, pb,
843 // packA,
844 // packB,
845 // C0 + jc * ldc + ic,
846 // NULL, ldc, alpha0, 0,
847 // stra_microkernel
848 // );
849 // //printf( "after fused macro kernel\n" );
850 //
851 // } else {
852 // fused_macro_kernel
853 // <KC, MR, NR, PACK_MR, PACK_NR, STRA_MICROKERNEL, TA, TB, TC, TV>
854 // (
855 // thread,
856 // ic, jc, pc,
857 // ib, jb, pb,
858 // packA,
859 // packB,
860 // C0 + jc * ldc + ic,
861 // C1 + jc * ldc + ic, ldc, alpha0, alpha1,
862 // stra_microkernel
863 // );
864 // }
865 //
866 // }
867 // else // semiring rank-k update
868 // {
869 
870  if ( alpha1 == 0 || C1 == NULL )
871  {
872  //hmlp::gkmx::rank_k_macro_kernel
873  //<KC, MR, NR, PACK_MR, PACK_NR, RANK_SEMIRINGKERNEL, TA, TB, TC, TV>
874  //(
875  // thread,
876  // ic, jc, pc,
877  // ib, jb, pb,
878  // packA,
879  // packB,
880  // C0 + jc * ldc + ic, ldc,
881  // rank_semiringkernel
882  //);
883 
885  //strassen_macro_kernel
886  <KC, MR, NR, PACK_MR, PACK_NR, STRA_SEMIRINGKERNEL, TA, TB, TC, TV>
887  (
888  thread,
889  ic, jc, pc,
890  ib, jb, pb,
891  packA,
892  packB,
893  C0 + jc * ldc + ic,
894  NULL, ldc, alpha0, 0,
895  stra_semiringkernel
896  );
897 
898  }
899  else
900  {
901 
903  //strassen_macro_kernel
904  <KC, MR, NR, PACK_MR, PACK_NR, STRA_SEMIRINGKERNEL, TA, TB, TC, TV>
905  (
906  thread,
907  ic, jc, pc,
908  ib, jb, pb,
909  packA,
910  packB,
911  C0 + jc * ldc + ic,
912  C1 + jc * ldc + ic, ldc, alpha0, alpha1,
913  stra_semiringkernel
914  );
915 
916  }
917 
918 // }
919  ic_comm.Barrier(); // sync all jr_id!!
920  } // end 4th loop
921  pc_comm.Barrier();
922  } // end 5th loop
923  } // end 6th loop
924 } // end strassen_internal
925 
926 template<typename TA, typename TB, typename TV>
927 void hmlp_dynamic_peeling
928 (
929  hmlpOperation_t transA, hmlpOperation_t transB,
930  int m, int n, int k,
931  TA *A, int lda,
932  TB *B, int ldb,
933  TV *C, int ldc,
934  int dim1, int dim2, int dim3
935 )
936 {
937  //printf( "Enter dynamic peeling\n" );
938  int mr = m % dim1;
939  int kr = k % dim2;
940  int nr = n % dim3;
941  int ms = m - mr;
942  int ns = n - nr;
943  int ks = k - kr;
944  TA *A_extra;
945  TB *B_extra;
946  TV *C_extra;
947 
948  char transA_val, transB_val;
949  char *char_transA = &transA_val, *char_transB = &transB_val;
950 
951 
952  //printf( "flag d1\n" );
953 
954  // Adjust part handled by fast matrix multiplication.
955  // Add far column of A outer product bottom row B
956  if ( kr > 0 ) {
957  // In Strassen, this looks like C([1, 2], [1, 2]) += A([1, 2], 3) * B(3, [1, 2])
958 
959  //printf( "flag d2\n" );
960 
961  if ( transA == HMLP_OP_N ) {
962  A_extra = &A[ 0 + ks * lda ];//ms * kr
963  *char_transA = 'N';
964  } else {
965  A_extra = &A[ 0 * lda + ks ];//ms * kr
966  *char_transA = 'T';
967  }
968 
969  //printf( "flag d3\n" );
970  if ( transB == HMLP_OP_N ) {
971  B_extra = &B[ ks + 0 * ldb ];//kr * ns
972  *char_transB = 'N';
973  } else {
974  B_extra = &B[ ks * ldb + 0 ];//kr * ns
975  *char_transB = 'T';
976  }
977 
978  //printf( "flag d4\n" );
979  C_extra = &C[ 0 + 0 * ldc ];//ms * ns
980  if ( ms > 0 && ns > 0 )
981  {
982  //bl_dgemm( ms, ns, kr, A_extra, lda, B_extra, ldb, C_extra, ldc );
983  xgemm( char_transA, char_transB, ms, ns, kr, 1.0, A_extra, lda, B_extra, ldb, 1.0, C_extra, ldc );
984  }
985  }
986 
987  //printf( "flag d5\n" );
988 
989  // Adjust for far right columns of C
990  if ( nr > 0 ) {
991  // In Strassen, this looks like C(:, 3) = A * B(:, 3)
992 
993  if ( transA == HMLP_OP_N ) {
994  *char_transA = 'N';
995  } else {
996  *char_transA = 'T';
997  }
998  //printf( "flag d6\n" );
999 
1000  if ( transB == HMLP_OP_N ) {
1001  B_extra = &B[ 0 + ns * ldb ];//k * nr
1002  *char_transB = 'N';
1003  } else {
1004  B_extra = &B[ 0 * ldb + ns ];//k * nr
1005  *char_transB = 'T';
1006  }
1007 
1008 
1009  //printf( "flag d7\n" );
1010 
1011  C_extra = &C[ 0 + ns * ldc ];//m * nr
1012  //bl_dgemm( m, nr, k, A, lda, B_extra, ldb, C_extra, ldc );
1013  xgemm( char_transA, char_transB, m, nr, k, 1.0, A, lda, B_extra, ldb, 1.0, C_extra, ldc );
1014 
1015  }
1016 
1017  //printf( "flag d8\n" );
1018 
1019  // Adjust for bottom rows of C
1020  if ( mr > 0 ) {
1021  // In Strassen, this looks like C(3, [1, 2]) = A(3, :) * B(:, [1, 2])
1022 
1023 
1024  //printf( "flag d8.1\n" );
1025  if ( transA == HMLP_OP_N ) {
1026 
1027  //printf( "flag d8.15\n" );
1028  A_extra = &A[ ms + 0 * lda ];// mr * k
1029  //printf( "flag d8.16\n" );
1030  *char_transA = 'N';
1031 
1032  //printf( "flag d8.2\n" );
1033  } else {
1034  A_extra = &A[ ms * lda + 0 ];// mr * k
1035  *char_transA = 'T';
1036  //printf( "flag d8.3\n" );
1037  }
1038 
1039  //printf( "flag d8.4\n" );
1040 
1041  if ( transB == HMLP_OP_N ) {
1042  B_extra = &B[ 0 + 0 * ldb ];// k * ns
1043  *char_transB = 'N';
1044 
1045  //printf( "flag d8.5\n" );
1046 
1047  } else {
1048  B_extra = &B[ 0 * ldb + 0 ];// k * ns
1049  *char_transB = 'T';
1050 
1051  //printf( "flag d8.6\n" );
1052  }
1053 
1054  //printf( "flag d9\n" );
1055 
1056  TV *C_extra = &C[ ms + 0 * ldc ];// mr * ns
1057  if ( ns > 0 )
1058  {
1059  //bl_dgemm( mr, ns, k, A_extra, lda, B_extra, ldb, C_extra, ldc );
1060  xgemm( char_transA, char_transB, mr, ns, k, 1.0, A_extra, lda, B_extra, ldb, 1.0, C_extra, ldc );
1061  }
1062  }
1063  //printf( "Leave dynamic peeling\n" );
1064 }
1065 
1066 template<
1067  int MC, int NC, int KC, int MR, int NR,
1068  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
1069  bool USE_STRASSEN,
1070  typename STRA_SEMIRINGKERNEL, typename STRA_MICROKERNEL,
1071  typename TA, typename TB, typename TC, typename TV>
1072 void strassen_internal
1073 (
1074  Worker &thread,
1075  hmlpOperation_t transA, hmlpOperation_t transB,
1076  int m, int n, int k,
1077  TA *A, int lda, int *amap,
1078  TB *B, int ldb, int *bmap,
1079  TV *C, int ldc,
1080  STRA_SEMIRINGKERNEL stra_semiringkernel,
1081  STRA_MICROKERNEL stra_microkernel,
1082  int nc, int pack_nc,
1083  TA *packA_buff,
1084  TB *packB_buff
1085 )
1086 {
1087 
1088  int ms, ks, ns;
1089  int md, kd, nd;
1090  int mr, kr, nr;
1091 
1092  mr = m % ( 2 ), kr = k % ( 2 ), nr = n % ( 2 );
1093  md = m - mr, kd = k - kr, nd = n - nr;
1094 
1095  // Partition code.
1096  ms=md, ks=kd, ns=nd;
1097  TA *A00, *A01, *A10, *A11;
1098  hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 0, 0, &A00 );
1099  hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 0, 1, &A01 );
1100  hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 1, 0, &A10 );
1101  hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 1, 1, &A11 );
1102 
1103  TB *B00, *B01, *B10, *B11;
1104  hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 0, 0, &B00 );
1105  hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 0, 1, &B01 );
1106  hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 1, 0, &B10 );
1107  hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 1, 1, &B11 );
1108 
1109  TV *C00, *C01, *C10, *C11;
1110  hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 0, 0, &C00 );
1111  hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 0, 1, &C01 );
1112  hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 1, 0, &C10 );
1113  hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 1, 1, &C11 );
1114 
1115  md = md / 2, kd = kd / 2, nd = nd / 2;
1116 
1117  // M1: C00 = 1*C00+1*(A00+A11)(B00+B11); C11 = 1*C11+1*(A00+A11)(B00+B11)
1118  STRAPRIM_MAP( A00, A11, 1, B00, B11, 1, C00, C11, 1, 1 );
1119  // M2: C10 = 1*C10+1*(A10+A11)B00; C11 = 1*C11-1*(A10+A11)B00
1120  STRAPRIM_MAP( A10, A11, 1, B00, NULL, 0, C10, C11, 1, -1 )
1121  // M3: C01 = 1*C01+1*A00(B01-B11); C11 = 1*C11+1*A00(B01-B11)
1122  STRAPRIM_MAP( A00, NULL, 0, B01, B11, -1, C01, C11, 1, 1 )
1123  // M4: C00 = 1*C00+1*A11(B10-B00); C10 = 1*C10+1*A11(B10-B00)
1124  STRAPRIM_MAP( A11, NULL, 0, B10, B00, -1, C00, C10, 1, 1 )
1125  // M5: C00 = 1*C00-1*(A00+A01)B11; C01 = 1*C01+1*(A00+A01)B11
1126  STRAPRIM_MAP( A00, A01, 1, B11, NULL, 0, C00, C01, -1, 1 )
1127  // M6: C11 = 1*C11+(A10-A00)(B00+B01)
1128  STRAPRIM_MAP( A10, A00, -1, B00, B01, 1, C11, NULL, 1, 0 )
1129  // M7: C00 = 1*C00+(A01-A11)(B10+B11)
1130  STRAPRIM_MAP( A01, A11, -1, B10, B11, 1, C00, NULL, 1, 0 )
1131 
1132  if ( omp_get_thread_num() == 0 ) { //Chief thread
1133  hmlp_dynamic_peeling( transA, transB, m, n, k, A, lda, B, ldb, C, ldc, 2, 2, 2 );
1134  }
1135 
1136 }
1137 
1138 template<
1139  int MC, int NC, int KC, int MR, int NR,
1140  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
1141  bool USE_STRASSEN,
1142  typename STRA_SEMIRINGKERNEL, typename STRA_MICROKERNEL,
1143  typename TA, typename TB, typename TC, typename TV>
1144 void strassen_internal
1145 (
1146  Worker &thread,
1147  hmlpOperation_t transA, hmlpOperation_t transB,
1148  int m, int n, int k,
1149  TA *A, int lda,
1150  TB *B, int ldb,
1151  TV *C, int ldc,
1152  STRA_SEMIRINGKERNEL stra_semiringkernel,
1153  STRA_MICROKERNEL stra_microkernel,
1154  int nc, int pack_nc,
1155  TA *packA_buff,
1156  TB *packB_buff
1157 )
1158 {
1159 
1160  int ms, ks, ns;
1161  int md, kd, nd;
1162  int mr, kr, nr;
1163 
1164  mr = m % ( 2 ), kr = k % ( 2 ), nr = n % ( 2 );
1165  md = m - mr, kd = k - kr, nd = n - nr;
1166 
1167  // Partition code.
1168  ms=md, ks=kd, ns=nd;
1169  TA *A00, *A01, *A10, *A11;
1170  hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 0, 0, &A00 );
1171  hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 0, 1, &A01 );
1172  hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 1, 0, &A10 );
1173  hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 1, 1, &A11 );
1174 
1175  TB *B00, *B01, *B10, *B11;
1176  hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 0, 0, &B00 );
1177  hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 0, 1, &B01 );
1178  hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 1, 0, &B10 );
1179  hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 1, 1, &B11 );
1180 
1181  TV *C00, *C01, *C10, *C11;
1182  hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 0, 0, &C00 );
1183  hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 0, 1, &C01 );
1184  hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 1, 0, &C10 );
1185  hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 1, 1, &C11 );
1186 
1187  md = md / 2, kd = kd / 2, nd = nd / 2;
1188 
1189  // M1: C00 = 1*C00+1*(A00+A11)(B00+B11); C11 = 1*C11+1*(A00+A11)(B00+B11)
1190  STRAPRIM( A00, A11, 1, B00, B11, 1, C00, C11, 1, 1 );
1191 
1192  //printf( "A00:\n" );
1193  //hmlp_printmatrix( md, kd, A00, m );
1194  //printf( "A11:\n" );
1195  //hmlp_printmatrix( md, kd, A11, m );
1196  //printf( "B00:\n" );
1197  //hmlp_printmatrix( kd, nd, B00, k );
1198  //printf( "B11:\n" );
1199  //hmlp_printmatrix( kd, nd, B11, k );
1200  //printf( "C00:\n" );
1201  //hmlp_printmatrix( md, nd, C00, m );
1202  //printf( "C01:\n" );
1203  //hmlp_printmatrix( md, nd, C11, m );
1204 
1205  // M2: C10 = 1*C10+1*(A10+A11)B00; C11 = 1*C11-1*(A10+A11)B00
1206  STRAPRIM( A10, A11, 1, B00, NULL, 0, C10, C11, 1, -1 )
1207 
1208  // M3: C01 = 1*C01+1*A00(B01-B11); C11 = 1*C11+1*A00(B01-B11)
1209  STRAPRIM( A00, NULL, 0, B01, B11, -1, C01, C11, 1, 1 )
1210  // M4: C00 = 1*C00+1*A11(B10-B00); C10 = 1*C10+1*A11(B10-B00)
1211  STRAPRIM( A11, NULL, 0, B10, B00, -1, C00, C10, 1, 1 )
1212  // M5: C00 = 1*C00-1*(A00+A01)B11; C01 = 1*C01+1*(A00+A01)B11
1213  STRAPRIM( A00, A01, 1, B11, NULL, 0, C00, C01, -1, 1 )
1214  // M6: C11 = 1*C11+(A10-A00)(B00+B01)
1215  STRAPRIM( A10, A00, -1, B00, B01, 1, C11, NULL, 1, 0 )
1216  // M7: C00 = 1*C00+(A01-A11)(B10+B11)
1217  STRAPRIM( A01, A11, -1, B10, B11, 1, C00, NULL, 1, 0 )
1218 
1219  //printf( "C00:" );
1220  //hmlp_printmatrix( md, nd, C00, m );
1221 
1222  //printf( "before dynamic peeling\n" );
1223 
1224  if ( omp_get_thread_num() == 0 ) { //Chief thread
1225  hmlp_dynamic_peeling( transA, transB, m, n, k, A, lda, B, ldb, C, ldc, 2, 2, 2 );
1226  }
1227 
1228 }
1229 
1230 
1235 template<
1236  int MC, int NC, int KC, int MR, int NR,
1237  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
1238  bool USE_STRASSEN,
1239  typename STRA_SEMIRINGKERNEL, typename STRA_MICROKERNEL,
1240  typename TA, typename TB, typename TC, typename TV>
1241 void strassen
1242 (
1243  hmlpOperation_t transA, hmlpOperation_t transB,
1244  int m, int n, int k,
1245  TA *A, int lda,
1246  TB *B, int ldb,
1247  TV *C, int ldc,
1248  STRA_SEMIRINGKERNEL stra_semiringkernel,
1249  STRA_MICROKERNEL stra_microkernel
1250 )
1251 {
1252  int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1;
1253  int nc = NC, pack_nc = PACK_NC;
1254  char *str;
1255 
1256  TA *packA_buff = NULL;
1257  TB *packB_buff = NULL;
1258 
1259  // Early return if possible
1260  if ( m == 0 || n == 0 || k == 0 ) return;
1261 
1262  // Check the environment variable.
1263  jc_nt = hmlp_read_nway_from_env( "KS_JC_NT" );
1264  ic_nt = hmlp_read_nway_from_env( "KS_IC_NT" );
1265  jr_nt = hmlp_read_nway_from_env( "KS_JR_NT" );
1266 
1267 
1268  if ( jc_nt > 1 )
1269  {
1270  nc = ( ( n - 1 ) / ( NR * jc_nt ) + 1 ) * NR;
1271  pack_nc = ( nc / NR ) * PACK_NR;
1272  }
1273 
1274  // allocate packing memory
1275  packA_buff = hmlp_malloc<ALIGN_SIZE, TA>( KC, ( PACK_MC + 1 ) * jc_nt * ic_nt, sizeof(TA) );
1276  packB_buff = hmlp_malloc<ALIGN_SIZE, TB>( KC, ( pack_nc + 1 ) * jc_nt, sizeof(TB) );
1277 
1278  // allocate tree communicator
1279  thread_communicator my_comm( jc_nt, pc_nt, ic_nt, jr_nt );
1280 
1281  #pragma omp parallel num_threads( my_comm.GetNumThreads() )
1282  {
1283  Worker thread( &my_comm );
1284 
1285  strassen_internal
1286  <MC, NC, KC, MR, NR,
1287  PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
1288  USE_STRASSEN,
1289  STRA_SEMIRINGKERNEL, STRA_MICROKERNEL,
1290  TA, TB, TC, TB>
1291  (
1292  thread,
1293  transA, transB,
1294  m, n, k,
1295  A, lda,
1296  B, ldb,
1297  C, ldc,
1298  stra_semiringkernel, stra_microkernel,
1299  nc, pack_nc,
1300  packA_buff,
1301  packB_buff
1302  );
1303 
1304  }
1305  // end omp
1306 } // end strassen
1307 
1308 
1309 }; // end namespace strassen
1310 }; // end namespace hmlp
1311 
1312 #endif // define STRASSEN_HPP
1313 
1314 
Definition: thread.hpp:107
void rank_k_macro_kernel(tci::Comm &Comm3rd, int ic, int jc, int pc, int m, int n, int k, TA *packA, TB *packB, TV *V, int rs_v, int cs_v, SEMIRINGKERNEL semiringkernel)
Macro kernel contains the 3rd and 2nd loops. Depending on the configuration of the communicator...
Definition: rank_k.hpp:51
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
void hmlp_acquire_mpart(hmlpOperation_t transX, int m, int n, T *src_buff, int lda, int x, int y, int i, int j, T **dst_buff)
Split into m x n, get the subblock starting from ith row and jth column. (for STRASSEN) ...
Definition: util.hpp:143
Definition: hmlp_internal.hpp:38
void Barrier()
OpenMP thread barrier from BLIS.
Definition: thread.cpp:227
Definition: gofmm.hpp:83
Definition: thread.hpp:166