31 #include <hmlp_internal.hpp> 32 #include <hmlp_base.hpp> 35 #include <primitives/strassen.hpp> 38 #include <semiring_mrxnr.hpp> 39 #include <fused_mrxnr.hpp> 55 int KC,
int MR,
int NR,
int PACK_MR,
int PACK_NR,
56 typename SEMIRINGKERNEL,
57 typename TA,
typename TB,
typename TC,
typename TV>
61 int ic,
int jc,
int pc,
66 SEMIRINGKERNEL semiringkernel
71 auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
72 auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
73 auto loop2nd = GetRange( 0, m, MR );
74 auto pack2nd = GetRange( 0, m, PACK_MR );
76 for (
int j = loop3rd.beg(), jp = pack3rd.beg();
78 j += loop3rd.inc(), jp += pack3rd.inc() )
80 struct aux_s<TA, TB, TC, TV> aux;
84 aux.jb = std::min( n - j, NR );
86 for (
int i = loop2nd.beg(), ip = pack2nd.beg();
88 i += loop2nd.inc(), ip += pack2nd.inc() )
90 aux.ib = std::min( m - i, MR );
93 aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
96 if ( aux.jb == NR && aux.ib == MR )
103 &V[ j * ldv + i ], 1, ldv,
113 for (
auto jj = 0; jj < aux.jb; jj ++ )
114 for (
auto ii = 0; ii < aux.ib; ii ++ )
115 vtmp[ jj * MR + ii ] = V[ ( j + jj ) * ldv + i + ii ];
127 for (
auto jj = 0; jj < aux.jb; jj ++ )
128 for (
auto ii = 0; ii < aux.ib; ii ++ )
129 V[ ( j + jj ) * ldv + i + ii ] = vtmp[ jj * MR + ii ];
143 int KC,
int MR,
int NR,
int PACK_MR,
int PACK_NR,
145 typename FUSEDKERNEL,
146 typename TA,
typename TB,
typename TC,
typename TV>
147 void fused_macro_kernel
150 int ic,
int jc,
int pc,
157 FUSEDKERNEL fusedkernel
162 auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
163 auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
164 auto loop2nd = GetRange( 0, m, MR );
165 auto pack2nd = GetRange( 0, m, PACK_MR );
167 for (
int j = loop3rd.beg(), jp = pack3rd.beg();
169 j += loop3rd.inc(), jp += pack3rd.inc() )
171 struct aux_s<TA, TB, TC, TV> aux;
176 for (
int i = loop2nd.beg(), ip = pack2nd.beg();
178 i += loop2nd.inc(), ip += pack2nd.inc() )
186 aux.ib = std::min( m - i, MR );
187 aux.jb = std::min( n - j, NR );
189 aux.V = V + j * ldv + i;
194 aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
197 if ( aux.jb == NR && aux.ib == MR )
204 &C[ j * ldc + i ], 1, ldc,
218 for (
auto jj = 0; jj < aux.jb; jj ++ )
219 for (
auto ii = 0; ii < aux.ib; ii ++ )
220 ctmp[ jj * MR + ii ] = C[ ( j + jj ) * ldc + i + ii ];
224 for (
auto jj = 0; jj < aux.jb; jj ++ )
225 for (
auto ii = 0; ii < aux.ib; ii ++ )
226 vtmp[ jj * MR + ii ] = V[ ( j + jj ) * ldv + i + ii ];
241 for (
auto jj = 0; jj < aux.jb; jj ++ )
242 for (
auto ii = 0; ii < aux.ib; ii ++ )
243 C[ ( j + jj ) * ldc + i + ii ] = ctmp[ jj * MR + ii ];
274 typename SEMIRINGKERNEL,
typename MICROKERNEL,
275 typename TA,
typename TB,
typename TC,
typename TV>
279 hmlpOperation_t transA, hmlpOperation_t transB,
280 int m,
int n,
int k,
int k_stra,
286 SEMIRINGKERNEL semiringkernel,
287 MICROKERNEL microkernel,
293 packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC
294 + ( thread.ic_id ) * PACK_MC * KC;
295 packB += ( thread.jc_id ) * pack_nc * KC;
297 auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt );
298 auto loop5th = GetRange( k_stra, k, KC );
299 auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
301 for (
int jc = loop6th.beg();
303 jc += loop6th.inc() )
305 auto &jc_comm = *thread.jc_comm;
306 auto jb = std::min( n - jc, nc );
308 for (
int pc = loop5th.beg();
310 pc += loop5th.inc() )
312 auto &pc_comm = *thread.pc_comm;
313 auto pb = std::min( k - pc, KC );
314 auto is_the_last_pc_iteration = ( pc + KC >= k );
315 auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() );
316 auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
318 for (
int j = looppkB.beg(), jp = packpkB.beg();
320 j += looppkB.inc(), jp += packpkB.inc() )
322 if ( transB == HMLP_OP_N )
324 pack2D<true, PACK_NR>
326 std::min( jb - j, NR ), pb,
327 &B[ ( jc + j ) * ldb + pc ], ldb, &packB[ jp * pb ]
332 pack2D<false, PACK_NR>
334 std::min( jb - j, NR ), pb,
335 &B[ pc * ldb + ( jc + j ) ], ldb, &packB[ jp * pb ]
341 for (
int ic = loop4th.beg();
343 ic += loop4th.inc() )
345 auto &ic_comm = *thread.ic_comm;
346 auto ib = std::min( m - ic, MC );
347 auto looppkA = GetRange( 0, ib, MR, thread.jr_id, thread.jr_nt );
348 auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt );
350 for (
int i = looppkA.beg(), ip = packpkA.beg();
352 i += looppkA.inc(), ip += packpkA.inc() )
354 if ( transA == HMLP_OP_N )
356 pack2D<false, PACK_MR>
358 std::min( ib - i, MR ), pb,
359 &A[ pc * lda + ( ic + i ) ], lda, &packA[ ip * pb ]
364 pack2D<true, PACK_MR>
366 std::min( ib - i, MR ), pb,
367 &A[ ( ic + i ) * lda + pc ], lda, &packA[ ip * pb ]
373 if ( is_the_last_pc_iteration )
376 <KC, MR, NR, PACK_MR, PACK_NR, REUSE_C, MICROKERNEL, TA, TB, TC, TV>
383 C + jc * ldc + ic, ldc,
384 V + jc * ldv + ic, ldv,
392 <KC, MR, NR, PACK_MR, PACK_NR, SEMIRINGKERNEL, TA, TB, TC, TV>
400 V + jc * ldv + ic, ldv,
432 bool USE_STRASSEN =
false,
434 typename SEMIRINGKERNEL,
typename MICROKERNEL,
435 typename TA,
typename TB,
typename TC,
typename TV = TC>
438 hmlpOperation_t transA, hmlpOperation_t transB,
444 SEMIRINGKERNEL semiringkernel,
445 MICROKERNEL microkernel
448 int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1;
451 int nc = NC, pack_nc = PACK_NC;
454 TA *packA_buff = NULL;
455 TB *packB_buff = NULL;
459 if ( m == 0 || n == 0 || k == 0 )
return;
462 if (
typeid(TC) !=
typeid(TV) && k > KC )
464 printf(
"gkmx: currently k(%d) must be smaller than %d when TC != TV\n", k, KC );
468 if ( omp_get_num_threads() == 1 && omp_get_max_threads() > 1 )
471 jc_nt = hmlp_read_nway_from_env(
"KS_JC_NT" );
472 ic_nt = hmlp_read_nway_from_env(
"KS_IC_NT" );
473 jr_nt = hmlp_read_nway_from_env(
"KS_JR_NT" );
478 nc = ( ( n - 1 ) / ( NR * jc_nt ) + 1 ) * NR;
479 pack_nc = ( nc / NR ) * PACK_NR;
483 packA_buff = hmlp_malloc<ALIGN_SIZE, TA>( KC * ( PACK_MC + 1 ) * jc_nt * ic_nt );
484 packB_buff = hmlp_malloc<ALIGN_SIZE, TB>( KC * ( pack_nc + 1 ) * jc_nt );
488 if ( k > KC && !std::is_same<TC, TV>::value && !REUSE_C )
490 V = hmlp_malloc<ALIGN_SIZE, TV>( m * n );
495 V =
reinterpret_cast<TV*
>( C );
505 assert(
typeid(TA) ==
typeid(TB) );
506 assert(
typeid(TC) ==
typeid(TV) );
509 if ( k_stra == k ) k_stra -= KC;
513 #pragma omp parallel for 514 for (
int i = 0; i < n * ldv; i ++ ) V[ i ] = 0.0;
519 #pragma omp parallel num_threads( my_comm.GetNumThreads() ) 521 Worker thread( &my_comm );
525 strassen::strassen_internal
527 PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
529 SEMIRINGKERNEL, SEMIRINGKERNEL,
538 semiringkernel, semiringkernel,
547 PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
548 USE_STRASSEN, REUSE_C,
549 SEMIRINGKERNEL, MICROKERNEL,
560 semiringkernel, microkernel,
590 bool USE_STRASSEN =
false,
591 bool REUSE_C =
false,
592 typename OPKERNEL,
typename OP1,
typename OP2,
593 typename TA,
typename TB,
typename TC,
typename TV>
596 hmlpOperation_t transA, hmlpOperation_t transB,
602 OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV
608 semiringkernel.op1 = op1;
609 semiringkernel.op2 = op2;
610 semiringkernel.initV = initV;
612 gkmmkernel.op1 = op1;
613 gkmmkernel.op2 = op2;
614 gkmmkernel.opkernel = opkernel;
615 gkmmkernel.initV = initV;
618 <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
619 USE_STRASSEN, REUSE_C,
630 semiringkernel, gkmmkernel
643 int MC,
int NC,
int KC,
int MR,
int NR,
644 int PACK_MC,
int PACK_NC,
int PACK_MR,
int PACK_NR,
int ALIGN_SIZE,
645 bool USE_STRASSEN,
bool REUSE_C,
646 typename OPKERNEL,
typename OP1,
typename OP2,
647 typename TA,
typename TB,
typename TC,
typename TV>
650 hmlpOperation_t transA, hmlpOperation_t transB,
652 TA *Aarray[],
int lda,
653 TB *Barray[],
int ldb,
654 TC *Carray[],
int ldc,
656 OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV
659 #pragma omp parallel for 660 for (
auto b = 0; b < batchSize; b ++ )
663 <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
674 opkernel, op1, op2, initV
690 int KC,
int MR,
int NR,
691 int PACK_MC,
int PACK_NC,
int PACK_MR,
int PACK_NR,
int ALIGN_SIZE,
692 bool USE_STRASSEN,
bool REUSE_C,
693 typename OPKERNEL,
typename OP1,
typename OP2,
694 typename TA,
typename TB,
typename TC,
typename TV>
697 hmlpOperation_t transA, hmlpOperation_t transB,
699 TA *Aarray,
int lda,
int loa,
700 TB *Barray,
int ldb,
int lob,
701 TC *Carray,
int ldc,
int loc,
703 OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV
706 #pragma omp parallel for 707 for (
auto b = 0; b < batchSize; b ++ )
710 <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
711 USE_STRASSEN, REUSE_C,
717 Aarray + b * loa, lda,
718 Barray + b * lob, ldb,
719 Carray + b * loc, ldc,
721 opkernel, op1, op2, initV
762 bool USE_STRASSEN =
false,
763 typename OPKERNEL,
typename OP1,
typename OP2,
typename OPREDUCE,
764 typename TA,
typename TB,
typename TC,
typename TV = TC>
767 hmlpOperation_t transA, hmlpOperation_t transB,
773 OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV,
774 OPREDUCE opreduce, TC initC
780 semiringkernel.op1 = op1;
781 semiringkernel.op2 = op2;
782 semiringkernel.initV = initV;
784 gkrmkernel.op1 = op1;
785 gkrmkernel.op2 = op2;
786 gkrmkernel.opkernel = opkernel;
787 gkrmkernel.initV = initV;
788 gkrmkernel.opreduce = opreduce;
789 gkrmkernel.initC = initC;
792 <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
804 semiringkernel, gkrmkernel
815 typename OPKERNEL,
typename OP1,
typename OP2,
816 typename TA,
typename TB,
typename TC,
typename TV = TC>
819 hmlpOperation_t transA, hmlpOperation_t transB,
824 OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV
827 for (
int i = 0; i < m; i ++ )
829 for (
int j = 0; j < n; j ++ )
832 for (
int p = 0; p < k; p ++ )
836 if ( transA == HMLP_OP_N ) a = A[ p * lda + i ];
837 else a = A[ i * lda + p ];
838 if ( transB == HMLP_OP_N ) b = B[ j * ldb + p ];
839 else b = B[ p * ldb + j ];
840 v = op1( v, op2( a, b ) );
842 C[ j * ldc + i ] = opkernel( v );
854 typename OPKERNEL,
typename OP1,
typename OP2,
typename OPREDUCE,
855 typename TA,
typename TB,
typename TC,
typename TV = TC>
858 hmlpOperation_t transA, hmlpOperation_t transB,
864 OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV,
865 OPREDUCE opreduce, TC initC
868 for (
int i = 0; i < m; i ++ )
871 for (
int j = 0; j < n; j ++ )
874 for (
int p = 0; p < k; p ++ )
878 if ( transA == HMLP_OP_N ) a = A[ p * lda + i ];
879 else a = A[ i * lda + p ];
880 if ( transB == HMLP_OP_N ) b = B[ j * ldb + p ];
881 else b = B[ p * ldb + j ];
882 v = op1( v, op2( a, b ) );
884 c = opreduce( c, opkernel( v ) );
894 #endif // define GKMX_HPP Definition: semiring_mrxnr.hpp:11
Definition: thread.hpp:107
void rank_k_macro_kernel(tci::Comm &Comm3rd, int ic, int jc, int pc, int m, int n, int k, TA *packA, TB *packB, TV *V, int rs_v, int cs_v, SEMIRINGKERNEL semiringkernel)
Macro kernel contains the 3rd and 2nd loops. Depending on the configuration of the communicator...
Definition: rank_k.hpp:51
This kernel takes opkernel, op1 and op2 to implement an MR-by-NR GKMM operation.
Definition: fused_mrxnr.hpp:12
Definition: fused_mrxnr.hpp:127
Definition: hmlp_internal.hpp:38
void Barrier()
OpenMP thread barrier from BLIS.
Definition: thread.cpp:227
void hmlp_free(T *ptr)
Free the aligned memory.
Definition: util.hpp:88
Definition: thread.hpp:166