HMLP: High-performance Machine Learning Primitives
gsknn.hpp
1 
23 #ifndef GSKNN_HXX
24 #define GSKNN_HXX
25 
26 #include <math.h>
27 #include <vector>
28 
29 #include <hmlp.h>
30 #include <hmlp_internal.hpp>
31 #include <hmlp_base.hpp>
32 
34 #include <primitives/strassen.hpp>
35 
36 namespace hmlp
37 {
38 namespace gsknn
39 {
40 
41 #define min( i, j ) ( (i)<(j) ? (i): (j) )
42 
46 template<
47  int KC, int MR, int NR, int PACK_MR, int PACK_NR,
48  typename SEMIRINGKERNEL,
49  typename TA, typename TB, typename TC, typename TV>
51 (
52  Worker &thread,
53  int ic, int jc, int pc,
54  int m, int n, int k,
55  TA *packA,
56  TB *packB,
57  TC *packC, int ldc,
58  SEMIRINGKERNEL semiringkernel
59 )
60 {
61  thread_communicator &ic_comm = *thread.ic_comm;
62 
63  auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
64  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
65  auto loop2nd = GetRange( 0, m, MR );
66  auto pack2nd = GetRange( 0, m, PACK_MR );
67 
68  for ( int j = loop3rd.beg(), jp = pack3rd.beg();
69  j < loop3rd.end();
70  j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop
71  {
72  struct aux_s<TA, TB, TC, TV> aux;
73  aux.pc = pc;
74  aux.b_next = packB;
75  aux.do_packC = 0;
76  aux.jb = min( n - j, NR );
77 
78  for ( int i = loop2nd.beg(), ip = pack2nd.beg();
79  i < loop2nd.end();
80  i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop
81  {
82  aux.ib = min( m - i, MR );
83  if ( i + MR >= m )
84  {
85  aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
86  }
87 
88  if ( aux.jb == NR && aux.ib == MR )
89  {
90  semiringkernel
91  (
92  k,
93  &packA[ ip * k ],
94  &packB[ jp * k ],
95  &packC[ j * ldc + i ], 1, ldc,
96  &aux
97  );
98  }
99  else
100  {
101  double c[ MR * NR ] __attribute__((aligned(32)));
102  double *cbuff = c;
103  if ( pc ) {
104  for ( auto jj = 0; jj < aux.jb; jj ++ )
105  for ( auto ii = 0; ii < aux.ib; ii ++ )
106  cbuff[ jj * MR + ii ] = packC[ ( j + jj ) * ldc + i + ii ];
107  }
108  semiringkernel
109  (
110  k,
111  &packA[ ip * k ],
112  &packB[ jp * k ],
113  cbuff, 1, MR,
114  &aux
115  );
116  for ( auto jj = 0; jj < aux.jb; jj ++ )
117  for ( auto ii = 0; ii < aux.ib; ii ++ )
118  packC[ ( j + jj ) * ldc + i + ii ] = cbuff[ jj * MR + ii ];
119  }
120  } // end 2nd loop
121  } // end 3rd loop
122 } // end rank_k_macro_kernel
123 
127 template<
128  int KC, int MR, int NR, int PACK_MR, int PACK_NR,
129  typename MICROKERNEL,
130  typename TA, typename TB, typename TC, typename TV>
131 void fused_macro_kernel
132 (
133  Worker &thread,
134  int pc,
135  int m, int n, int k, int r,
136  TA *packA, TA *packA2,
137  TB *packB, TB *packB2,
138  int *bmap,
139  TV *D, int *I, int ldr,
140  TC *packC, int ldc,
141  MICROKERNEL microkernel
142 )
143 {
144  double c[ MR * NR ] __attribute__((aligned(32)));
145  double *cbuff = c;
146  thread_communicator &ic_comm = *thread.ic_comm;
147 
148  auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() );
149  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
150  auto loop2nd = GetRange( 0, m, MR );
151  auto pack2nd = GetRange( 0, m, PACK_MR );
152 
153  for ( int j = loop3rd.beg(), jp = pack3rd.beg();
154  j < loop3rd.end();
155  j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop
156  {
157  struct aux_s<TA, TB, TC, TV> aux;
158  aux.pc = pc;
159  aux.b_next = packB;
160  //aux.ldr = ldr;
161  aux.jb = min( n - j, NR );
162 
163  for ( int i = loop2nd.beg(), ip = pack2nd.beg();
164  i < loop2nd.end();
165  i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop
166  {
167  aux.ib = min( m - i, MR );
168  //aux.I = I + i * ldr;
169  //aux.D = D + i * ldr;
170  if ( i + MR >= m )
171  {
172  aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
173  }
174  if ( pc ) {
175  for ( auto jj = 0; jj < aux.jb; jj ++ )
176  for ( auto ii = 0; ii < aux.ib; ii ++ )
177  cbuff[ jj * MR + ii ] = packC[ ( j + jj ) * ldc + i + ii ];
178  }
179  microkernel
180  (
181  k, r,
182  packA + ip * k, packA2 + ip,
183  packB + jp * k, packB2 + jp, bmap + j,
184  cbuff,
185  D + i * ldr, I + i * ldr, ldr,
186  &aux
187  );
188  if ( pc ) {
189  for ( auto jj = 0; jj < aux.jb; jj ++ )
190  for ( auto ii = 0; ii < aux.ib; ii ++ )
191  packC[ ( j + jj ) * ldc + i + ii ] = cbuff[ jj * MR + ii ];
192  }
193  } // end 2nd loop
194  } // end 3rd loop
195 } // end fused_macro_kernel
196 
197 
201 template<
202  int MC, int NC, int KC, int MR, int NR,
203  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
204  bool USE_STRASSEN,
205  typename SEMIRINGKERNEL, typename MICROKERNEL,
206  typename TA, typename TB, typename TC, typename TV>
207 void gsknn_internal
208 (
209  Worker &thread,
210  int m, int n, int k, int k_stra, int r,
211  TA *A, TA *A2, int *amap,
212  TB *B, TB *B2, int *bmap,
213  TV *D, int *I,
214  SEMIRINGKERNEL semiringkernel,
215  MICROKERNEL microkernel,
216  TA *packA, TA *packA2,
217  TB *packB, TB *packB2,
218  TC *packC, int ldpackc, int padn,
219  int ldr
220 )
221 {
222 
223  packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC
224  + ( thread.ic_id ) * PACK_MC * KC;
225  packA2 += ( thread.jc_id * thread.ic_nt + thread.ic_id ) * PACK_MC;
226  packB += ( thread.jc_id ) * PACK_NC * KC;
227  packB2 += ( thread.jc_id ) * PACK_NC;
228 
229  auto loop6th = GetRange( 0, n, NC );
230  auto loop5th = GetRange( k_stra, k, KC );
231  auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
232 
233  for ( int jc = loop6th.beg();
234  jc < loop6th.end();
235  jc += loop6th.inc() ) // beg 6th loop
236  {
237  auto jb = min( n - jc, NC );
238 
239  for ( int pc = loop5th.beg();
240  pc < loop5th.end();
241  pc += loop5th.inc() )
242  {
243  auto &pc_comm = *thread.pc_comm;
244  auto pb = min( k - pc, KC );
245  auto is_the_last_pc_iteration = ( pc + KC >= k );
246 
247  auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() );
248  auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
249 
250  for ( int j = looppkB.beg(), jp = packpkB.beg();
251  j < looppkB.end();
252  j += looppkB.inc(), jp += packpkB.inc() )
253  {
254  pack2D<true, PACK_NR> // packB
255  (
256  min( jb - j, NR ), pb,
257  &B[ pc ], k, &bmap[ jc + j ], &packB[ jp * pb ]
258  );
259 
260 
261  if ( is_the_last_pc_iteration )
262  {
263 
264  pack2D<true, PACK_NR> // packB2
265  (
266  min( jb - j, NR ), 1,
267  &B2[ 0 ], 1, &bmap[ jc + j ], &packB2[ jp * 1 ]
268  );
269 
270 
271  }
272  }
273  pc_comm.Barrier();
274 
275  for ( int ic = loop4th.beg();
276  ic < loop4th.end();
277  ic += loop4th.inc() ) // beg 4th loop
278  {
279  auto &ic_comm = *thread.ic_comm;
280  auto ib = min( m - ic, MC );
281 
282  auto looppkA = GetRange( 0, ib, MR, thread.jr_id, 1 );
283  auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, 1 );
284 
285  for ( int i = looppkA.beg(), ip = packpkA.beg();
286  i < looppkA.end();
287  i += looppkA.inc(), ip += packpkA.inc() )
288  {
289  pack2D<true, PACK_MR> // packA
290  (
291  min( ib - i, MR ), pb,
292  &A[ pc ], k, &amap[ ic + i ], &packA[ ip * pb ]
293  );
294 
295  if ( is_the_last_pc_iteration )
296  {
297  pack2D<true, PACK_MR> // packA2
298  (
299  min( ib - i, MR ), 1,
300  &A2[ 0 ], 1, &amap[ ic + i ], &packA2[ ip * 1 ]
301  );
302 
303  }
304  }
305 
306 
307  ic_comm.Barrier();
308  if ( pc + KC < k )
309  {
311  <KC, MR, NR, PACK_MR, PACK_NR, SEMIRINGKERNEL, TA, TB, TC, TV>
312  (
313  thread,
314  ic, jc, pc,
315  ib, jb, pb,
316  packA,
317  packB,
318  packC + jc * ldpackc + ic,
319  ldpackc,
320  semiringkernel
321  );
322  }
323  else
324  {
325  fused_macro_kernel
326  <KC, MR, NR, PACK_MR, PACK_NR, MICROKERNEL, TA, TB, TC, TV>
327  (
328  thread,
329  pc,
330  ib, jb, pb, r,
331  packA, packA2,
332  packB, packB2, bmap + jc,
333  D + ic * ldr, I + ic * ldr, ldr,
334  packC + jc * ldpackc + ic,
335  ldpackc,
336  microkernel
337  );
338  }
339 
340  ic_comm.Barrier(); // sync all jr_id!!
341 
342  } // end 4th loop
343  pc_comm.Barrier();
344  } // end 5th loop
345  } // end 6th loop
346 } // end gsknn_internal
347 
348 
349 
350 
351 
355 template<
356  int MC, int NC, int KC, int MR, int NR,
357  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
358  bool USE_STRASSEN,
359  typename SEMIRINGKERNEL, typename MICROKERNEL,
360  typename TA, typename TB, typename TC, typename TV>
361 void gsknn(
362  int m, int n, int k, int r,
363  TA *A, TA *A2, int *amap,
364  TB *B, TB *B2, int *bmap,
365  TV *D, int *I,
366  SEMIRINGKERNEL semiringkernel,
367  MICROKERNEL microkernel
368  )
369 {
370  int ic_nt = 1;
371  int k_stra = 0;
372  int ldpackc = 0, padn = 0;
373  int ldr = 0;
374  char *str;
375 
376  TA *packA_buff = NULL, *packA2_buff = NULL;
377  TB *packB_buff = NULL, *packB2_buff = NULL;
378  TC *packC_buff = NULL;
379 
380  // Early return if possible
381  if ( m == 0 || n == 0 || k == 0 ) return;
382 
383  // Check the environment variable.
384  str = getenv( "KS_IC_NT" );
385  if ( str ) ic_nt = (int)strtol( str, NULL, 10 );
386 
387  ldpackc = m;
388  ldr = r;
389 
390  // allocate packing memory
391  packA_buff = hmlp_malloc<ALIGN_SIZE, TA>( KC, ( PACK_MC + 1 ) * ic_nt, sizeof(TA) );
392  packB_buff = hmlp_malloc<ALIGN_SIZE, TB>( KC, ( PACK_NC + 1 ), sizeof(TB) );
393  packA2_buff = hmlp_malloc<ALIGN_SIZE, TA>( 1, ( PACK_MC + 1 ) * ic_nt, sizeof(TA) );
394  packB2_buff = hmlp_malloc<ALIGN_SIZE, TB>( 1, ( PACK_NC + 1 ), sizeof(TB) );
395  if ( k > KC ) {
396  packC_buff = hmlp_malloc<ALIGN_SIZE, TC>( m, n, sizeof(TC) );
397  }
398 
399  // allocate tree communicator
400  thread_communicator my_comm( 1, 1, ic_nt, 1 );
401 
402  if ( USE_STRASSEN )
403  {
404  k_stra = k - k % KC;
405 
406  if ( k_stra == k ) k_stra -= KC;
407 
408  if ( k_stra )
409  {
410  #pragma omp parallel for
411  for ( int i = 0; i < m * n; i ++ ) packC_buff[ i ] = 0.0;
412  }
413 
414  }
415 
416  #pragma omp parallel num_threads( my_comm.GetNumThreads() )
417  {
418  Worker thread( &my_comm );
419 
420  if ( USE_STRASSEN && k > KC )
421  {
422  strassen::strassen_internal
423  <MC, NC, KC, MR, NR,
424  PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
425  USE_STRASSEN,
426  SEMIRINGKERNEL, SEMIRINGKERNEL,
427  TA, TB, TC, TV>
428  (
429  thread,
430  HMLP_OP_T, HMLP_OP_N,
431  m, n, k_stra,
432  A, k, amap,
433  B, k, bmap,
434  packC_buff, ldpackc,
435  semiringkernel, semiringkernel,
436  NC, PACK_NC,
437  packA_buff,
438  packB_buff
439  );
440  }
441 
442  gsknn_internal
443  <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
444  USE_STRASSEN,
445  SEMIRINGKERNEL, MICROKERNEL,
446  TA, TB, TC, TB>
447  (
448  thread,
449  m, n, k, k_stra, r,
450  A, A2, amap,
451  B, B2, bmap,
452  D, I,
453  semiringkernel, microkernel,
454  packA_buff, packA2_buff,
455  packB_buff, packB2_buff,
456  packC_buff, ldpackc, padn,
457  ldr
458  );
459 
460  } // end omp region
461 
462  hmlp_free( packA_buff );
463  hmlp_free( packB_buff );
464  hmlp_free( packA2_buff );
465  hmlp_free( packB2_buff );
466  hmlp_free( packC_buff );
467 } // end gsknn
468 
469 
473 template<typename T>
474 void gsknn_ref
475 (
476  int m, int n, int k, int r,
477  T *A, T *A2, int *amap,
478  T *B, T *B2, int *bmap,
479  T *D, int *I
480 )
481 {
482  int i, j, p;
483  double beg, time_collect, time_dgemm, time_square, time_heap;
484  std::vector<T> packA, packB, C;
485  double fneg2 = -2.0, fzero = 0.0, fone = 1.0;
486 
487  // Early return if possible
488  if ( m == 0 || n == 0 || k == 0 ) return;
489 
490  packA.resize( k * m );
491  packB.resize( k * n );
492  C.resize( m * n );
493 
494  // Collect As from A and B.
495  beg = omp_get_wtime();
496  #pragma omp parallel for private( p )
497  for ( i = 0; i < m; i ++ ) {
498  for ( p = 0; p < k; p ++ ) {
499  packA[ i * k + p ] = A[ amap[ i ] * k + p ];
500  }
501  }
502  #pragma omp parallel for private( p )
503  for ( j = 0; j < n; j ++ ) {
504  for ( p = 0; p < k; p ++ ) {
505  packB[ j * k + p ] = B[ bmap[ j ] * k + p ];
506  }
507  }
508  time_collect = omp_get_wtime() - beg;
509 
510  // Compute the inner-product term.
511  beg = omp_get_wtime();
512  #ifdef USE_BLAS
513  xgemm
514  (
515  "T", "N",
516  m, n, k,
517  fone, packA.data(), k,
518  packB.data(), k,
519  fzero, C.data(), m
520  );
521  #else
522  #pragma omp parallel for private( i, p )
523  for ( j = 0; j < n; j ++ ) {
524  for ( i = 0; i < m; i ++ ) {
525  C[ j * m + i ] = 0.0;
526  for ( p = 0; p < k; p ++ ) {
527  C[ j * m + i ] += packA[ i * k + p ] * packB[ j * k + p ];
528  }
529  }
530  }
531  #endif
532  time_dgemm = omp_get_wtime() - beg;
533 
534  beg = omp_get_wtime();
535  #pragma omp parallel for private( i )
536  for ( j = 0; j < n; j ++ )
537  {
538  for ( i = 0; i < m; i ++ )
539  {
540  C[ j * m + i ] *= -2.0;
541  C[ j * m + i ] += A2[ amap[ i ] ];
542  C[ j * m + i ] += B2[ bmap[ j ] ];
543  }
544  }
545  time_square = omp_get_wtime() - beg;
546 
547  // Pure C Max Heap implementation.
548  beg = omp_get_wtime();
549  #pragma omp parallel for schedule( dynamic )
550  for ( j = 0; j < n; j ++ )
551  {
552  heap_select<T>( m, r, &C[ j * m ], amap, &D[ j * r ], &I[ j * r ] );
553  }
554  time_heap = omp_get_wtime() - beg;
555 
556 } // end void gsknn_ref
557 
558 
559 }; // end namespace gsknn
560 }; // end namespace hmlp
561 
562 #endif // define GSKNN_HXX
Definition: thread.hpp:107
void rank_k_macro_kernel(tci::Comm &Comm3rd, int ic, int jc, int pc, int m, int n, int k, TA *packA, TB *packB, TV *V, int rs_v, int cs_v, SEMIRINGKERNEL semiringkernel)
Macro kernel contains the 3rd and 2nd loops. Depending on the configuration of the communicator...
Definition: rank_k.hpp:51
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
Definition: hmlp_internal.hpp:38
void Barrier()
OpenMP thread barrier from BLIS.
Definition: thread.cpp:227
void hmlp_free(T *ptr)
Free the aligned memory.
Definition: util.hpp:88
Definition: gofmm.hpp:83
Definition: thread.hpp:166