30 #include <hmlp_internal.hpp> 31 #include <hmlp_base.hpp> 34 #include <primitives/strassen.hpp> 41 #define min( i, j ) ( (i)<(j) ? (i): (j) ) 47 int KC,
int MR,
int NR,
int PACK_MR,
int PACK_NR,
48 typename SEMIRINGKERNEL,
49 typename TA,
typename TB,
typename TC,
typename TV>
53 int ic,
int jc,
int pc,
58 SEMIRINGKERNEL semiringkernel
63 auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
64 auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
65 auto loop2nd = GetRange( 0, m, MR );
66 auto pack2nd = GetRange( 0, m, PACK_MR );
68 for (
int j = loop3rd.beg(), jp = pack3rd.beg();
70 j += loop3rd.inc(), jp += pack3rd.inc() )
72 struct aux_s<TA, TB, TC, TV> aux;
76 aux.jb = min( n - j, NR );
78 for (
int i = loop2nd.beg(), ip = pack2nd.beg();
80 i += loop2nd.inc(), ip += pack2nd.inc() )
82 aux.ib = min( m - i, MR );
85 aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
88 if ( aux.jb == NR && aux.ib == MR )
95 &packC[ j * ldc + i ], 1, ldc,
101 double c[ MR * NR ] __attribute__((aligned(32)));
104 for (
auto jj = 0; jj < aux.jb; jj ++ )
105 for (
auto ii = 0; ii < aux.ib; ii ++ )
106 cbuff[ jj * MR + ii ] = packC[ ( j + jj ) * ldc + i + ii ];
116 for (
auto jj = 0; jj < aux.jb; jj ++ )
117 for (
auto ii = 0; ii < aux.ib; ii ++ )
118 packC[ ( j + jj ) * ldc + i + ii ] = cbuff[ jj * MR + ii ];
128 int KC,
int MR,
int NR,
int PACK_MR,
int PACK_NR,
129 typename MICROKERNEL,
130 typename TA,
typename TB,
typename TC,
typename TV>
131 void fused_macro_kernel
135 int m,
int n,
int k,
int r,
136 TA *packA, TA *packA2,
137 TB *packB, TB *packB2,
139 TV *D,
int *I,
int ldr,
141 MICROKERNEL microkernel
144 double c[ MR * NR ] __attribute__((aligned(32)));
148 auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
149 auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
150 auto loop2nd = GetRange( 0, m, MR );
151 auto pack2nd = GetRange( 0, m, PACK_MR );
153 for (
int j = loop3rd.beg(), jp = pack3rd.beg();
155 j += loop3rd.inc(), jp += pack3rd.inc() )
157 struct aux_s<TA, TB, TC, TV> aux;
161 aux.jb = min( n - j, NR );
163 for (
int i = loop2nd.beg(), ip = pack2nd.beg();
165 i += loop2nd.inc(), ip += pack2nd.inc() )
167 aux.ib = min( m - i, MR );
172 aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
175 for (
auto jj = 0; jj < aux.jb; jj ++ )
176 for (
auto ii = 0; ii < aux.ib; ii ++ )
177 cbuff[ jj * MR + ii ] = packC[ ( j + jj ) * ldc + i + ii ];
182 packA + ip * k, packA2 + ip,
183 packB + jp * k, packB2 + jp, bmap + j,
185 D + i * ldr, I + i * ldr, ldr,
189 for (
auto jj = 0; jj < aux.jb; jj ++ )
190 for (
auto ii = 0; ii < aux.ib; ii ++ )
191 packC[ ( j + jj ) * ldc + i + ii ] = cbuff[ jj * MR + ii ];
202 int MC,
int NC,
int KC,
int MR,
int NR,
203 int PACK_MC,
int PACK_NC,
int PACK_MR,
int PACK_NR,
int ALIGN_SIZE,
205 typename SEMIRINGKERNEL,
typename MICROKERNEL,
206 typename TA,
typename TB,
typename TC,
typename TV>
210 int m,
int n,
int k,
int k_stra,
int r,
211 TA *A, TA *A2,
int *amap,
212 TB *B, TB *B2,
int *bmap,
214 SEMIRINGKERNEL semiringkernel,
215 MICROKERNEL microkernel,
216 TA *packA, TA *packA2,
217 TB *packB, TB *packB2,
218 TC *packC,
int ldpackc,
int padn,
223 packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC
224 + ( thread.ic_id ) * PACK_MC * KC;
225 packA2 += ( thread.jc_id * thread.ic_nt + thread.ic_id ) * PACK_MC;
226 packB += ( thread.jc_id ) * PACK_NC * KC;
227 packB2 += ( thread.jc_id ) * PACK_NC;
229 auto loop6th = GetRange( 0, n, NC );
230 auto loop5th = GetRange( k_stra, k, KC );
231 auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
233 for (
int jc = loop6th.beg();
235 jc += loop6th.inc() )
237 auto jb = min( n - jc, NC );
239 for (
int pc = loop5th.beg();
241 pc += loop5th.inc() )
243 auto &pc_comm = *thread.pc_comm;
244 auto pb = min( k - pc, KC );
245 auto is_the_last_pc_iteration = ( pc + KC >= k );
247 auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() );
248 auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
250 for (
int j = looppkB.beg(), jp = packpkB.beg();
252 j += looppkB.inc(), jp += packpkB.inc() )
254 pack2D<true, PACK_NR>
256 min( jb - j, NR ), pb,
257 &B[ pc ], k, &bmap[ jc + j ], &packB[ jp * pb ]
261 if ( is_the_last_pc_iteration )
264 pack2D<true, PACK_NR>
266 min( jb - j, NR ), 1,
267 &B2[ 0 ], 1, &bmap[ jc + j ], &packB2[ jp * 1 ]
275 for (
int ic = loop4th.beg();
277 ic += loop4th.inc() )
279 auto &ic_comm = *thread.ic_comm;
280 auto ib = min( m - ic, MC );
282 auto looppkA = GetRange( 0, ib, MR, thread.jr_id, 1 );
283 auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, 1 );
285 for (
int i = looppkA.beg(), ip = packpkA.beg();
287 i += looppkA.inc(), ip += packpkA.inc() )
289 pack2D<true, PACK_MR>
291 min( ib - i, MR ), pb,
292 &A[ pc ], k, &amap[ ic + i ], &packA[ ip * pb ]
295 if ( is_the_last_pc_iteration )
297 pack2D<true, PACK_MR>
299 min( ib - i, MR ), 1,
300 &A2[ 0 ], 1, &amap[ ic + i ], &packA2[ ip * 1 ]
311 <KC, MR, NR, PACK_MR, PACK_NR, SEMIRINGKERNEL, TA, TB, TC, TV>
318 packC + jc * ldpackc + ic,
326 <KC, MR, NR, PACK_MR, PACK_NR, MICROKERNEL, TA, TB, TC, TV>
332 packB, packB2, bmap + jc,
333 D + ic * ldr, I + ic * ldr, ldr,
334 packC + jc * ldpackc + ic,
356 int MC,
int NC,
int KC,
int MR,
int NR,
357 int PACK_MC,
int PACK_NC,
int PACK_MR,
int PACK_NR,
int ALIGN_SIZE,
359 typename SEMIRINGKERNEL,
typename MICROKERNEL,
360 typename TA,
typename TB,
typename TC,
typename TV>
362 int m,
int n,
int k,
int r,
363 TA *A, TA *A2,
int *amap,
364 TB *B, TB *B2,
int *bmap,
366 SEMIRINGKERNEL semiringkernel,
367 MICROKERNEL microkernel
372 int ldpackc = 0, padn = 0;
376 TA *packA_buff = NULL, *packA2_buff = NULL;
377 TB *packB_buff = NULL, *packB2_buff = NULL;
378 TC *packC_buff = NULL;
381 if ( m == 0 || n == 0 || k == 0 )
return;
384 str = getenv(
"KS_IC_NT" );
385 if ( str ) ic_nt = (int)strtol( str, NULL, 10 );
391 packA_buff = hmlp_malloc<ALIGN_SIZE, TA>( KC, ( PACK_MC + 1 ) * ic_nt,
sizeof(TA) );
392 packB_buff = hmlp_malloc<ALIGN_SIZE, TB>( KC, ( PACK_NC + 1 ),
sizeof(TB) );
393 packA2_buff = hmlp_malloc<ALIGN_SIZE, TA>( 1, ( PACK_MC + 1 ) * ic_nt,
sizeof(TA) );
394 packB2_buff = hmlp_malloc<ALIGN_SIZE, TB>( 1, ( PACK_NC + 1 ),
sizeof(TB) );
396 packC_buff = hmlp_malloc<ALIGN_SIZE, TC>( m, n,
sizeof(TC) );
406 if ( k_stra == k ) k_stra -= KC;
410 #pragma omp parallel for 411 for (
int i = 0; i < m * n; i ++ ) packC_buff[ i ] = 0.0;
416 #pragma omp parallel num_threads( my_comm.GetNumThreads() ) 418 Worker thread( &my_comm );
420 if ( USE_STRASSEN && k > KC )
422 strassen::strassen_internal
424 PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
426 SEMIRINGKERNEL, SEMIRINGKERNEL,
430 HMLP_OP_T, HMLP_OP_N,
435 semiringkernel, semiringkernel,
443 <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
445 SEMIRINGKERNEL, MICROKERNEL,
453 semiringkernel, microkernel,
454 packA_buff, packA2_buff,
455 packB_buff, packB2_buff,
456 packC_buff, ldpackc, padn,
476 int m,
int n,
int k,
int r,
477 T *A, T *A2,
int *amap,
478 T *B, T *B2,
int *bmap,
483 double beg, time_collect, time_dgemm, time_square, time_heap;
484 std::vector<T> packA, packB, C;
485 double fneg2 = -2.0, fzero = 0.0, fone = 1.0;
488 if ( m == 0 || n == 0 || k == 0 )
return;
490 packA.resize( k * m );
491 packB.resize( k * n );
495 beg = omp_get_wtime();
496 #pragma omp parallel for private( p ) 497 for ( i = 0; i < m; i ++ ) {
498 for ( p = 0; p < k; p ++ ) {
499 packA[ i * k + p ] = A[ amap[ i ] * k + p ];
502 #pragma omp parallel for private( p ) 503 for ( j = 0; j < n; j ++ ) {
504 for ( p = 0; p < k; p ++ ) {
505 packB[ j * k + p ] = B[ bmap[ j ] * k + p ];
508 time_collect = omp_get_wtime() - beg;
511 beg = omp_get_wtime();
517 fone, packA.data(), k,
522 #pragma omp parallel for private( i, p ) 523 for ( j = 0; j < n; j ++ ) {
524 for ( i = 0; i < m; i ++ ) {
525 C[ j * m + i ] = 0.0;
526 for ( p = 0; p < k; p ++ ) {
527 C[ j * m + i ] += packA[ i * k + p ] * packB[ j * k + p ];
532 time_dgemm = omp_get_wtime() - beg;
534 beg = omp_get_wtime();
535 #pragma omp parallel for private( i ) 536 for ( j = 0; j < n; j ++ )
538 for ( i = 0; i < m; i ++ )
540 C[ j * m + i ] *= -2.0;
541 C[ j * m + i ] += A2[ amap[ i ] ];
542 C[ j * m + i ] += B2[ bmap[ j ] ];
545 time_square = omp_get_wtime() - beg;
548 beg = omp_get_wtime();
549 #pragma omp parallel for schedule( dynamic ) 550 for ( j = 0; j < n; j ++ )
552 heap_select<T>( m, r, &C[ j * m ], amap, &D[ j * r ], &I[ j * r ] );
554 time_heap = omp_get_wtime() - beg;
562 #endif // define GSKNN_HXX Definition: thread.hpp:107
void rank_k_macro_kernel(tci::Comm &Comm3rd, int ic, int jc, int pc, int m, int n, int k, TA *packA, TB *packB, TV *V, int rs_v, int cs_v, SEMIRINGKERNEL semiringkernel)
Macro kernel contains the 3rd and 2nd loops. Depending on the configuration of the communicator...
Definition: rank_k.hpp:51
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
Definition: hmlp_internal.hpp:38
void Barrier()
OpenMP thread barrier from BLIS.
Definition: thread.cpp:227
void hmlp_free(T *ptr)
Free the aligned memory.
Definition: util.hpp:88
Definition: thread.hpp:166