26 #include <hmlp_internal.hpp> 27 #include <hmlp_packing.hpp> 28 #include <base/util.hpp> 29 #include <hmlp_thread.hpp> 30 #include <hmlp_runtime.hpp> 37 #define min( i, j ) ( (i)<(j) ? (i): (j) ) 44 template<
int FOLD,
bool ZEROPAD = false,
typename T>
49 template<
int FOLD,
bool ZEROPAD = false,
typename T>
60 int KC,
int MR,
int NR,
int PACK_MR,
int PACK_NR,
61 typename SEMIRINGKERNEL,
62 typename TA,
typename TB,
typename TC,
typename TV>
66 int ic,
int jc,
int pc,
71 SEMIRINGKERNEL semiringkernel
74 thread_communicator &ic_comm = *thread.ic_comm;
76 auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
77 auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
78 auto loop2nd = GetRange( 0, m, MR );
79 auto pack2nd = GetRange( 0, m, PACK_MR );
81 for (
int j = loop3rd.beg(), jp = pack3rd.beg();
83 j += loop3rd.inc(), jp += pack3rd.inc() )
85 struct aux_s<TA, TB, TC, TV> aux;
89 aux.jb = min( n - j, NR );
91 for (
int i = loop2nd.beg(), ip = pack2nd.beg();
93 i += loop2nd.inc(), ip += pack2nd.inc() )
95 aux.ib = min( m - i, MR );
98 aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
101 if ( aux.jb == NR && aux.ib == MR )
108 &C[ j * ldc + i ], ldc,
115 TV ctmp[ MR * NR ] = { (TV)0.0 };
126 for (
auto jj = 0; jj < aux.jb; jj ++ )
128 for (
auto ii = 0; ii < aux.ib; ii ++ )
130 C[ ( j + jj ) * ldc + i + ii ] += ctmp[ jj * MR + ii ];
136 for (
auto jj = 0; jj < aux.jb; jj ++ )
138 for (
auto ii = 0; ii < aux.ib; ii ++ )
140 C[ ( j + jj ) * ldc + i + ii ] = ctmp[ jj * MR + ii ];
152 template<
int KC,
int MR,
int NR,
int PACK_MR,
int PACK_NR,
153 typename MICROKERNEL,
154 typename TA,
typename TB,
typename TC,
typename TV>
155 void fused_macro_kernel
158 int ic,
int jc,
int pc,
163 MICROKERNEL microkernel
166 thread_communicator &ic_comm = *thread.ic_comm;
168 auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
169 auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
170 auto loop2nd = GetRange( 0, m, MR );
171 auto pack2nd = GetRange( 0, m, PACK_MR );
173 for (
int j = loop3rd.beg(), jp = pack3rd.beg();
175 j += loop3rd.inc(), jp += pack3rd.inc() )
177 struct aux_s<TA, TB, TC, TV> aux;
181 aux.jb = min( n - j, NR );
183 for (
int i = loop2nd.beg(), ip = pack2nd.beg();
185 i += loop2nd.inc(), ip += pack2nd.inc() )
187 aux.ib = min( m - i, MR );
190 aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
193 if ( aux.jb == NR && aux.ib == MR )
200 &C[ j * ldc + i ], ldc,
207 TV ctmp[ MR * NR ] = { (TV)0.0 };
219 for (
auto jj = 0; jj < aux.jb; jj ++ )
221 for (
auto ii = 0; ii < aux.ib; ii ++ )
223 C[ ( j + jj ) * ldc + i + ii ] += ctmp[ jj * MR + ii ];
229 for (
auto jj = 0; jj < aux.jb; jj ++ )
231 for (
auto ii = 0; ii < aux.ib; ii ++ )
233 C[ ( j + jj ) * ldc + i + ii ] = ctmp[ jj * MR + ii ];
248 int MC,
int NC,
int KC,
int MR,
int NR,
249 int PACK_MC,
int PACK_NC,
int PACK_MR,
int PACK_NR,
int ALIGN_SIZE,
251 typename SEMIRINGKERNEL,
typename MICROKERNEL,
252 typename TA,
typename TB,
typename TC,
typename TV>
256 hmlpOperation_t transA, hmlpOperation_t transB,
261 SEMIRINGKERNEL semiringkernel,
262 MICROKERNEL microkernel,
268 packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC
269 + ( thread.ic_id ) * PACK_MC * KC;
270 packB += ( thread.jc_id ) * pack_nc * KC;
272 auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt );
273 auto loop5th = GetRange( 0, k, KC );
274 auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
280 for (
int jc = loop6th.beg();
282 jc += loop6th.inc() )
284 auto &jc_comm = *thread.jc_comm;
285 auto jb = min( n - jc, nc );
290 for (
int pc = loop5th.beg();
292 pc += loop5th.inc() )
294 auto &pc_comm = *thread.pc_comm;
295 auto pb = min( k - pc, KC );
296 auto is_the_last_pc_iteration = ( pc + KC >= k );
301 auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() );
302 auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
304 for (
int j = looppkB.beg(), jp = packpkB.beg();
306 j += looppkB.inc(), jp += packpkB.inc() )
308 if ( transB == HMLP_OP_N )
310 pack2D<true, PACK_NR>
312 min( jb - j, NR ), pb,
313 &B[ ( jc + j ) * ldb + pc ], ldb, &packB[ jp * pb ]
318 pack2D<false, PACK_NR>
320 min( jb - j, NR ), pb,
321 &B[ pc * ldb + ( jc + j ) ], ldb, &packB[ jp * pb ]
330 for (
int ic = loop4th.beg();
332 ic += loop4th.inc() )
334 auto &ic_comm = *thread.ic_comm;
335 auto ib = min( m - ic, MC );
340 auto looppkA = GetRange( 0, ib, MR, thread.jr_id, thread.jr_nt );
341 auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt );
343 for (
int i = looppkA.beg(), ip = packpkA.beg();
345 i += looppkA.inc(), ip += packpkA.inc() )
347 if ( transA == HMLP_OP_N )
349 pack2D<false, PACK_MR>
351 min( ib - i, MR ), pb,
352 &A[ pc * lda + ( ic + i ) ], lda, &packA[ ip * pb ]
357 pack2D<true, PACK_MR>
359 min( ib - i, MR ), pb,
360 &A[ ( ic + i ) * lda + pc ], lda, &packA[ ip * pb ]
366 if ( is_the_last_pc_iteration )
369 <KC, MR, NR, PACK_MR, PACK_NR, MICROKERNEL, TA, TB, TC, TV>
376 C + jc * ldc + ic, ldc,
383 <KC, MR, NR, PACK_MR, PACK_NR, SEMIRINGKERNEL, TA, TB, TC, TV>
390 C + jc * ldc + ic, ldc,
443 int MC,
int NC,
int KC,
int MR,
int NR,
444 int PACK_MC,
int PACK_NC,
int PACK_MR,
int PACK_NR,
int ALIGN_SIZE,
446 typename SEMIRINGKERNEL,
typename MICROKERNEL,
447 typename TA,
typename TB,
typename TC,
typename TV>
453 hmlpOperation_t transA, hmlpOperation_t transB,
458 SEMIRINGKERNEL semiringkernel,
459 MICROKERNEL microkernel
462 int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1;
463 int nc = NC, pack_nc = PACK_NC;
466 TA *packA_buff = NULL;
467 TB *packB_buff = NULL;
470 if ( m == 0 || n == 0 || k == 0 )
return;
473 jc_nt = hmlp_read_nway_from_env(
"KS_JC_NT" );
474 ic_nt = hmlp_read_nway_from_env(
"KS_IC_NT" );
475 jr_nt = hmlp_read_nway_from_env(
"KS_JR_NT" );
480 nc = ( ( n - 1 ) / ( NR * jc_nt ) + 1 ) * NR;
481 pack_nc = ( nc / NR ) * PACK_NR;
485 packA_buff = hmlp_malloc<ALIGN_SIZE, TA>( KC, ( PACK_MC + 1 ) * jc_nt * ic_nt,
sizeof(TA) );
486 packB_buff = hmlp_malloc<ALIGN_SIZE, TB>( KC, ( pack_nc + 1 ) * jc_nt,
sizeof(TB) );
489 thread_communicator my_comm( jc_nt, pc_nt, ic_nt, jr_nt );
492 #pragma omp parallel num_threads( my_comm.GetNumThreads() ) 494 Worker thread( &my_comm );
498 printf(
"cnn: strassen algorithms haven't been implemented." );
504 PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
506 SEMIRINGKERNEL, MICROKERNEL,
518 semiringkernel, microkernel,
541 #endif // define GKMX_HPP void rank_k_macro_kernel(tci::Comm &Comm3rd, int ic, int jc, int pc, int m, int n, int k, TA *packA, TB *packB, TV *V, int rs_v, int cs_v, SEMIRINGKERNEL semiringkernel)
Macro kernel contains the 3rd and 2nd loops. Depending on the configuration of the communicator...
Definition: rank_k.hpp:51
Definition: hmlp_internal.hpp:38