28 #include <hmlp_internal.hpp> 29 #include <hmlp_base.hpp> 42 int KC,
int MR,
int NR,
int PACK_MR,
int PACK_NR,
43 typename SEMIRINGKERNEL,
44 typename TA,
typename TB,
typename TC,
typename TV>
48 int ic,
int jc,
int pc,
53 SEMIRINGKERNEL semiringkernel
58 auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
59 auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
60 auto loop2nd = GetRange( 0, m, MR );
61 auto pack2nd = GetRange( 0, m, PACK_MR );
63 for (
int j = loop3rd.beg(), jp = pack3rd.beg();
65 j += loop3rd.inc(), jp += pack3rd.inc() )
67 struct aux_s<TA, TB, TC, TV> aux;
71 aux.jb = std::min( n - j, NR );
73 for (
int i = loop2nd.beg(), ip = pack2nd.beg();
75 i += loop2nd.inc(), ip += pack2nd.inc() )
77 aux.ib = std::min( m - i, MR );
80 aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
83 if ( aux.jb == NR && aux.ib == MR )
90 &C[ j * ldc + i ], 1, ldc,
97 TV ctmp[ MR * NR ] = { (TV)0.0 };
108 for (
auto jj = 0; jj < aux.jb; jj ++ )
110 for (
auto ii = 0; ii < aux.ib; ii ++ )
112 C[ ( j + jj ) * ldc + i + ii ] += ctmp[ jj * MR + ii ];
118 for (
auto jj = 0; jj < aux.jb; jj ++ )
120 for (
auto ii = 0; ii < aux.ib; ii ++ )
122 C[ ( j + jj ) * ldc + i + ii ] = ctmp[ jj * MR + ii ];
140 typename MICROKERNEL,
141 typename TA,
typename TB,
typename TC,
typename TV>
142 void fused_macro_kernel
145 int ic,
int jc,
int pc,
150 MICROKERNEL microkernel
155 auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
156 auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
157 auto loop2nd = GetRange( 0, m, MR );
158 auto pack2nd = GetRange( 0, m, PACK_MR );
160 for (
int j = loop3rd.beg(), jp = pack3rd.beg();
162 j += loop3rd.inc(), jp += pack3rd.inc() )
164 struct aux_s<TA, TB, TC, TV> aux;
168 aux.jb = std::min( n - j, NR );
170 for (
int i = loop2nd.beg(), ip = pack2nd.beg();
172 i += loop2nd.inc(), ip += pack2nd.inc() )
174 aux.ib = std::min( m - i, MR );
177 aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
180 if ( aux.jb == NR && aux.ib == MR )
187 &C[ j * ldc + i ], 1, ldc,
193 TV ctmp[ MR * NR ] = { (TV)0.0 };
205 for (
auto jj = 0; jj < aux.jb; jj ++ )
207 for (
auto ii = 0; ii < aux.ib; ii ++ )
209 C[ ( j + jj ) * ldc + i + ii ] += ctmp[ jj * MR + ii ];
215 for (
auto jj = 0; jj < aux.jb; jj ++ )
217 for (
auto ii = 0; ii < aux.ib; ii ++ )
219 C[ ( j + jj ) * ldc + i + ii ] = ctmp[ jj * MR + ii ];
234 int MC,
int NC,
int KC,
int MR,
int NR,
235 int PACK_MC,
int PACK_NC,
int PACK_MR,
int PACK_NR,
int ALIGN_SIZE,
237 typename SEMIRINGKERNEL,
typename MICROKERNEL,
238 typename TA,
typename TB,
typename TC,
typename TV>
242 int w0,
int h0,
int d0,
int s,
int p,
244 int w1,
int h1,
int d1,
247 SEMIRINGKERNEL semiringkernel,
248 MICROKERNEL microkernel,
254 packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC
255 + ( thread.ic_id ) * PACK_MC * KC;
256 packB += ( thread.jc_id ) * pack_nc * KC;
261 int nx = ( w0 - w1 + 2 * p ) / s + 1;
262 int ny = ( h0 - h1 + 2 * p ) / s + 1;
264 int k = w1 * h1 * d0;
267 auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt );
268 auto loop5th = GetRange( 0, k, KC );
269 auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
277 for (
int jc = loop6th.beg();
279 jc += loop6th.inc() )
281 auto &jc_comm = *thread.jc_comm;
282 auto jb = std::min( n - jc, nc );
287 for (
int pc = loop5th.beg();
289 pc += loop5th.inc() )
291 auto &pc_comm = *thread.pc_comm;
292 auto pb = std::min( k - pc, KC );
293 auto is_the_last_pc_iteration = ( pc + KC >= k );
298 auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() );
299 auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
301 for (
int j = looppkB.beg(), jp = packpkB.beg();
303 j += looppkB.inc(), jp += packpkB.inc() )
305 auto x0 = ( ( jc + j ) % nx ) * s - p;
306 auto y0 = ( ( jc + j ) / nx ) * s - p;
309 printf(
"x0 %4d y0 %4d\n", x0, y0 );
314 std::min( jb - j, NR ), pb,
326 for (
int i = 0; i < pb; i ++ )
328 for (
int jj = 0; jj < jb; jj += NR )
330 for (
int j = 0; j < NR; j ++ )
332 printf(
"%5.2lf ", packB[ jj * pb + i * NR + j ] );
342 for (
int ic = loop4th.beg();
344 ic += loop4th.inc() )
346 auto &ic_comm = *thread.ic_comm;
347 auto ib = std::min( m - ic, MC );
349 auto looppkA = GetRange( 0, ib, MR, thread.jr_id, thread.jr_nt );
350 auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt );
355 for (
int i = looppkA.beg(), ip = packpkA.beg();
357 i += looppkA.inc(), ip += packpkA.inc() )
359 pack2D<true, PACK_MR>
361 std::min( ib - i, MR ), pb,
362 &A[ ( ic + i ) * k + pc ], k, &packA[ ip * pb ]
366 if ( is_the_last_pc_iteration )
369 <KC, MR, NR, PACK_MR, PACK_NR, MICROKERNEL, TA, TB, TC, TV>
383 <KC, MR, NR, PACK_MR, PACK_NR, SEMIRINGKERNEL, TA, TB, TC, TV>
447 int MC,
int NC,
int KC,
int MR,
int NR,
448 int PACK_MC,
int PACK_NC,
int PACK_MR,
int PACK_NR,
int ALIGN_SIZE,
450 typename SEMIRINGKERNEL,
typename MICROKERNEL,
451 typename TA,
typename TB,
typename TC,
typename TV>
454 int w0,
int h0,
int d0,
int s,
int p,
456 int w1,
int h1,
int d1,
459 SEMIRINGKERNEL semiringkernel,
460 MICROKERNEL microkernel
463 int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1;
464 int nc = NC, pack_nc = PACK_NC;
468 int nx = ( w0 - w1 + 2 * p ) / s + 1;
469 int ny = ( h0 - h1 + 2 * p ) / s + 1;
471 int k = w1 * h1 * d0;
476 TA *packA_buff = NULL;
477 TB *packB_buff = NULL;
482 if ( omp_get_num_threads() == 1 && omp_get_max_threads() > 1 )
484 jc_nt = hmlp_read_nway_from_env(
"KS_JC_NT" );
485 ic_nt = hmlp_read_nway_from_env(
"KS_IC_NT" );
486 jr_nt = hmlp_read_nway_from_env(
"KS_JR_NT" );
492 nc = ( ( n - 1 ) / ( NR * jc_nt ) + 1 ) * NR;
494 pack_nc = ( nc / NR ) * PACK_NR;
498 packA_buff = hmlp_malloc<ALIGN_SIZE, TA>( KC, ( PACK_MC + 1 ) * jc_nt * ic_nt,
sizeof(TA) );
499 packB_buff = hmlp_malloc<ALIGN_SIZE, TB>( KC, ( pack_nc + 1 ) * jc_nt,
sizeof(TB) );
509 #pragma omp parallel num_threads( my_comm.GetNumThreads() ) 511 Worker thread( &my_comm );
515 printf(
"cnn: strassen algorithms haven't been implemented." );
521 PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
523 SEMIRINGKERNEL, MICROKERNEL,
532 semiringkernel, microkernel,
540 for (
int j = 0; j < ny; j ++ )
542 for (
int i = 0; i < nx; i ++ )
544 printf(
"%5.2lf ", C[ j * nx + i ] );
596 int MC,
int NC,
int KC,
int MR,
int NR,
597 int PACK_MC,
int PACK_NC,
int PACK_MR,
int PACK_NR,
int ALIGN_SIZE,
599 typename SEMIRINGKERNEL,
typename MICROKERNEL,
600 typename TA,
typename TB,
typename TC,
typename TV>
603 int w0,
int h0,
int d0,
int s,
int p,
int batchSize,
605 int w1,
int h1,
int d1,
608 SEMIRINGKERNEL semiringkernel,
609 MICROKERNEL microkernel
619 int nx = ( w0 - w1 + 2 * p ) / s + 1;
620 int ny = ( h0 - h1 + 2 * p ) / s + 1;
625 #pragma omp parallel for 626 for (
int b = 0; b < batchSize; b ++ )
629 <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
631 SEMIRINGKERNEL, MICROKERNEL,
635 B + b * w0 * h0 * d0,
638 C + b * nx * ny * d1,
653 int w0,
int h0,
int d0,
int s,
int p,
655 int w1,
int h1,
int d1,
661 int nx = ( w0 - w1 + 2 * p ) / s + 1;
662 int ny = ( h0 - h1 + 2 * p ) / s + 1;
664 int k = w1 * h1 * d0;
667 T *packB = hmlp_malloc<16, T>( k, n,
sizeof(T) );
669 double beg = omp_get_wtime();
677 double im2col_t = omp_get_wtime() - beg;
678 printf(
"im2col( B ) %3.1Es\n", im2col_t ); fflush( stdout );
682 for (
int p = 0; p < k; p ++ )
684 for (
int j = 0; j < n; j ++ )
686 printf(
"%5.2lf ", packB[ j * k + p ] );
703 #pragma omp parallel for 704 for (
int j = 0; j < n; j ++ )
706 for (
int i = 0; i < m; i ++ )
708 C[ j * m + i ] = 0.0;
709 for (
int p = 0; p < k; p ++ )
711 C[ j * m + i ] += packA[ i * k + p ] * packB[ j * k + p ];
721 int w0,
int h0,
int d0,
int s,
int p,
int batchSize,
723 int w1,
int h1,
int d1,
728 int nx = ( w0 - w1 + 2 * p ) / s + 1;
729 int ny = ( h0 - h1 + 2 * p ) / s + 1;
731 #pragma omp parallel for 732 for (
int b = 0; b < batchSize; b ++ )
737 B + b * w0 * h0 * d0,
748 #endif // define GKMX_HPP Definition: thread.hpp:107
void rank_k_macro_kernel(tci::Comm &Comm3rd, int ic, int jc, int pc, int m, int n, int k, TA *packA, TB *packB, TV *V, int rs_v, int cs_v, SEMIRINGKERNEL semiringkernel)
Macro kernel contains the 3rd and 2nd loops. Depending on the configuration of the communicator...
Definition: rank_k.hpp:51
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
Definition: hmlp_internal.hpp:38
void Barrier()
OpenMP thread barrier from BLIS.
Definition: thread.cpp:227
Definition: thread.hpp:166