HMLP: High-performance Machine Learning Primitives
cnn.hpp
1 
22 #ifndef CNN_HPP
23 #define CNN_HPP
24 
25 #include <hmlp.h>
26 #include <hmlp_internal.hpp>
27 #include <hmlp_packing.hpp>
28 #include <base/util.hpp>
29 #include <hmlp_thread.hpp>
30 #include <hmlp_runtime.hpp>
31 
32 namespace hmlp
33 {
34 namespace cnn
35 {
36 
37 #define min( i, j ) ( (i)<(j) ? (i): (j) )
38 
44 template<int FOLD, bool ZEROPAD = false, typename T>
45 void my_packA( /* Define what parameters you need. */ )
46 {
47 };
48 
49 template<int FOLD, bool ZEROPAD = false, typename T>
50 void my_packB( /* Define what parameters you need. */ )
51 {
52 };
53 
54 
55 
59 template<
60  int KC, int MR, int NR, int PACK_MR, int PACK_NR,
61  typename SEMIRINGKERNEL,
62  typename TA, typename TB, typename TC, typename TV>
64 (
65  Worker &thread,
66  int ic, int jc, int pc,
67  int m, int n, int k,
68  TA *packA,
69  TB *packB,
70  TV *C, int ldc,
71  SEMIRINGKERNEL semiringkernel
72 )
73 {
74  thread_communicator &ic_comm = *thread.ic_comm;
75 
76  auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
77  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
78  auto loop2nd = GetRange( 0, m, MR );
79  auto pack2nd = GetRange( 0, m, PACK_MR );
80 
81  for ( int j = loop3rd.beg(), jp = pack3rd.beg();
82  j < loop3rd.end();
83  j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop
84  {
85  struct aux_s<TA, TB, TC, TV> aux;
86  aux.pc = pc;
87  aux.b_next = packB;
88  aux.do_packC = 0;
89  aux.jb = min( n - j, NR );
90 
91  for ( int i = loop2nd.beg(), ip = pack2nd.beg();
92  i < loop2nd.end();
93  i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop
94  {
95  aux.ib = min( m - i, MR );
96  if ( aux.ib != MR )
97  {
98  aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
99  }
100 
101  if ( aux.jb == NR && aux.ib == MR )
102  {
103  semiringkernel
104  (
105  k,
106  &packA[ ip * k ],
107  &packB[ jp * k ],
108  &C[ j * ldc + i ], ldc,
109  &aux
110  );
111  }
112  else // corner case
113  {
114  // TODO: this should be initC.
115  TV ctmp[ MR * NR ] = { (TV)0.0 };
116  semiringkernel
117  (
118  k,
119  &packA[ ip * k ],
120  &packB[ jp * k ],
121  ctmp, MR,
122  &aux
123  );
124  if ( pc )
125  {
126  for ( auto jj = 0; jj < aux.jb; jj ++ )
127  {
128  for ( auto ii = 0; ii < aux.ib; ii ++ )
129  {
130  C[ ( j + jj ) * ldc + i + ii ] += ctmp[ jj * MR + ii ];
131  }
132  }
133  }
134  else
135  {
136  for ( auto jj = 0; jj < aux.jb; jj ++ )
137  {
138  for ( auto ii = 0; ii < aux.ib; ii ++ )
139  {
140  C[ ( j + jj ) * ldc + i + ii ] = ctmp[ jj * MR + ii ];
141  }
142  }
143  }
144  }
145  } // end 2nd loop
146  } // end 3rd loop
147 } // end rank_k_macro_kernel
148 
152 template<int KC, int MR, int NR, int PACK_MR, int PACK_NR,
153  typename MICROKERNEL,
154  typename TA, typename TB, typename TC, typename TV>
155 void fused_macro_kernel
156 (
157  Worker &thread,
158  int ic, int jc, int pc,
159  int m, int n, int k,
160  TA *packA,
161  TB *packB,
162  TV *C, int ldc,
163  MICROKERNEL microkernel
164 )
165 {
166  thread_communicator &ic_comm = *thread.ic_comm;
167 
168  auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
169  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
170  auto loop2nd = GetRange( 0, m, MR );
171  auto pack2nd = GetRange( 0, m, PACK_MR );
172 
173  for ( int j = loop3rd.beg(), jp = pack3rd.beg();
174  j < loop3rd.end();
175  j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop
176  {
177  struct aux_s<TA, TB, TC, TV> aux;
178  aux.pc = pc;
179  aux.b_next = packB;
180  aux.do_packC = 0;
181  aux.jb = min( n - j, NR );
182 
183  for ( int i = loop2nd.beg(), ip = pack2nd.beg();
184  i < loop2nd.end();
185  i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop
186  {
187  aux.ib = min( m - i, MR );
188  if ( aux.ib != MR )
189  {
190  aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
191  }
192 
193  if ( aux.jb == NR && aux.ib == MR )
194  {
195  microkernel
196  (
197  k,
198  &packA[ ip * k ],
199  &packB[ jp * k ],
200  &C[ j * ldc + i ], ldc,
201  &aux
202  );
203  }
204  else // corner case
205  {
206  // TODO: this should be initC.
207  TV ctmp[ MR * NR ] = { (TV)0.0 };
208  microkernel
209  (
210  k,
211  &packA[ ip * k ],
212  &packB[ jp * k ],
213  ctmp, MR,
214  &aux
215  );
216 
217  if ( pc )
218  {
219  for ( auto jj = 0; jj < aux.jb; jj ++ )
220  {
221  for ( auto ii = 0; ii < aux.ib; ii ++ )
222  {
223  C[ ( j + jj ) * ldc + i + ii ] += ctmp[ jj * MR + ii ];
224  }
225  }
226  }
227  else
228  {
229  for ( auto jj = 0; jj < aux.jb; jj ++ )
230  {
231  for ( auto ii = 0; ii < aux.ib; ii ++ )
232  {
233  C[ ( j + jj ) * ldc + i + ii ] = ctmp[ jj * MR + ii ];
234  }
235  }
236  }
237  }
238  } // end 2nd loop
239  } // end 3rd loop
240 } // end fused_macro_kernel
241 
242 
243 
244 /*
245  *
246  */
247 template<
248  int MC, int NC, int KC, int MR, int NR,
249  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
250  bool USE_STRASSEN,
251  typename SEMIRINGKERNEL, typename MICROKERNEL,
252  typename TA, typename TB, typename TC, typename TV>
253 void cnn_internal
254 (
255  Worker &thread,
256  hmlpOperation_t transA, hmlpOperation_t transB,
257  int m, int n, int k,
258  TA *A, int lda,
259  TB *B, int ldb,
260  TC *C, int ldc,
261  SEMIRINGKERNEL semiringkernel,
262  MICROKERNEL microkernel,
263  int nc, int pack_nc,
264  TA *packA,
265  TB *packB
266 )
267 {
268  packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC
269  + ( thread.ic_id ) * PACK_MC * KC;
270  packB += ( thread.jc_id ) * pack_nc * KC;
271 
272  auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt );
273  auto loop5th = GetRange( 0, k, KC );
274  auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
275 
276 
277  /*
278  * @CHENHAN: loop over your filters.
279  */
280  for ( int jc = loop6th.beg();
281  jc < loop6th.end();
282  jc += loop6th.inc() ) // beg 6th loop
283  {
284  auto &jc_comm = *thread.jc_comm;
285  auto jb = min( n - jc, nc );
286 
287  /*
288  * @CHENHAN: loop over your window size (width*length).
289  */
290  for ( int pc = loop5th.beg();
291  pc < loop5th.end();
292  pc += loop5th.inc() )
293  {
294  auto &pc_comm = *thread.pc_comm;
295  auto pb = min( k - pc, KC );
296  auto is_the_last_pc_iteration = ( pc + KC >= k );
297 
298  /*
299  * @CHENHAN: pack your filters into packB.
300  */
301  auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() );
302  auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
303 
304  for ( int j = looppkB.beg(), jp = packpkB.beg();
305  j < looppkB.end();
306  j += looppkB.inc(), jp += packpkB.inc() )
307  {
308  if ( transB == HMLP_OP_N )
309  {
310  pack2D<true, PACK_NR> // packB
311  (
312  min( jb - j, NR ), pb,
313  &B[ ( jc + j ) * ldb + pc ], ldb, &packB[ jp * pb ]
314  );
315  }
316  else
317  {
318  pack2D<false, PACK_NR> // packB (transB)
319  (
320  min( jb - j, NR ), pb,
321  &B[ pc * ldb + ( jc + j ) ], ldb, &packB[ jp * pb ]
322  );
323  }
324  }
325  pc_comm.Barrier();
326 
327  /*
328  * @CHENHAN: loop over windows of your image.
329  */
330  for ( int ic = loop4th.beg();
331  ic < loop4th.end();
332  ic += loop4th.inc() ) // beg 4th loop
333  {
334  auto &ic_comm = *thread.ic_comm;
335  auto ib = min( m - ic, MC );
336 
337  /*
338  * @CHENHAN: pack your windows into packA.
339  */
340  auto looppkA = GetRange( 0, ib, MR, thread.jr_id, thread.jr_nt );
341  auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt );
342 
343  for ( int i = looppkA.beg(), ip = packpkA.beg();
344  i < looppkA.end();
345  i += looppkA.inc(), ip += packpkA.inc() )
346  {
347  if ( transA == HMLP_OP_N )
348  {
349  pack2D<false, PACK_MR> // packA
350  (
351  min( ib - i, MR ), pb,
352  &A[ pc * lda + ( ic + i ) ], lda, &packA[ ip * pb ]
353  );
354  }
355  else
356  {
357  pack2D<true, PACK_MR> // packA (transA)
358  (
359  min( ib - i, MR ), pb,
360  &A[ ( ic + i ) * lda + pc ], lda, &packA[ ip * pb ]
361  );
362  }
363  }
364  ic_comm.Barrier();
365 
366  if ( is_the_last_pc_iteration ) // fused_macro_kernel
367  {
368  fused_macro_kernel
369  <KC, MR, NR, PACK_MR, PACK_NR, MICROKERNEL, TA, TB, TC, TV>
370  (
371  thread,
372  ic, jc, pc,
373  ib, jb, pb,
374  packA,
375  packB,
376  C + jc * ldc + ic, ldc,
377  microkernel
378  );
379  }
380  else // semiring rank-k update
381  {
383  <KC, MR, NR, PACK_MR, PACK_NR, SEMIRINGKERNEL, TA, TB, TC, TV>
384  (
385  thread,
386  ic, jc, pc,
387  ib, jb, pb,
388  packA,
389  packB,
390  C + jc * ldc + ic, ldc,
391  semiringkernel
392  );
393  }
394  ic_comm.Barrier(); // sync all jr_id!!
395  } // end 4th loop
396  pc_comm.Barrier();
397  } // end 5th loop
398  } // end 6th loop
399 } // end cnn_internal
400 
401 
402 
403 
404 
442 template<
443  int MC, int NC, int KC, int MR, int NR,
444  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
445  bool USE_STRASSEN,
446  typename SEMIRINGKERNEL, typename MICROKERNEL,
447  typename TA, typename TB, typename TC, typename TV>
448 void cnn
449 (
450  /*
451  * @CHENHAN: define what parameters you need.
452  */
453  hmlpOperation_t transA, hmlpOperation_t transB,
454  int m, int n, int k,
455  TA *A, int lda,
456  TB *B, int ldb,
457  TC *C, int ldc,
458  SEMIRINGKERNEL semiringkernel,
459  MICROKERNEL microkernel
460 )
461 {
462  int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1;
463  int nc = NC, pack_nc = PACK_NC;
464  char *str;
465 
466  TA *packA_buff = NULL;
467  TB *packB_buff = NULL;
468 
469  // Early return if possible
470  if ( m == 0 || n == 0 || k == 0 ) return;
471 
472  // Check the environment variable.
473  jc_nt = hmlp_read_nway_from_env( "KS_JC_NT" );
474  ic_nt = hmlp_read_nway_from_env( "KS_IC_NT" );
475  jr_nt = hmlp_read_nway_from_env( "KS_JR_NT" );
476 
477 
478  if ( jc_nt > 1 )
479  {
480  nc = ( ( n - 1 ) / ( NR * jc_nt ) + 1 ) * NR;
481  pack_nc = ( nc / NR ) * PACK_NR;
482  }
483 
484  // allocate packing memory
485  packA_buff = hmlp_malloc<ALIGN_SIZE, TA>( KC, ( PACK_MC + 1 ) * jc_nt * ic_nt, sizeof(TA) );
486  packB_buff = hmlp_malloc<ALIGN_SIZE, TB>( KC, ( pack_nc + 1 ) * jc_nt, sizeof(TB) );
487 
488  // allocate tree communicator
489  thread_communicator my_comm( jc_nt, pc_nt, ic_nt, jr_nt );
490 
491 
492  #pragma omp parallel num_threads( my_comm.GetNumThreads() )
493  {
494  Worker thread( &my_comm );
495 
496  if ( USE_STRASSEN )
497  {
498  printf( "cnn: strassen algorithms haven't been implemented." );
499  exit( 1 );
500  }
501 
502  cnn_internal
503  <MC, NC, KC, MR, NR,
504  PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
505  USE_STRASSEN,
506  SEMIRINGKERNEL, MICROKERNEL,
507  TA, TB, TC, TB>
508  (
509  /*
510  * @CHENHAN: change these parameters according to your interface.
511  */
512  thread,
513  transA, transB,
514  m, n, k,
515  A, lda,
516  B, ldb,
517  C, ldc,
518  semiringkernel, microkernel,
519  nc, pack_nc,
520  packA_buff,
521  packB_buff
522  );
523  } // end omp
524 } // end cnn
525 
526 
527 }; // end namespace cnn
528 }; // end namespace hmlp
529 
530 
531 
536 template<typename T>
537 void cnn_ref( /* Use the same interface as cnn(). */ )
538 {
539 }
540 
541 #endif // define GKMX_HPP
void rank_k_macro_kernel(tci::Comm &Comm3rd, int ic, int jc, int pc, int m, int n, int k, TA *packA, TB *packB, TV *V, int rs_v, int cs_v, SEMIRINGKERNEL semiringkernel)
Macro kernel contains the 3rd and 2nd loops. Depending on the configuration of the communicator...
Definition: rank_k.hpp:51
Definition: hmlp_internal.hpp:38
Definition: gofmm.hpp:83