31 #include <hmlp_internal.hpp> 32 #include <hmlp_base.hpp> 37 #include <primitives/rank_k.hpp> 40 #include <packing.hpp> 41 #include <semiring_mrxnr.hpp> 42 #include <fused_mrxnr.hpp> 139 template<
int KC,
typename FUSEDKERNEL,
typename TA,
typename TB,
typename TC,
typename TV>
140 void fused_macro_kernel
144 int ic,
int jc,
int pc,
145 int mc,
int nc,
int kc,
149 TV *V,
int rs_v,
int cs_v,
151 FUSEDKERNEL fusedkernel
155 const static int MR = FUSEDKERNEL::mr;
156 const static int NR = FUSEDKERNEL::nr;
157 const static int PACK_MR = FUSEDKERNEL::pack_mr;
158 const static int PACK_NR = FUSEDKERNEL::pack_nr;
160 auto Comm2nd = Comm3rd.
Split( hmlp_read_nway_from_env(
"KS_JR_NT" ) );
164 auto Loop2nd = Comm2nd.DistributeOver1DThreads( 0, mc, MR );
165 auto Pack2nd = Comm2nd.DistributeOver1DThreads( 0, mc, PACK_MR );
168 for (
int j = Loop3rd.beg(), jp = Pack3rd.beg();
170 j += Loop3rd.inc(), jp += Pack3rd.inc() )
172 struct aux_s<TA, TB, TC, TV> aux;
177 for (
int i = Loop2nd.beg(), ip = Pack2nd.beg();
179 i += Loop2nd.inc(), ip += Pack2nd.inc() )
192 aux.ib = std::min( mc - i, MR );
193 aux.jb = std::min( nc - j, NR );
196 aux.V = V + i * rs_v + j * cs_v;
200 if ( i + MR >= mc ) aux.b_next += Pack3rd.inc() * kc;
204 if ( aux.jb == NR && aux.ib == MR )
206 fusedkernel( kc, &packA[ ip * kc ], &packB[ jp * kc ],
207 C, &V[ i * rs_v + j * cs_v ], rs_v, cs_v, &aux );
214 for (
auto jj = 0; jj < aux.jb; jj ++ )
215 for (
auto ii = 0; ii < aux.ib; ii ++ )
216 vtmp[ jj * MR + ii ] =
217 V[ ( j + jj ) * cs_v + ( i + ii ) * rs_v ];
221 fusedkernel( kc, &packA[ ip * kc ], &packB[ jp * kc ],
222 C, vtmp, 1, MR, &aux );
242 int MC,
int NC,
int KC,
243 typename TPACKA,
typename TPACKB,
typename TV,
244 typename TA,
typename TB,
typename TC,
245 typename SEMIRINGKERNEL,
typename MICROKERNEL>
249 int batchId,
int m,
int n,
int k,
int k_stra,
253 TV* V,
int rs_v,
int cs_v,
254 SEMIRINGKERNEL semiringkernel,
255 MICROKERNEL microkernel
259 const static int MR = SEMIRINGKERNEL::mr;
260 const static int NR = SEMIRINGKERNEL::nr;
261 const static int PACK_MR = SEMIRINGKERNEL::pack_mr;
262 const static int PACK_NR = SEMIRINGKERNEL::pack_nr;
263 const static int ALIGN_SIZE = SEMIRINGKERNEL::align_size;
264 const static int PACK_MC = ( MC / MR ) * PACK_MR;
265 const static int PACK_NC = ( NC / NR ) * PACK_NR;
267 auto Comm5th = Comm6th.
Split( hmlp_read_nway_from_env(
"KS_JC_NT" ) );
268 auto Comm4th = Comm5th.
Split( hmlp_read_nway_from_env(
"JS_PC_NT" ) );
269 auto Comm3th = Comm4th.
Split( hmlp_read_nway_from_env(
"KS_IC_NT" ) );
271 int nc = Comm6th.BalanceOver1DGangs( n, NC, NR );
272 int pack_nc = ( nc / NR ) * PACK_NR;
274 auto *packB = Comm4th.AllocateSharedMemory<ALIGN_SIZE, TPACKB>( KC * ( pack_nc + 1 ) );
276 auto *packA = Comm3th.AllocateSharedMemory<ALIGN_SIZE, TPACKA>( KC * ( PACK_MC + 1 ) );
286 auto Loop5th = Comm5th.DistributeOver1DGangs( k_stra, k, KC );
288 auto Loop4th = Comm4th.DistributeOver1DGangs( 0, m, MC );
290 for (
int jc = Loop6th.beg();
292 jc += Loop6th.inc() )
294 auto jb = std::min( n - jc, nc );
296 for (
int pc = Loop5th.beg();
298 pc += Loop5th.inc() )
300 auto pb = std::min( k - pc, KC );
301 auto is_the_last_pc_iteration = ( pc + KC >= k );
303 auto LooppkB = Comm4th.DistributeOver1DThreads( 0, jb, NR );
304 auto PackpkB = Comm4th.DistributeOver1DThreads( 0, jb, PACK_NR );
306 for (
int j = LooppkB.beg(), jp = PackpkB.beg();
308 j += LooppkB.inc(), jp += PackpkB.inc() )
310 B.Pack( k, pc, pb, n, jc + j, std::min( jb - j, NR ),
316 for (
int ic = Loop4th.beg();
318 ic += Loop4th.inc() )
320 auto ib = std::min( m - ic, MC );
322 auto LooppkA = Comm3th.DistributeOver1DThreads( 0, ib, MR );
323 auto PackpkA = Comm3th.DistributeOver1DThreads( 0, ib, PACK_MR );
325 for (
int i = LooppkA.beg(), ip = PackpkA.beg();
327 i += LooppkA.inc(), ip += PackpkA.inc() )
329 A.Pack( m, ic + i, std::min( ib - i, MR ),
330 k, pc, pb, &packA[ ip * pb ] );
335 if ( is_the_last_pc_iteration )
338 fused_macro_kernel<KC>( Comm3th,
339 m, n, ic, jc, pc, ib, jb, pb, packA, packB,
340 &C, V + ic * rs_v + jc * cs_v, rs_v, cs_v,
341 batchId, microkernel );
348 rank_k_macro_kernel<KC>( Comm3th,
349 ic, jc, pc, ib, jb, pb, packA, packB,
350 V + ic * rs_v + jc * cs_v, rs_v, cs_v,
366 Comm4th.FreeSharedMemory( packB );
369 Comm3th.FreeSharedMemory( packA );
383 int MC,
int NC,
int KC,
384 typename TPACKA,
typename TPACKB,
typename TV,
385 typename TA,
typename TB,
typename TC,
386 typename SEMIRINGKERNEL,
typename MICROKERNEL>
389 int batchId,
int m,
int n,
int k,
393 SEMIRINGKERNEL semiringkernel,
394 MICROKERNEL microkernel
397 const static int MR = SEMIRINGKERNEL::mr;
398 const static int NR = SEMIRINGKERNEL::nr;
399 const static int PACK_MR = SEMIRINGKERNEL::pack_mr;
400 const static int PACK_NR = SEMIRINGKERNEL::pack_nr;
401 const static int ALIGN_SIZE = SEMIRINGKERNEL::align_size;
402 const static int PACK_MC = ( MC / MR ) * PACK_MR;
403 const static int PACK_NC = ( NC / NR ) * PACK_NR;
404 const static bool USE_STRASSEN =
false;
407 if ( m == 0 || n == 0 || k == 0 )
return;
417 printf(
"here m %d n %d\n", m, n );
418 V = hmlp_malloc<ALIGN_SIZE, TV>( m * n );
425 V =
reinterpret_cast<TV*
>( C.X );
434 assert(
typeid(TPACKA) ==
typeid(TPACKB) );
435 assert(
typeid(TC) ==
typeid(TV) );
438 if ( k_stra == k ) k_stra -= KC;
441 tci::Parallelize( NULL, nbody_internal<MC, NC, KC, TPACKA, TPACKB, TV,
442 TA, TB, TC, SEMIRINGKERNEL, MICROKERNEL>,
443 batchId, m, n, k, k_stra, A, B, C, V, rs_v, cs_v,
444 semiringkernel, microkernel );
Comm Split(int num_groups)
Definition: tci.cpp:149
Definition: hmlp_internal.hpp:38
Definition: packing.hpp:198
Range DistributeOver1DGangs(int beg, int end, int nb)
Definition: tci.cpp:335
void hmlp_free(T *ptr)
Free the aligned memory.
Definition: util.hpp:88