31 #include <hmlp_internal.hpp> 32 #include <hmlp_base.hpp> 38 #include <packing.hpp> 39 #include <semiring_mrxnr.hpp> 40 #include <fused_mrxnr.hpp> 55 template<
int KC,
typename SEMIRINGKERNEL,
typename TA,
typename TB,
typename TV>
56 void rank_k_macro_kernel
59 int ic,
int jc,
int pc,
63 TV *V,
int rs_v,
int cs_v,
64 SEMIRINGKERNEL semiringkernel
68 const static int MR = SEMIRINGKERNEL::mr;
69 const static int NR = SEMIRINGKERNEL::nr;
70 const static int PACK_MR = SEMIRINGKERNEL::pack_mr;
71 const static int PACK_NR = SEMIRINGKERNEL::pack_nr;
77 auto Loop3rd = Comm4th.DistributeOver1DGangs( 0, n, NR );
78 auto Pack3rd = Comm4th.DistributeOver1DGangs( 0, n, PACK_NR );
79 auto Loop2nd = Comm4th.DistributeOver1DThreads( 0, m, MR );
80 auto Pack2nd = Comm4th.DistributeOver1DThreads( 0, m, PACK_MR );
83 for (
int j = get<0>( Loop3rd ), jp = get<0>( Pack3rd );
84 j < get<1>( Loop3rd );
85 j += get<2>( Loop3rd ), jp += get<2>( Pack3rd ) )
87 struct aux_s<TA, TB, TV, TV> aux;
91 aux.jb = std::min( n - j, NR );
94 for (
int i = get<0>( Loop2nd ), ip = get<0>( Pack2nd );
95 i < get<1>( Loop2nd );
96 i += get<2>( Loop2nd ), ip += get<2>( Pack2nd ) )
98 aux.ib = std::min( m - i, MR );
101 aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
104 if ( aux.jb == NR && aux.ib == MR )
111 &V[ i * rs_v + j * cs_v ], rs_v, cs_v,
121 for (
auto jj = 0; jj < aux.jb; jj ++ )
122 for (
auto ii = 0; ii < aux.ib; ii ++ )
123 vtmp[ jj * MR + ii ] =
124 V[ ( j + jj ) * cs_v + ( i + ii ) * rs_v ];
136 for (
auto jj = 0; jj < aux.jb; jj ++ )
137 for (
auto ii = 0; ii < aux.ib; ii ++ )
138 V[ ( j + jj ) * cs_v + ( i + ii ) * rs_v ] = vtmp[ jj * MR + ii ];
154 template<
int KC,
typename FUSEDKERNEL,
typename TA,
typename TB,
typename TC,
typename TV>
155 void fused_macro_kernel
159 int ic,
int jc,
int pc,
160 int mc,
int nc,
int kc,
164 TV *V,
int rs_v,
int cs_v,
166 FUSEDKERNEL fusedkernel
170 const static int MR = FUSEDKERNEL::mr;
171 const static int NR = FUSEDKERNEL::nr;
172 const static int PACK_MR = FUSEDKERNEL::pack_mr;
173 const static int PACK_NR = FUSEDKERNEL::pack_nr;
179 auto Loop3rd = Comm4th.DistributeOver1DGangs( 0, nc, NR );
180 auto Pack3rd = Comm4th.DistributeOver1DGangs( 0, nc, PACK_NR );
181 auto Loop2nd = Comm4th.DistributeOver1DThreads( 0, mc, MR );
182 auto Pack2nd = Comm4th.DistributeOver1DThreads( 0, mc, PACK_MR );
185 for (
int j = get<0>( Loop3rd ), jp = get<0>( Pack3rd );
186 j < get<1>( Loop3rd );
187 j += get<2>( Loop3rd ), jp += get<2>( Pack3rd ) )
189 struct aux_s<TA, TB, TC, TV> aux;
195 for (
int i = get<0>( Loop2nd ), ip = get<0>( Pack2nd );
196 i < get<1>( Loop2nd );
197 i += get<2>( Loop2nd ), ip += get<2>( Pack2nd ) )
212 aux.ib = std::min( mc - i, MR );
213 aux.jb = std::min( nc - j, NR );
218 aux.V = V + i * rs_v + j * cs_v;
223 aux.b_next += ic_comm.GetNumThreads() * PACK_NR * kc;
226 if ( aux.jb == NR && aux.ib == MR )
234 &V[ i * rs_v + j * cs_v ], rs_v, cs_v,
243 for (
auto jj = 0; jj < aux.jb; jj ++ )
244 for (
auto ii = 0; ii < aux.ib; ii ++ )
245 vtmp[ jj * MR + ii ] =
246 V[ ( j + jj ) * cs_v + ( i + ii ) * rs_v ];
276 int MC,
int NC,
int KC,
277 typename TPACKA,
typename TPACKB,
typename TV,
278 typename TA,
typename TB,
typename TC,
279 typename SEMIRINGKERNEL,
typename MICROKERNEL>
283 int batchId,
int m,
int n,
int k,
int k_stra,
287 TV* V,
int rs_v,
int cs_v,
288 SEMIRINGKERNEL semiringkernel,
289 MICROKERNEL microkernel
293 const static int MR = SEMIRINGKERNEL::mr;
294 const static int NR = SEMIRINGKERNEL::nr;
295 const static int PACK_MR = SEMIRINGKERNEL::pack_mr;
296 const static int PACK_NR = SEMIRINGKERNEL::pack_nr;
297 const static int ALIGN_SIZE = SEMIRINGKERNEL::align_size;
298 const static int PACK_MC = ( MC / MR ) * PACK_MR;
299 const static int PACK_NC = ( NC / NR ) * PACK_NR;
302 auto CommGLB = thread.
Split();
303 auto Comm6th = CommGLB.
Split();
304 auto Comm5th = Comm6th.
Split();
305 auto Comm4th = Comm5th.
Split();
309 int nc = CommGLB.BalanceOver1DGangs( n, NC, NR );
310 int pack_nc = ( nc / NR ) * PACK_NR;
326 auto *packA = Comm4th.AllocateSharedMemory<ALIGN_SIZE, TPACKA>( KC * ( PACK_MC + 1 ) );
327 auto *packB = Comm5th.AllocateSharedMemory<ALIGN_SIZE, TPACKB>( KC * ( pack_nc + 1 ) );
330 auto Loop6th = CommGLB.DistributeOver1DGangs( 0, n, nc );
331 auto Loop5th = Comm6th.DistributeOver1DGangs( k_stra, k, KC );
332 auto Loop4th = Comm5th.DistributeOver1DGangs( 0, m, MC );
335 for (
int jc = get<0>( Loop6th );
336 jc < get<1>( Loop6th );
337 jc += get<2>( Loop6th ) )
339 auto jb = std::min( n - jc, nc );
343 for (
int pc = get<0>( Loop5th );
344 pc < get<1>( Loop5th );
345 pc += get<2>( Loop5th ) )
347 auto pb = std::min( k - pc, KC );
348 auto is_the_last_pc_iteration = ( pc + KC >= k );
349 auto LooppkB = Comm5th.DistributeOver1DThreads( 0, jb, NR );
350 auto PackpkB = Comm5th.DistributeOver1DThreads( 0, jb, PACK_NR );
352 for (
int j = get<0>( LooppkB ), jp = get<0>( PackpkB );
353 j < get<1>( LooppkB );
354 j += get<2>( LooppkB ), jp += get<2>( PackpkB ) )
359 n, jc + j, std::min( jb - j, NR ),
366 for (
int ic = get<0>( Loop4th );
367 ic < get<1>( Loop4th );
368 ic += get<2>( Loop4th ) )
370 auto &ic_comm = *thread.ic_comm;
371 auto ib = std::min( m - ic, MC );
372 auto LooppkA = Comm4th.DistributeOver1DThreads( 0, ib, MR );
373 auto PackpkA = Comm4th.DistributeOver1DThreads( 0, ib, PACK_MR );
375 for (
int i = get<0>( LooppkA ), ip = get<0>( PackpkA );
376 i < get<1>( LooppkA );
377 i += get<2>( LooppkA ), ip += get<2>( PackpkA ) )
381 m, ic + i, std::min( ib - i, MR ),
387 if ( is_the_last_pc_iteration )
389 fused_macro_kernel<KC>
398 V + ic * rs_v + jc * cs_v, rs_v, cs_v,
406 rank_k_macro_kernel<KC>
413 V + ic * rs_v + jc * cs_v, rs_v, cs_v,
426 Comm4th.FreeSharedMemory( packA );
427 Comm5th.FreeSharedMemory( packB );
442 int MC,
int NC,
int KC,
443 typename TPACKA,
typename TPACKB,
typename TV,
444 typename TA,
typename TB,
typename TC,
445 typename SEMIRINGKERNEL,
typename MICROKERNEL>
448 int batchId,
int m,
int n,
int k,
452 SEMIRINGKERNEL semiringkernel,
453 MICROKERNEL microkernel
456 const static int MR = SEMIRINGKERNEL::mr;
457 const static int NR = SEMIRINGKERNEL::nr;
458 const static int PACK_MR = SEMIRINGKERNEL::pack_mr;
459 const static int PACK_NR = SEMIRINGKERNEL::pack_nr;
460 const static int ALIGN_SIZE = SEMIRINGKERNEL::align_size;
461 const static int PACK_MC = ( MC / MR ) * PACK_MR;
462 const static int PACK_NC = ( NC / NR ) * PACK_NR;
463 const static bool USE_STRASSEN =
false;
466 if ( m == 0 || n == 0 || k == 0 )
return;
477 V = hmlp_malloc<ALIGN_SIZE, TV>( m * n );
484 V =
reinterpret_cast<TV*
>( C.X );
493 assert(
typeid(TPACKA) ==
typeid(TPACKB) );
494 assert(
typeid(TC) ==
typeid(TV) );
497 if ( k_stra == k ) k_stra -= KC;
500 int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1;
501 if ( omp_get_num_threads() == 1 && omp_get_max_threads() > 1 )
504 jc_nt = hmlp_read_nway_from_env(
"KS_JC_NT" );
505 ic_nt = hmlp_read_nway_from_env(
"KS_IC_NT" );
506 jr_nt = hmlp_read_nway_from_env(
"KS_JR_NT" );
512 #pragma omp parallel num_threads( my_comm.GetNumThreads() ) 514 Worker thread( &my_comm );
517 thread.InitWithCommunicator( &my_comm, omp_get_thread_num(), 0 );
540 gnbx_internal<MC, NC, KC, TPACKA, TPACKB>
543 batchId, m, n, k, k_stra,
548 semiringkernel, microkernel
566 int MR,
int NR,
int MC,
int NC,
int KC,
567 typename TPACKA,
typename TPACKB,
typename TPACKC,
typename TV,
568 typename TA,
typename TB,
typename TC,
569 typename OPKERNEL,
typename OP1,
typename OP2>
572 int batchId,
int m,
int n,
int k,
576 OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV
582 semiringkernel.op1 = op1;
583 semiringkernel.op2 = op2;
584 semiringkernel.initV = initV;
586 gkrmkernel.op1 = op1;
587 gkrmkernel.op2 = op2;
588 gkrmkernel.opkernel = opkernel;
589 gkrmkernel.initV = initV;
591 gnbx<MC, NC, KC, TPACKA, TPACKB, TV>
592 ( batchId, m, n, k, A, B, C, semiringkernel, gkrmkernel );
Definition: semiring_mrxnr.hpp:11
Worker Split()
Definition: thread.cpp:391
Definition: thread.hpp:107
Definition: fused_mrxnr.hpp:216
Definition: hmlp_internal.hpp:38
Definition: packing.hpp:198
void hmlp_free(T *ptr)
Free the aligned memory.
Definition: util.hpp:88
Definition: thread.hpp:166