25 #define STRAPRIM( A0,A1,gamma,B0,B1,delta,C0,C1,alpha0,alpha1 ) \ 27 <MC, NC, KC, MR, NR, \ 28 PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, \ 30 STRA_SEMIRINGKERNEL, STRA_MICROKERNEL, \ 38 C0, C1, ldc, alpha0, alpha1, \ 39 stra_semiringkernel, stra_microkernel, \ 45 #define STRAPRIM_MAP( A0,A1,gamma,B0,B1,delta,C0,C1,alpha0,alpha1 ) \ 47 <MC, NC, KC, MR, NR, \ 48 PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, \ 50 STRA_SEMIRINGKERNEL, STRA_MICROKERNEL, \ 56 A0, A1, lda, gamma, amap, \ 57 B0, B1, ldb, delta, bmap, \ 58 C0, C1, ldc, alpha0, alpha1, \ 59 stra_semiringkernel, stra_microkernel, \ 66 #include <hmlp_internal.hpp> 67 #include <hmlp_base.hpp> 80 int KC,
int MR,
int NR,
int PACK_MR,
int PACK_NR,
81 typename SEMIRINGKERNEL,
82 typename TA,
typename TB,
typename TC,
typename TV>
86 int ic,
int jc,
int pc,
90 TV *C0, TV *C1,
int ldc, TV alpha0, TV alpha1,
91 SEMIRINGKERNEL semiringkernel
96 auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
97 auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
98 auto loop2nd = GetRange( 0, m, MR );
99 auto pack2nd = GetRange( 0, m, PACK_MR );
101 for (
int j = loop3rd.beg(), jp = pack3rd.beg();
103 j += loop3rd.inc(), jp += pack3rd.inc() )
105 struct aux_s<TA, TB, TC, TV> aux;
109 aux.jb = std::min( n - j, NR );
111 for (
int i = loop2nd.beg(), ip = pack2nd.beg();
113 i += loop2nd.inc(), ip += pack2nd.inc() )
115 aux.ib = std::min( m - i, MR );
118 aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
121 if ( aux.jb == NR && aux.ib == MR )
124 if ( alpha1 == 0 || C1 == NULL ) {
125 TV *c_list[1], alpha_list[1];
126 c_list[0] = &C0[ j * ldc + i ];
127 alpha_list[0] = alpha0;
134 1, c_list, ldc, alpha_list,
140 TV *c_list[2], alpha_list[2];
141 c_list[0] = &C0[ j * ldc + i ]; c_list[1] = &C1[ j * ldc + i ];
142 alpha_list[0] = alpha0; alpha_list[1] = alpha1;
148 2, c_list, ldc, alpha_list,
170 TV ctmp[ MR * NR ] = { (TV)0.0 };
172 TV *c_list[1], alpha_list[1];
182 1, c_list, MR, alpha_list,
200 for (
auto jj = 0; jj < aux.jb; jj ++ )
202 for (
auto ii = 0; ii < aux.ib; ii ++ )
204 C0[ ( j + jj ) * ldc + i + ii ] += alpha0 * ctmp[ jj * MR + ii ];
206 if ( alpha1 != 0 && C1 != NULL ) {
207 C1[ ( j + jj ) * ldc + i + ii ] += alpha1 * ctmp[ jj * MR + ii ];
389 int MC,
int NC,
int KC,
int MR,
int NR,
390 int PACK_MC,
int PACK_NC,
int PACK_MR,
int PACK_NR,
int ALIGN_SIZE,
392 typename STRA_SEMIRINGKERNEL,
typename STRA_MICROKERNEL,
393 typename TA,
typename TB,
typename TC,
typename TV>
397 hmlpOperation_t transA, hmlpOperation_t transB,
399 TA *A0, TA *A1,
int lda, TA gamma,
400 TB *B0, TB *B1,
int ldb, TB delta,
401 TV *C0, TV *C1,
int ldc, TV alpha0, TV alpha1,
402 STRA_SEMIRINGKERNEL stra_semiringkernel,
403 STRA_MICROKERNEL stra_microkernel,
411 packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC
412 + ( thread.ic_id ) * PACK_MC * KC;
413 packB += ( thread.jc_id ) * pack_nc * KC;
415 auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt );
416 auto loop5th = GetRange( 0, k, KC );
417 auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
419 for (
int jc = loop6th.beg();
421 jc += loop6th.inc() )
423 auto &jc_comm = *thread.jc_comm;
424 auto jb = std::min( n - jc, nc );
426 for (
int pc = loop5th.beg();
428 pc += loop5th.inc() )
430 auto &pc_comm = *thread.pc_comm;
431 auto pb = std::min( k - pc, KC );
432 auto is_the_last_pc_iteration = ( pc + KC >= k );
433 auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() );
434 auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
436 for (
int j = looppkB.beg(), jp = packpkB.beg();
438 j += looppkB.inc(), jp += packpkB.inc() )
442 if ( transB == HMLP_OP_N )
445 if ( delta == 0 || B1 == NULL ) {
446 pack2D<true, PACK_NR>
448 std::min( jb - j, NR ), pb,
449 &B0[ ( jc + j ) * ldb + pc ], ldb, &packB[ jp * pb ]
453 pack2D<true, PACK_NR>
455 std::min( jb - j, NR ), pb,
456 &B0[ ( jc + j ) * ldb + pc ], &B1[ ( jc + j ) * ldb + pc ], ldb, delta, &packB[ jp * pb ]
464 if ( delta == 0 || B1 == NULL ) {
465 pack2D<false, PACK_NR>
467 std::min( jb - j, NR ), pb,
468 &B0[ pc * ldb + ( jc + j ) ], ldb, &packB[ jp * pb ]
475 pack2D<false, PACK_NR>
477 std::min( jb - j, NR ), pb,
478 &B0[ pc * ldb + ( jc + j ) ], &B1[ pc * ldb + ( jc + j ) ], ldb, delta, &packB[ jp * pb ]
494 for (
int ic = loop4th.beg();
496 ic += loop4th.inc() )
498 auto &ic_comm = *thread.ic_comm;
499 auto ib = std::min( m - ic, MC );
500 auto looppkA = GetRange( 0, ib, MR, thread.jr_id, thread.jr_nt );
501 auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt );
503 for (
int i = looppkA.beg(), ip = packpkA.beg();
505 i += looppkA.inc(), ip += packpkA.inc() )
510 if ( transA == HMLP_OP_N )
513 if ( gamma == 0 || A1 == NULL ) {
514 pack2D<false, PACK_MR>
516 std::min( ib - i, MR ), pb,
517 &A0[ pc * lda + ( ic + i ) ], lda, &packA[ ip * pb ]
522 pack2D<false, PACK_MR>
524 std::min( ib - i, MR ), pb,
525 &A0[ pc * lda + ( ic + i ) ], &A1[ pc * lda + ( ic + i ) ], lda, gamma, &packA[ ip * pb ]
534 if ( gamma == 0 || A1 == NULL ) {
535 pack2D<true, PACK_MR>
537 std::min( ib - i, MR ), pb,
538 &A0[ ( ic + i ) * lda + pc ], lda, &packA[ ip * pb ]
541 pack2D<true, PACK_MR>
543 std::min( ib - i, MR ), pb,
544 &A0[ ( ic + i ) * lda + pc ], &A1[ ( ic + i ) * lda + pc ], lda, gamma, &packA[ ip * pb ]
604 if ( alpha1 == 0 || C1 == NULL )
620 <KC, MR, NR, PACK_MR, PACK_NR, STRA_SEMIRINGKERNEL, TA, TB, TC, TV>
628 NULL, ldc, alpha0, 0,
638 <KC, MR, NR, PACK_MR, PACK_NR, STRA_SEMIRINGKERNEL, TA, TB, TC, TV>
646 C1 + jc * ldc + ic, ldc, alpha0, alpha1,
667 int MC,
int NC,
int KC,
int MR,
int NR,
668 int PACK_MC,
int PACK_NC,
int PACK_MR,
int PACK_NR,
int ALIGN_SIZE,
670 typename STRA_SEMIRINGKERNEL,
typename STRA_MICROKERNEL,
671 typename TA,
typename TB,
typename TC,
typename TV>
675 hmlpOperation_t transA, hmlpOperation_t transB,
677 TA *A0, TA *A1,
int lda, TA gamma,
int *amap,
678 TB *B0, TB *B1,
int ldb, TB delta,
int *bmap,
679 TV *C0, TV *C1,
int ldc, TV alpha0, TV alpha1,
680 STRA_SEMIRINGKERNEL stra_semiringkernel,
681 STRA_MICROKERNEL stra_microkernel,
689 packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC
690 + ( thread.ic_id ) * PACK_MC * KC;
691 packB += ( thread.jc_id ) * pack_nc * KC;
693 auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt );
694 auto loop5th = GetRange( 0, k, KC );
695 auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
697 for (
int jc = loop6th.beg();
699 jc += loop6th.inc() )
701 auto &jc_comm = *thread.jc_comm;
702 auto jb = std::min( n - jc, nc );
704 for (
int pc = loop5th.beg();
706 pc += loop5th.inc() )
708 auto &pc_comm = *thread.pc_comm;
709 auto pb = std::min( k - pc, KC );
710 auto is_the_last_pc_iteration = ( pc + KC >= k );
711 auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() );
712 auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
714 for (
int j = looppkB.beg(), jp = packpkB.beg();
716 j += looppkB.inc(), jp += packpkB.inc() )
720 if ( transB == HMLP_OP_N )
723 if ( delta == 0 || B1 == NULL ) {
725 pack2D<true, PACK_NR>
727 std::min( jb - j, NR ), pb,
728 &B0[ pc ], ldb, &bmap[ jc + j ], &packB[ jp * pb ]
731 pack2D<true, PACK_NR>
733 std::min( jb - j, NR ), pb,
734 &B0[ pc ], &B1[ pc ], ldb, delta, &bmap[ jc + j ], &packB[ jp * pb ]
741 if ( delta == 0 || B1 == NULL ) {
742 pack2D<false, PACK_NR>
744 std::min( jb - j, NR ), pb,
745 &B0[ pc ], ldb, &bmap[ jc + j ], &packB[ jp * pb ]
748 pack2D<false, PACK_NR>
750 std::min( jb - j, NR ), pb,
751 &B0[ pc ], &B1[ pc ], ldb, delta, &bmap[ jc + j ], &packB[ jp * pb ]
762 for (
int ic = loop4th.beg();
764 ic += loop4th.inc() )
766 auto &ic_comm = *thread.ic_comm;
767 auto ib = std::min( m - ic, MC );
768 auto looppkA = GetRange( 0, ib, MR, thread.jr_id, thread.jr_nt );
769 auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt );
771 for (
int i = looppkA.beg(), ip = packpkA.beg();
773 i += looppkA.inc(), ip += packpkA.inc() )
779 if ( transA == HMLP_OP_N )
782 if ( gamma == 0 || A1 == NULL ) {
783 pack2D<false, PACK_MR>
785 std::min( ib - i, MR ), pb,
786 &A0[ pc ], lda, &amap[ ic + i ], &packA[ ip * pb ]
789 pack2D<false, PACK_MR>
791 std::min( ib - i, MR ), pb,
792 &A0[ pc ], &A1[ pc ], lda, gamma, &amap[ ic + i ], &packA[ ip * pb ]
800 if ( gamma == 0 || A1 == NULL ) {
801 pack2D<true, PACK_MR>
803 std::min( ib - i, MR ), pb,
804 &A0[ pc ], lda, &amap[ ic + i ], &packA[ ip * pb ]
807 pack2D<true, PACK_MR>
809 std::min( ib - i, MR ), pb,
810 &A0[ pc ], &A1[ pc ], lda, gamma, &amap[ ic + i ], &packA[ ip * pb ]
870 if ( alpha1 == 0 || C1 == NULL )
886 <KC, MR, NR, PACK_MR, PACK_NR, STRA_SEMIRINGKERNEL, TA, TB, TC, TV>
894 NULL, ldc, alpha0, 0,
904 <KC, MR, NR, PACK_MR, PACK_NR, STRA_SEMIRINGKERNEL, TA, TB, TC, TV>
912 C1 + jc * ldc + ic, ldc, alpha0, alpha1,
926 template<
typename TA,
typename TB,
typename TV>
927 void hmlp_dynamic_peeling
929 hmlpOperation_t transA, hmlpOperation_t transB,
934 int dim1,
int dim2,
int dim3
948 char transA_val, transB_val;
949 char *char_transA = &transA_val, *char_transB = &transB_val;
961 if ( transA == HMLP_OP_N ) {
962 A_extra = &A[ 0 + ks * lda ];
965 A_extra = &A[ 0 * lda + ks ];
970 if ( transB == HMLP_OP_N ) {
971 B_extra = &B[ ks + 0 * ldb ];
974 B_extra = &B[ ks * ldb + 0 ];
979 C_extra = &C[ 0 + 0 * ldc ];
980 if ( ms > 0 && ns > 0 )
983 xgemm( char_transA, char_transB, ms, ns, kr, 1.0, A_extra, lda, B_extra, ldb, 1.0, C_extra, ldc );
993 if ( transA == HMLP_OP_N ) {
1000 if ( transB == HMLP_OP_N ) {
1001 B_extra = &B[ 0 + ns * ldb ];
1004 B_extra = &B[ 0 * ldb + ns ];
1011 C_extra = &C[ 0 + ns * ldc ];
1013 xgemm( char_transA, char_transB, m, nr, k, 1.0, A, lda, B_extra, ldb, 1.0, C_extra, ldc );
1025 if ( transA == HMLP_OP_N ) {
1028 A_extra = &A[ ms + 0 * lda ];
1034 A_extra = &A[ ms * lda + 0 ];
1041 if ( transB == HMLP_OP_N ) {
1042 B_extra = &B[ 0 + 0 * ldb ];
1048 B_extra = &B[ 0 * ldb + 0 ];
1056 TV *C_extra = &C[ ms + 0 * ldc ];
1060 xgemm( char_transA, char_transB, mr, ns, k, 1.0, A_extra, lda, B_extra, ldb, 1.0, C_extra, ldc );
1067 int MC,
int NC,
int KC,
int MR,
int NR,
1068 int PACK_MC,
int PACK_NC,
int PACK_MR,
int PACK_NR,
int ALIGN_SIZE,
1070 typename STRA_SEMIRINGKERNEL,
typename STRA_MICROKERNEL,
1071 typename TA,
typename TB,
typename TC,
typename TV>
1072 void strassen_internal
1075 hmlpOperation_t transA, hmlpOperation_t transB,
1076 int m,
int n,
int k,
1077 TA *A,
int lda,
int *amap,
1078 TB *B,
int ldb,
int *bmap,
1080 STRA_SEMIRINGKERNEL stra_semiringkernel,
1081 STRA_MICROKERNEL stra_microkernel,
1082 int nc,
int pack_nc,
1092 mr = m % ( 2 ), kr = k % ( 2 ), nr = n % ( 2 );
1093 md = m - mr, kd = k - kr, nd = n - nr;
1096 ms=md, ks=kd, ns=nd;
1097 TA *A00, *A01, *A10, *A11;
1103 TB *B00, *B01, *B10, *B11;
1109 TV *C00, *C01, *C10, *C11;
1115 md = md / 2, kd = kd / 2, nd = nd / 2;
1118 STRAPRIM_MAP( A00, A11, 1, B00, B11, 1, C00, C11, 1, 1 );
1120 STRAPRIM_MAP( A10, A11, 1, B00, NULL, 0, C10, C11, 1, -1 )
1122 STRAPRIM_MAP( A00, NULL, 0, B01, B11, -1, C01, C11, 1, 1 )
1124 STRAPRIM_MAP( A11, NULL, 0, B10, B00, -1, C00, C10, 1, 1 )
1126 STRAPRIM_MAP( A00, A01, 1, B11, NULL, 0, C00, C01, -1, 1 )
1128 STRAPRIM_MAP( A10, A00, -1, B00, B01, 1, C11, NULL, 1, 0 )
1130 STRAPRIM_MAP( A01, A11, -1, B10, B11, 1, C00, NULL, 1, 0 )
1132 if ( omp_get_thread_num() == 0 ) {
1133 hmlp_dynamic_peeling( transA, transB, m, n, k, A, lda, B, ldb, C, ldc, 2, 2, 2 );
1139 int MC,
int NC,
int KC,
int MR,
int NR,
1140 int PACK_MC,
int PACK_NC,
int PACK_MR,
int PACK_NR,
int ALIGN_SIZE,
1142 typename STRA_SEMIRINGKERNEL,
typename STRA_MICROKERNEL,
1143 typename TA,
typename TB,
typename TC,
typename TV>
1144 void strassen_internal
1147 hmlpOperation_t transA, hmlpOperation_t transB,
1148 int m,
int n,
int k,
1152 STRA_SEMIRINGKERNEL stra_semiringkernel,
1153 STRA_MICROKERNEL stra_microkernel,
1154 int nc,
int pack_nc,
1164 mr = m % ( 2 ), kr = k % ( 2 ), nr = n % ( 2 );
1165 md = m - mr, kd = k - kr, nd = n - nr;
1168 ms=md, ks=kd, ns=nd;
1169 TA *A00, *A01, *A10, *A11;
1175 TB *B00, *B01, *B10, *B11;
1181 TV *C00, *C01, *C10, *C11;
1187 md = md / 2, kd = kd / 2, nd = nd / 2;
1190 STRAPRIM( A00, A11, 1, B00, B11, 1, C00, C11, 1, 1 );
1206 STRAPRIM( A10, A11, 1, B00, NULL, 0, C10, C11, 1, -1 )
1209 STRAPRIM( A00, NULL, 0, B01, B11, -1, C01, C11, 1, 1 )
1211 STRAPRIM( A11, NULL, 0, B10, B00, -1, C00, C10, 1, 1 )
1213 STRAPRIM( A00, A01, 1, B11, NULL, 0, C00, C01, -1, 1 )
1215 STRAPRIM( A10, A00, -1, B00, B01, 1, C11, NULL, 1, 0 )
1217 STRAPRIM( A01, A11, -1, B10, B11, 1, C00, NULL, 1, 0 )
1224 if ( omp_get_thread_num() == 0 ) {
1225 hmlp_dynamic_peeling( transA, transB, m, n, k, A, lda, B, ldb, C, ldc, 2, 2, 2 );
1236 int MC,
int NC,
int KC,
int MR,
int NR,
1237 int PACK_MC,
int PACK_NC,
int PACK_MR,
int PACK_NR,
int ALIGN_SIZE,
1239 typename STRA_SEMIRINGKERNEL,
typename STRA_MICROKERNEL,
1240 typename TA,
typename TB,
typename TC,
typename TV>
1243 hmlpOperation_t transA, hmlpOperation_t transB,
1244 int m,
int n,
int k,
1248 STRA_SEMIRINGKERNEL stra_semiringkernel,
1249 STRA_MICROKERNEL stra_microkernel
1252 int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1;
1253 int nc = NC, pack_nc = PACK_NC;
1256 TA *packA_buff = NULL;
1257 TB *packB_buff = NULL;
1260 if ( m == 0 || n == 0 || k == 0 )
return;
1263 jc_nt = hmlp_read_nway_from_env(
"KS_JC_NT" );
1264 ic_nt = hmlp_read_nway_from_env(
"KS_IC_NT" );
1265 jr_nt = hmlp_read_nway_from_env(
"KS_JR_NT" );
1270 nc = ( ( n - 1 ) / ( NR * jc_nt ) + 1 ) * NR;
1271 pack_nc = ( nc / NR ) * PACK_NR;
1275 packA_buff = hmlp_malloc<ALIGN_SIZE, TA>( KC, ( PACK_MC + 1 ) * jc_nt * ic_nt,
sizeof(TA) );
1276 packB_buff = hmlp_malloc<ALIGN_SIZE, TB>( KC, ( pack_nc + 1 ) * jc_nt,
sizeof(TB) );
1281 #pragma omp parallel num_threads( my_comm.GetNumThreads() ) 1283 Worker thread( &my_comm );
1286 <MC, NC, KC, MR, NR,
1287 PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
1289 STRA_SEMIRINGKERNEL, STRA_MICROKERNEL,
1298 stra_semiringkernel, stra_microkernel,
1312 #endif // define STRASSEN_HPP Definition: thread.hpp:107
void rank_k_macro_kernel(tci::Comm &Comm3rd, int ic, int jc, int pc, int m, int n, int k, TA *packA, TB *packB, TV *V, int rs_v, int cs_v, SEMIRINGKERNEL semiringkernel)
Macro kernel contains the 3rd and 2nd loops. Depending on the configuration of the communicator...
Definition: rank_k.hpp:51
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
void hmlp_acquire_mpart(hmlpOperation_t transX, int m, int n, T *src_buff, int lda, int x, int y, int i, int j, T **dst_buff)
Split into m x n, get the subblock starting from ith row and jth column. (for STRASSEN) ...
Definition: util.hpp:143
Definition: hmlp_internal.hpp:38
void Barrier()
OpenMP thread barrier from BLIS.
Definition: thread.cpp:227
Definition: thread.hpp:166