31 #include <hmlp_internal.hpp> 32 #include <hmlp_base.hpp> 34 #include <KernelMatrix.hpp> 42 #define min( i, j ) ( (i)<(j) ? (i): (j) ) 49 int KC,
int MR,
int NR,
int PACK_MR,
int PACK_NR,
50 typename SEMIRINGKERNEL,
51 typename TA,
typename TB,
typename TC,
typename TV>
55 int ic,
int jc,
int pc,
60 SEMIRINGKERNEL semiringkernel
65 auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
66 auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
67 auto loop2nd = GetRange( 0, m, MR );
68 auto pack2nd = GetRange( 0, m, PACK_MR );
70 for (
int j = loop3rd.beg(), jp = pack3rd.beg();
72 j += loop3rd.inc(), jp += pack3rd.inc() )
74 struct aux_s<TA, TB, TC, TV> aux;
78 aux.jb = min( n - j, NR );
80 for (
int i = loop2nd.beg(), ip = pack2nd.beg();
82 i += loop2nd.inc(), ip += pack2nd.inc() )
84 aux.ib = min( m - i, MR );
87 aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
95 &packC[ j * ldc + i * NR ], 1, MR,
106 int KC,
int MR,
int NR,
int PACK_MR,
int PACK_NR,
107 typename MICROKERNEL,
108 typename TA,
typename TB,
typename TC,
typename TV>
109 void fused_macro_kernel
114 int ic,
int jc,
int pc,
117 TA *packA, TA *packA2, TV *packAh,
118 TB *packB, TB *packB2, TV *packBh,
121 MICROKERNEL microkernel
126 auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
127 auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
128 auto loop2nd = GetRange( 0, m, MR );
129 auto pack2nd = GetRange( 0, m, PACK_MR );
131 for (
int j = loop3rd.beg(), jp = pack3rd.beg();
133 j += loop3rd.inc(), jp += pack3rd.inc() )
135 struct aux_s<TA, TB, TC, TV> aux;
139 aux.jb = min( n - j, NR );
141 for (
int i = loop2nd.beg(), ip = pack2nd.beg();
143 i += loop2nd.inc(), ip += pack2nd.inc() )
145 aux.ib = min( m - i, MR );
148 aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
150 aux.hi = packAh + ip;
151 aux.hj = packBh + jp;
163 packC + j * ldc + i * NR, MR,
175 int MC,
int NC,
int KC,
int MR,
int NR,
176 int PACK_MC,
int PACK_NC,
int PACK_MR,
int PACK_NR,
int ALIGN_SIZE,
177 bool USE_L2NORM,
bool USE_VAR_BANDWIDTH,
bool USE_STRASSEN,
178 typename SEMIRINGKERNEL,
typename MICROKERNEL,
179 typename TA,
typename TB,
typename TC,
typename TV>
187 TA *A, TA *A2,
int *amap,
188 TB *B, TB *B2,
int *bmap,
190 SEMIRINGKERNEL semiringkernel,
191 MICROKERNEL microkernel,
194 TA *packA, TA *packA2, TA *packAh,
195 TB *packB, TB *packB2, TB *packBh,
197 TV *packC,
int ldpackc,
int padn
200 packu += ( thread.jc_id * thread.ic_nt * thread.jr_nt ) * PACK_MC * KS_RHS
201 + ( thread.ic_id * thread.jr_nt + thread.jr_id ) * PACK_MC * KS_RHS;
202 packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC
203 + ( thread.ic_id ) * PACK_MC * KC;
204 packA2 += ( thread.jc_id * thread.ic_nt + thread.ic_id ) * PACK_MC;
205 packAh += ( thread.jc_id * thread.ic_nt + thread.ic_id ) * PACK_MC;
206 packB += ( thread.jc_id ) * pack_nc * KC;
207 packB2 += ( thread.jc_id ) * pack_nc;
208 packBh += ( thread.jc_id ) * pack_nc;
209 packw += ( thread.jc_id ) * pack_nc;
210 packC += ( thread.jc_id ) * ldpackc * padn;
212 auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt );
213 auto loop5th = GetRange( 0, k, KC );
214 auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
216 for (
int jc = loop6th.beg();
218 jc += loop6th.inc() )
220 auto &jc_comm = *thread.jc_comm;
221 auto jb = min( n - jc, nc );
223 for (
int pc = loop5th.beg();
225 pc += loop5th.inc() )
227 auto &pc_comm = *thread.pc_comm;
228 auto pb = min( k - pc, KC );
229 auto is_the_last_pc_iteration = ( pc + KC >= k );
231 auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() );
232 auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
234 for (
int j = looppkB.beg(), jp = packpkB.beg();
236 j += looppkB.inc(), jp += packpkB.inc() )
238 pack2D<true, PACK_NR>
240 min( jb - j, NR ), pb,
241 &B[ pc ], k, &bmap[ jc + j ], &packB[ jp * pb ]
245 if ( is_the_last_pc_iteration )
247 pack2D<true, PACK_NR, true>
249 min( jb - j, NR ), 1,
250 &w[ 0 ], 1, &wmap[ jc + j ], &packw[ jp * 1 ]
255 pack2D<true, PACK_NR>
257 min( jb - j, NR ), 1,
258 &B2[ 0 ], 1, &bmap[ jc + j ], &packB2[ jp * 1 ]
262 if ( USE_VAR_BANDWIDTH )
264 pack2D<true, PACK_NR>
266 min( jb - j, NR ), 1,
267 kernel->hj, 1, &bmap[ jc + j ], &packBh[ jp * 1 ]
274 for (
int ic = loop4th.beg();
276 ic += loop4th.inc() )
278 auto &ic_comm = *thread.ic_comm;
279 auto ib = min( m - ic, MC );
281 auto looppkA = GetRange( 0, ib, MR, thread.jr_id, thread.jr_nt );
282 auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt );
284 for (
int i = looppkA.beg(), ip = packpkA.beg();
286 i += looppkA.inc(), ip += packpkA.inc() )
288 pack2D<true, PACK_MR>
290 min( ib - i, MR ), pb,
291 &A[ pc ], k, &amap[ ic + i ], &packA[ ip * pb ]
294 if ( is_the_last_pc_iteration )
298 pack2D<true, PACK_MR>
300 min( ib - i, MR ), 1,
301 &A2[ 0 ], 1, &amap[ ic + i ], &packA2[ ip * 1 ]
305 if ( USE_VAR_BANDWIDTH )
307 pack2D<true, PACK_MR>
309 min( ib - i, MR ), 1,
310 kernel->hi, 1, &amap[ ic + i ], &packAh[ ip * 1 ]
316 if ( is_the_last_pc_iteration )
318 for (
auto i = 0, ip = 0; i < ib; i += MR, ip += PACK_MR )
320 for (
auto ir = 0; ir < min( ib - i, MR ); ir ++ )
322 packu[ ip + ir ] = 0.0;
329 if ( is_the_last_pc_iteration )
332 <KC, MR, NR, PACK_MR, PACK_NR, MICROKERNEL, TA, TB, TC, TV>
339 packA, packA2, packAh,
340 packB, packB2, packBh,
343 ( ( ib - 1 ) / MR + 1 ) * MR,
350 <KC, MR, NR, PACK_MR, PACK_NR, SEMIRINGKERNEL, TA, TB, TC, TV>
358 ( ( ib - 1 ) / MR + 1 ) * MR,
364 if ( is_the_last_pc_iteration )
366 for (
auto i = 0, ip = 0; i < ib; i += MR, ip += PACK_MR )
368 for (
auto ir = 0; ir < min( ib - i, MR ); ir ++ )
370 TC *uptr = &( u[ umap[ ic + i + ir ] ] );
371 #pragma omp atomic update // concurrent write 372 *uptr += packu[ ip + ir ];
391 int MC,
int NC,
int KC,
int MR,
int NR,
392 int PACK_MC,
int PACK_NC,
int PACK_MR,
int PACK_NR,
int ALIGN_SIZE,
393 bool USE_L2NORM,
bool USE_VAR_BANDWIDTH,
bool USE_STRASSEN,
394 typename SEMIRINGKERNEL,
typename MICROKERNEL,
395 typename TA,
typename TB,
typename TC,
typename TV>
401 TA *A, TA *A2,
int *amap,
402 TB *B, TB *B2,
int *bmap,
404 SEMIRINGKERNEL semiringkernel,
405 MICROKERNEL microkernel
408 int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1;
409 int ldpackc = 0, padn = 0, nc = NC, pack_nc = PACK_NC;
412 TC *packu_buff = NULL;
413 TA *packA_buff = NULL, *packA2_buff = NULL, *packAh_buff = NULL;
414 TB *packB_buff = NULL, *packB2_buff = NULL, *packBh_buff = NULL;
415 TC *packw_buff = NULL;
416 TV *packC_buff = NULL;
419 if ( m == 0 || n == 0 || k == 0 )
return;
422 jc_nt = hmlp_read_nway_from_env(
"KS_JC_NT" );
423 ic_nt = hmlp_read_nway_from_env(
"KS_IC_NT" );
424 jr_nt = hmlp_read_nway_from_env(
"KS_JR_NT" );
428 nc = ( ( n - 1 ) / ( NR * jc_nt ) + 1 ) * NR;
429 pack_nc = ( nc / NR ) * PACK_NR;
434 packA_buff = hmlp_malloc<ALIGN_SIZE, TA>( KC, ( PACK_MC + 1 ) * jc_nt * ic_nt,
sizeof(TA) );
435 packB_buff = hmlp_malloc<ALIGN_SIZE, TB>( KC, ( pack_nc + 1 ) * jc_nt,
sizeof(TB) );
436 packu_buff = hmlp_malloc<ALIGN_SIZE, TC>( 1, ( PACK_MC + 1 ) * jc_nt * ic_nt * jr_nt,
sizeof(TC) );
437 packw_buff = hmlp_malloc<ALIGN_SIZE, TC>( 1, ( pack_nc + 1 ) * jc_nt,
sizeof(TC) );
443 packA2_buff = hmlp_malloc<ALIGN_SIZE, TA>( 1, ( PACK_MC + 1 ) * jc_nt * ic_nt,
sizeof(TA) );
444 packB2_buff = hmlp_malloc<ALIGN_SIZE, TB>( 1, ( pack_nc + 1 ) * jc_nt,
sizeof(TB) );
447 if ( USE_VAR_BANDWIDTH )
449 packAh_buff = hmlp_malloc<ALIGN_SIZE, TA>( 1, ( PACK_MC + 1 ) * jc_nt * ic_nt,
sizeof(TA) );
450 packBh_buff = hmlp_malloc<ALIGN_SIZE, TB>( 1, ( pack_nc + 1 ) * jc_nt,
sizeof(TB) );
456 ldpackc = ( ( m - 1 ) / PACK_MR + 1 ) * PACK_MR;
458 if ( n < nc ) padn = ( ( n - 1 ) / PACK_NR + 1 ) * PACK_NR ;
459 packC_buff = hmlp_malloc<ALIGN_SIZE, TV>( ldpackc, padn * jc_nt,
sizeof(TV) );
466 #pragma omp parallel num_threads( my_comm.GetNumThreads() ) 468 Worker thread( &my_comm );
472 printf(
"gsks: strassen algorithms haven't been implemented." );
477 <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
478 USE_L2NORM, USE_VAR_BANDWIDTH, USE_STRASSEN,
479 SEMIRINGKERNEL, MICROKERNEL,
489 semiringkernel, microkernel,
492 packA_buff, packA2_buff, packAh_buff,
493 packB_buff, packB2_buff, packBh_buff,
495 packC_buff, ldpackc, padn
522 T *A, T *A2,
int *amap,
523 T *B, T *B2,
int *bmap,
528 T rank_k_scale, fone = 1.0, fzero = 0.0;
529 std::vector<T> packA, packB, C, packu, packw;
532 if ( m == 0 || n == 0 || k == 0 )
return;
534 packA.resize( k * m );
535 packB.resize( k * n );
540 switch ( kernel->type )
545 case GAUSSIAN_VAR_BANDWIDTH:
555 #pragma omp parallel for 556 for (
int i = 0; i < m; i ++ )
558 for (
int p = 0; p < k; p ++ )
560 packA[ i * k + p ] = A[ amap[ i ] * k + p ];
562 for (
int p = 0; p < KS_RHS; p ++ )
564 packu[ p * m + i ] = u[ umap[ i ] * KS_RHS + p ];
571 #pragma omp parallel for 572 for (
int j = 0; j < n; j ++ )
574 for (
int p = 0; p < k; p ++ )
576 packB[ j * k + p ] = B[ bmap[ j ] * k + p ];
578 for (
int p = 0; p < KS_RHS; p ++ )
580 packw[ p * n + j ] = w[ wmap[ j ] * KS_RHS + p ];
592 rank_k_scale, packA.data(), k,
597 #pragma omp parallel for 598 for (
int j = 0; j < n; j ++ )
600 for (
int i = 0; i < m; i ++ )
602 C[ j * m + i ] = 0.0;
603 for (
int p = 0; p < k; p ++ )
605 C[ j * m + i ] += packA[ i * k + p ] * packB[ j * k + p ];
609 #pragma omp parallel for 610 for (
int j = 0; j < n; j ++ )
612 for (
int i = 0; i < m; i ++ )
614 C[ j * m + i ] *= rank_k_scale;
619 switch ( kernel->type )
623 #pragma omp parallel for 624 for (
int j = 0; j < n; j ++ )
626 for (
int i = 0; i < m; i ++ )
628 C[ j * m + i ] += A2[ amap[ i ] ];
629 C[ j * m + i ] += B2[ bmap[ j ] ];
630 C[ j * m + i ] *= kernel->scal;
632 for (
int i = 0; i < m; i ++ )
634 C[ j * m + i ] = exp( C[ j * m + i ] );
639 case GAUSSIAN_VAR_BANDWIDTH:
641 #pragma omp parallel for 642 for (
int j = 0; j < n; j ++ )
644 for (
int i = 0; i < m; i ++ )
646 C[ j * m + i ] += A2[ amap[ i ] ];
647 C[ j * m + i ] += B2[ bmap[ j ] ];
648 C[ j * m + i ] *= -0.5;
649 C[ j * m + i ] *= kernel->hi[ i ];
650 C[ j * m + i ] *= kernel->hj[ j ];
652 for (
int i = 0; i < m; i ++ )
654 C[ j * m + i ] = exp( C[ j * m + i ] );
673 fone, packu.data(), m
676 #pragma omp parallel for 677 for (
int i = 0; i < m; i ++ )
679 for (
int j = 0; j < nrhs; j ++ )
681 for (
int p = 0; p < n; p ++ )
683 packu[ j * m + i ] += C[ p * m + i ] * packw[ j * n + p ];
692 #pragma omp parallel for 693 for (
int i = 0; i < m; i ++ )
695 for (
int p = 0; p < KS_RHS; p ++ )
697 u[ umap[ i ] * KS_RHS + p ] = packu[ p * m + i ];
707 #endif // define GSKS_HXX Definition: thread.hpp:107
void rank_k_macro_kernel(tci::Comm &Comm3rd, int ic, int jc, int pc, int m, int n, int k, TA *packA, TB *packB, TV *V, int rs_v, int cs_v, SEMIRINGKERNEL semiringkernel)
Macro kernel contains the 3rd and 2nd loops. Depending on the configuration of the communicator...
Definition: rank_k.hpp:51
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
Definition: KernelMatrix.hpp:54
Definition: hmlp_internal.hpp:38
void Barrier()
OpenMP thread barrier from BLIS.
Definition: thread.cpp:227
void hmlp_free(T *ptr)
Free the aligned memory.
Definition: util.hpp:88
Definition: thread.hpp:166