HMLP: High-performance Machine Learning Primitives
conv2d.hpp
1 
24 #ifndef CNN_HPP
25 #define CNN_HPP
26 
27 #include <hmlp.h>
28 #include <hmlp_internal.hpp>
29 #include <hmlp_base.hpp>
30 
31 // #define DEBUG_CONV2D 1
32 
33 namespace hmlp
34 {
35 namespace cnn
36 {
37 
41 template<
42  int KC, int MR, int NR, int PACK_MR, int PACK_NR,
43  typename SEMIRINGKERNEL,
44  typename TA, typename TB, typename TC, typename TV>
46 (
47  Worker &thread,
48  int ic, int jc, int pc,
49  int m, int n, int k,
50  TA *packA,
51  TB *packB,
52  TV *C, int ldc,
53  SEMIRINGKERNEL semiringkernel
54 )
55 {
56  thread_communicator &ic_comm = *thread.ic_comm;
57 
58  auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
59  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
60  auto loop2nd = GetRange( 0, m, MR );
61  auto pack2nd = GetRange( 0, m, PACK_MR );
62 
63  for ( int j = loop3rd.beg(), jp = pack3rd.beg();
64  j < loop3rd.end();
65  j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop
66  {
67  struct aux_s<TA, TB, TC, TV> aux;
68  aux.pc = pc;
69  aux.b_next = packB;
70  aux.do_packC = 0;
71  aux.jb = std::min( n - j, NR );
72 
73  for ( int i = loop2nd.beg(), ip = pack2nd.beg();
74  i < loop2nd.end();
75  i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop
76  {
77  aux.ib = std::min( m - i, MR );
78  if ( aux.ib != MR )
79  {
80  aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
81  }
82 
83  if ( aux.jb == NR && aux.ib == MR )
84  {
85  semiringkernel
86  (
87  k,
88  &packA[ ip * k ],
89  &packB[ jp * k ],
90  &C[ j * ldc + i ], 1, ldc,
91  &aux
92  );
93  }
94  else // corner case
95  {
96  // TODO: this should be initC.
97  TV ctmp[ MR * NR ] = { (TV)0.0 };
98  semiringkernel
99  (
100  k,
101  &packA[ ip * k ],
102  &packB[ jp * k ],
103  ctmp, 1, MR,
104  &aux
105  );
106  if ( pc )
107  {
108  for ( auto jj = 0; jj < aux.jb; jj ++ )
109  {
110  for ( auto ii = 0; ii < aux.ib; ii ++ )
111  {
112  C[ ( j + jj ) * ldc + i + ii ] += ctmp[ jj * MR + ii ];
113  }
114  }
115  }
116  else
117  {
118  for ( auto jj = 0; jj < aux.jb; jj ++ )
119  {
120  for ( auto ii = 0; ii < aux.ib; ii ++ )
121  {
122  C[ ( j + jj ) * ldc + i + ii ] = ctmp[ jj * MR + ii ];
123  }
124  }
125  }
126  }
127  } // end 2nd loop
128  } // end 3rd loop
129 } // end rank_k_macro_kernel
130 
134 template<
135  int KC,
136  int MR,
137  int NR,
138  int PACK_MR,
139  int PACK_NR,
140  typename MICROKERNEL,
141  typename TA, typename TB, typename TC, typename TV>
142 void fused_macro_kernel
143 (
144  Worker &thread,
145  int ic, int jc, int pc,
146  int m, int n, int k,
147  TA *packA,
148  TB *packB,
149  TV *C, int ldc,
150  MICROKERNEL microkernel
151 )
152 {
153  thread_communicator &ic_comm = *thread.ic_comm;
154 
155  auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
156  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
157  auto loop2nd = GetRange( 0, m, MR );
158  auto pack2nd = GetRange( 0, m, PACK_MR );
159 
160  for ( int j = loop3rd.beg(), jp = pack3rd.beg();
161  j < loop3rd.end();
162  j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop
163  {
164  struct aux_s<TA, TB, TC, TV> aux;
165  aux.pc = pc;
166  aux.b_next = packB;
167  aux.do_packC = 0;
168  aux.jb = std::min( n - j, NR );
169 
170  for ( int i = loop2nd.beg(), ip = pack2nd.beg();
171  i < loop2nd.end();
172  i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop
173  {
174  aux.ib = std::min( m - i, MR );
175  if ( aux.ib != MR )
176  {
177  aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
178  }
179 
180  if ( aux.jb == NR && aux.ib == MR )
181  {
182  microkernel
183  (
184  k,
185  &packA[ ip * k ],
186  &packB[ jp * k ],
187  &C[ j * ldc + i ], 1, ldc,
188  &aux
189  );
190  }
191  else // corner case
192  {
193  TV ctmp[ MR * NR ] = { (TV)0.0 };
194  microkernel
195  (
196  k,
197  &packA[ ip * k ],
198  &packB[ jp * k ],
199  ctmp, 1, MR,
200  &aux
201  );
202 
203  if ( pc )
204  {
205  for ( auto jj = 0; jj < aux.jb; jj ++ )
206  {
207  for ( auto ii = 0; ii < aux.ib; ii ++ )
208  {
209  C[ ( j + jj ) * ldc + i + ii ] += ctmp[ jj * MR + ii ];
210  }
211  }
212  }
213  else
214  {
215  for ( auto jj = 0; jj < aux.jb; jj ++ )
216  {
217  for ( auto ii = 0; ii < aux.ib; ii ++ )
218  {
219  C[ ( j + jj ) * ldc + i + ii ] = ctmp[ jj * MR + ii ];
220  }
221  }
222  }
223  }
224  } // end 2nd loop
225  } // end 3rd loop
226 }; // end fused_macro_kernel
227 
228 
229 
230 /*
231  *
232  */
233 template<
234  int MC, int NC, int KC, int MR, int NR,
235  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
236  bool USE_STRASSEN,
237  typename SEMIRINGKERNEL, typename MICROKERNEL,
238  typename TA, typename TB, typename TC, typename TV>
239 void conv2d_internal
240 (
241  Worker &thread,
242  int w0, int h0, int d0, int s, int p,
243  TB *B,
244  int w1, int h1, int d1,
245  TA *A,
246  TC *C,
247  SEMIRINGKERNEL semiringkernel,
248  MICROKERNEL microkernel,
249  int nc, int pack_nc,
250  TA *packA,
251  TB *packB
252 )
253 {
254  packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC
255  + ( thread.ic_id ) * PACK_MC * KC;
256  packB += ( thread.jc_id ) * pack_nc * KC;
257 
258 
259  // Now compute parameters such that I can transform the problem into GEMM.
260  int m = d1;
261  int nx = ( w0 - w1 + 2 * p ) / s + 1;
262  int ny = ( h0 - h1 + 2 * p ) / s + 1;
263  int n = nx * ny;
264  int k = w1 * h1 * d0;
265 
266  //auto loop6th = GetRange( HMLP_SCHEDULE_HEFT, 0, n, nc, thread.jc_id, thread.jc_nt );
267  auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt );
268  auto loop5th = GetRange( 0, k, KC );
269  auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
270 
271  //printf( "tid %d beg %d end %d inc %d\n", thread.jc_id, loop6th.beg(), loop6th.end(), loop6th.inc() );
272 
273  //double my_beg = omp_get_wtime();
274  /*
275  * @CHENHAN: loop over your filters.
276  */
277  for ( int jc = loop6th.beg();
278  jc < loop6th.end();
279  jc += loop6th.inc() ) // beg 6th loop
280  {
281  auto &jc_comm = *thread.jc_comm;
282  auto jb = std::min( n - jc, nc );
283 
284  /*
285  * @CHENHAN: loop over your window size ( w1 * h1 * d0 ).
286  */
287  for ( int pc = loop5th.beg();
288  pc < loop5th.end();
289  pc += loop5th.inc() )
290  {
291  auto &pc_comm = *thread.pc_comm;
292  auto pb = std::min( k - pc, KC );
293  auto is_the_last_pc_iteration = ( pc + KC >= k );
294 
295  /*
296  * @CHENHAN: pack image into packB.
297  */
298  auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() );
299  auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
300 
301  for ( int j = looppkB.beg(), jp = packpkB.beg();
302  j < looppkB.end();
303  j += looppkB.inc(), jp += packpkB.inc() )
304  {
305  auto x0 = ( ( jc + j ) % nx ) * s - p; // top-left
306  auto y0 = ( ( jc + j ) / nx ) * s - p; // top-left
307 
308 #ifdef DEBUG_CONV2D
309  printf( "x0 %4d y0 %4d\n", x0, y0 );
310 #endif
311 
312  pack2Dimg<PACK_NR> // packB
313  (
314  std::min( jb - j, NR ), pb,
315  &packB[ jp * pb ],
316  x0, y0, pc,
317  B,
318  w0, h0, d0, s, p,
319  w1, h1
320  );
321  }
322  pc_comm.Barrier();
323 
324 
325 #ifdef DEBUG_CONV2D
326  for ( int i = 0; i < pb; i ++ )
327  {
328  for ( int jj = 0; jj < jb; jj += NR )
329  {
330  for ( int j = 0; j < NR; j ++ )
331  {
332  printf( "%5.2lf ", packB[ jj * pb + i * NR + j ] );
333  }
334  printf( " " );
335  }
336  printf( "\n" );
337  }
338  printf( "\n" );
339 #endif
340 
341 
342  for ( int ic = loop4th.beg();
343  ic < loop4th.end();
344  ic += loop4th.inc() ) // beg 4th loop
345  {
346  auto &ic_comm = *thread.ic_comm;
347  auto ib = std::min( m - ic, MC );
348 
349  auto looppkA = GetRange( 0, ib, MR, thread.jr_id, thread.jr_nt );
350  auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt );
351 
352  /*
353  * @CHENHAN: assume filters were already packed format.
354  */
355  for ( int i = looppkA.beg(), ip = packpkA.beg();
356  i < looppkA.end();
357  i += looppkA.inc(), ip += packpkA.inc() )
358  {
359  pack2D<true, PACK_MR> // packA (transA)
360  (
361  std::min( ib - i, MR ), pb,
362  &A[ ( ic + i ) * k + pc ], k, &packA[ ip * pb ]
363  );
364  }
365 
366  if ( is_the_last_pc_iteration ) // fused_macro_kernel
367  {
368  fused_macro_kernel
369  <KC, MR, NR, PACK_MR, PACK_NR, MICROKERNEL, TA, TB, TC, TV>
370  (
371  thread,
372  ic, jc, pc,
373  ib, jb, pb,
374  packA,
375  packB,
376  C + jc * m + ic, m,
377  microkernel
378  );
379  }
380  else // semiring rank-k update
381  {
383  <KC, MR, NR, PACK_MR, PACK_NR, SEMIRINGKERNEL, TA, TB, TC, TV>
384  (
385  thread,
386  ic, jc, pc,
387  ib, jb, pb,
388  packA,
389  packB,
390  C + jc * m + ic, m,
391  semiringkernel
392  );
393  }
394  ic_comm.Barrier(); // sync all jr_id!!
395  } // end 4th loop
396  pc_comm.Barrier();
397  } // end 5th loop
398  } // end 6th loop
399  //double my_time = omp_get_wtime() - my_beg;
400  //double my_flop = ( ( loop6th.end() - loop6th.beg() ) / 1e+9 ) * 2 * m * k;
402  //printf( "tid %d GFLOPS %5.2lf\n", thread.jc_id, my_time );
403 }; // end cnn_internal
404 
405 
406 
407 
408 
446 template<
447  int MC, int NC, int KC, int MR, int NR,
448  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
449  bool USE_STRASSEN,
450  typename SEMIRINGKERNEL, typename MICROKERNEL,
451  typename TA, typename TB, typename TC, typename TV>
452 void conv2d
453 (
454  int w0, int h0, int d0, int s, int p,
455  TA *B,
456  int w1, int h1, int d1,
457  TB *A,
458  TC *C,
459  SEMIRINGKERNEL semiringkernel,
460  MICROKERNEL microkernel
461 )
462 {
463  int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1;
464  int nc = NC, pack_nc = PACK_NC;
465  char *str;
466 
467  int m = d1;
468  int nx = ( w0 - w1 + 2 * p ) / s + 1;
469  int ny = ( h0 - h1 + 2 * p ) / s + 1;
470  int n = nx * ny;
471  int k = w1 * h1 * d0;
472 
473 
474  //printf( "m %4d n %4d k %4d\n", m, n, k );
475 
476  TA *packA_buff = NULL;
477  TB *packB_buff = NULL;
478 
479  // Early return if possible
480 
481  // Check the environment variable.
482  if ( omp_get_num_threads() == 1 && omp_get_max_threads() > 1 )
483  {
484  jc_nt = hmlp_read_nway_from_env( "KS_JC_NT" );
485  ic_nt = hmlp_read_nway_from_env( "KS_IC_NT" );
486  jr_nt = hmlp_read_nway_from_env( "KS_JR_NT" );
487  }
488 
489 
490  if ( jc_nt > 1 )
491  {
492  nc = ( ( n - 1 ) / ( NR * jc_nt ) + 1 ) * NR;
493  //if ( nc > NC ) nc = NC;
494  pack_nc = ( nc / NR ) * PACK_NR;
495  }
496 
497  // allocate packing memory
498  packA_buff = hmlp_malloc<ALIGN_SIZE, TA>( KC, ( PACK_MC + 1 ) * jc_nt * ic_nt, sizeof(TA) );
499  packB_buff = hmlp_malloc<ALIGN_SIZE, TB>( KC, ( pack_nc + 1 ) * jc_nt, sizeof(TB) );
500 
501  //#pragma omp parallel for
502  //for ( int i = 0; i < KC * ( PACK_MC + 1 ) * jc_nt * ic_nt; i ++ ) packA_buff[ i ] = 1.0;
503 
504 
505  // allocate tree communicator
506  thread_communicator my_comm( jc_nt, pc_nt, ic_nt, jr_nt );
507 
508 
509  #pragma omp parallel num_threads( my_comm.GetNumThreads() )
510  {
511  Worker thread( &my_comm );
512 
513  if ( USE_STRASSEN )
514  {
515  printf( "cnn: strassen algorithms haven't been implemented." );
516  exit( 1 );
517  }
518 
519  conv2d_internal
520  <MC, NC, KC, MR, NR,
521  PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
522  USE_STRASSEN,
523  SEMIRINGKERNEL, MICROKERNEL,
524  TA, TB, TC, TB>
525  (
526  thread,
527  w0, h0, d0, s, p,
528  B,
529  w1, h1, d1,
530  A,
531  C,
532  semiringkernel, microkernel,
533  nc, pack_nc,
534  packA_buff,
535  packB_buff
536  );
537  } // end omp
538 
539 #ifdef DEBUG_CONV2D
540  for ( int j = 0; j < ny; j ++ )
541  {
542  for ( int i = 0; i < nx; i ++ )
543  {
544  printf( "%5.2lf ", C[ j * nx + i ] );
545  }
546  printf( "\n" );
547  }
548 #endif
549 
550 }; // end cnn
551 
552 
553 //template<
554 // int MC, int NC, int KC, int MR, int NR,
555 // int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
556 // bool USE_STRASSEN,
557 // typename SEMIRINGKERNEL, typename MICROKERNEL,
558 // typename TA, typename TB, typename TC, typename TV>
559 //void conv2d
560 //(
561 // int w0, int h0, int d0,
562 // TA *B,
563 // int w1, int h1, int d1,
564 // TB *A,
565 // TC *C,
566 // SEMIRINGKERNEL semiringkernel,
567 // MICROKERNEL microkernel
568 //)
569 //{
570 // // Deciding s and p given the output size is also (w0, h0).
571 // // w0 = ( w0 - w1 + 2 * p ) / s + 1
572 // // h0 = ( h0 - h1 + 2 * p ) / s + 1
573 // // if s = 1, then p = ( w1 - 1 ) / 2
574 // // p = ( h1 - 1 ) / 2
575 // // that is w1 and h1 must be odd.
576 //
577 // assert( w1 == h1 );
578 //
579 // conv2d
580 // <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
581 // USE_STRASSEN,
582 // SEMIRINGKERNEL, MICROKERNEL,
583 // TA, TB, TC, TV>
584 // (
585 // w0, h0, d0, 1, ( w1 - 1 ) / 2,
586 // B,
587 // w1, h1, d1,
588 // A,
589 // C,
590 // semiringkernel,
591 // microkernel
592 // );
593 //};
594 
595 template<
596  int MC, int NC, int KC, int MR, int NR,
597  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
598  bool USE_STRASSEN,
599  typename SEMIRINGKERNEL, typename MICROKERNEL,
600  typename TA, typename TB, typename TC, typename TV>
601 void conv2d
602 (
603  int w0, int h0, int d0, int s, int p, int batchSize,
604  TA *B,
605  int w1, int h1, int d1,
606  TB *A,
607  TC *C,
608  SEMIRINGKERNEL semiringkernel,
609  MICROKERNEL microkernel
610 )
611 {
612  // Deciding s and p given the output size is also (w0, h0).
613  // w0 = ( w0 - w1 + 2 * p ) / s + 1
614  // h0 = ( h0 - h1 + 2 * p ) / s + 1
615  // if s = 1, then p = ( w1 - 1 ) / 2
616  // p = ( h1 - 1 ) / 2
617  // that is w1 and h1 must be odd.
618 
619  int nx = ( w0 - w1 + 2 * p ) / s + 1;
620  int ny = ( h0 - h1 + 2 * p ) / s + 1;
621 
622 
623  assert( w1 == h1 );
624 
625  #pragma omp parallel for
626  for ( int b = 0; b < batchSize; b ++ )
627  {
628  conv2d
629  <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
630  USE_STRASSEN,
631  SEMIRINGKERNEL, MICROKERNEL,
632  TA, TB, TC, TV>
633  (
634  w0, h0, d0, s, p,
635  B + b * w0 * h0 * d0,
636  w1, h1, d1,
637  A,
638  C + b * nx * ny * d1,
639  semiringkernel,
640  microkernel
641  );
642  }
643 };
644 
645 
650 template<typename T>
651 void conv2d_ref
652 (
653  int w0, int h0, int d0, int s, int p,
654  T *B,
655  int w1, int h1, int d1,
656  T *A,
657  T *C
658 )
659 {
660  int m = d1;
661  int nx = ( w0 - w1 + 2 * p ) / s + 1;
662  int ny = ( h0 - h1 + 2 * p ) / s + 1;
663  int n = nx * ny;
664  int k = w1 * h1 * d0;
665 
666  T *packA = A;
667  T *packB = hmlp_malloc<16, T>( k, n, sizeof(T) );
668 
669  double beg = omp_get_wtime();
670  im2col<T>
671  (
672  n, k,
673  packB, B,
674  w0, h0, d0, s, p,
675  w1, h1
676  );
677  double im2col_t = omp_get_wtime() - beg;
678  printf( "im2col( B ) %3.1Es\n", im2col_t ); fflush( stdout );
679 
680 #ifdef DEBUG_CONV2D
681  printf( "packB\n" );
682  for ( int p = 0; p < k; p ++ )
683  {
684  for ( int j = 0; j < n; j ++ )
685  {
686  printf( "%5.2lf ", packB[ j * k + p ] );
687  }
688  printf( "\n" );
689  }
690 #endif
691 
692 
693 #ifdef USE_BLAS
694  xgemm
695  (
696  "T", "N",
697  m, n, k,
698  1.0, packA, k,
699  packB, k,
700  0.0, C, m
701  );
702 #else
703  #pragma omp parallel for
704  for ( int j = 0; j < n; j ++ )
705  {
706  for ( int i = 0; i < m; i ++ )
707  {
708  C[ j * m + i ] = 0.0;
709  for ( int p = 0; p < k; p ++ )
710  {
711  C[ j * m + i ] += packA[ i * k + p ] * packB[ j * k + p ];
712  }
713  }
714  }
715 #endif
716 }; // end void conv2d_ref
717 
718 template<typename T>
719 void conv2d_ref
720 (
721  int w0, int h0, int d0, int s, int p, int batchSize,
722  T *B,
723  int w1, int h1, int d1,
724  T *A,
725  T *C
726 )
727 {
728  int nx = ( w0 - w1 + 2 * p ) / s + 1;
729  int ny = ( h0 - h1 + 2 * p ) / s + 1;
730 
731  #pragma omp parallel for
732  for ( int b = 0; b < batchSize; b ++ )
733  {
734  conv2d_ref<T>
735  (
736  w0, h0, d0, s, p,
737  B + b * w0 * h0 * d0,
738  w1, h1, d1,
739  A,
740  C + b * nx * ny * d1
741  );
742  }
743 };
744 
745 }; // end namespace conv2d
746 }; // end namespace hmlp
747 
748 #endif // define GKMX_HPP
Definition: thread.hpp:107
void rank_k_macro_kernel(tci::Comm &Comm3rd, int ic, int jc, int pc, int m, int n, int k, TA *packA, TB *packB, TV *V, int rs_v, int cs_v, SEMIRINGKERNEL semiringkernel)
Macro kernel contains the 3rd and 2nd loops. Depending on the configuration of the communicator...
Definition: rank_k.hpp:51
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
Definition: hmlp_internal.hpp:38
void Barrier()
OpenMP thread barrier from BLIS.
Definition: thread.cpp:227
Definition: gofmm.hpp:83
Definition: thread.hpp:166