HMLP: High-performance Machine Learning Primitives
rank_k.hpp
1 
23 #ifndef RANK_K_HPP
24 #define RANK_K_HPP
25 
26 #include <assert.h>
27 #include <typeinfo>
28 #include <algorithm>
29 
30 #include <hmlp.h>
31 #include <hmlp_internal.hpp>
32 #include <hmlp_base.hpp>
33 
35 #include <packing.hpp>
36 #include <semiring_mrxnr.hpp>
37 #include <fused_mrxnr.hpp>
38 
39 using namespace std;
40 using namespace hmlp;
41 
42 namespace hmlp
43 {
49 template<int KC, typename SEMIRINGKERNEL, typename TA, typename TB, typename TV>
50 void rank_k_macro_kernel
51 (
52  tci::Comm &Comm3rd,
53  int ic, int jc, int pc,
54  int m, int n, int k,
55  TA *packA,
56  TB *packB,
57  TV *V, int rs_v, int cs_v,
58  SEMIRINGKERNEL semiringkernel
59 )
60 {
62  const static int MR = SEMIRINGKERNEL::mr;
63  const static int NR = SEMIRINGKERNEL::nr;
64  const static int PACK_MR = SEMIRINGKERNEL::pack_mr;
65  const static int PACK_NR = SEMIRINGKERNEL::pack_nr;
67  auto Comm2nd = Comm3rd.Split( hmlp_read_nway_from_env( "KS_JR_NT" ) );
69  auto Loop3rd = Comm3rd.DistributeOver1DGangs( 0, n, NR );
70  auto Pack3rd = Comm3rd.DistributeOver1DGangs( 0, n, PACK_NR );
71  auto Loop2nd = Comm2nd.DistributeOver1DThreads( 0, m, MR );
72  auto Pack2nd = Comm2nd.DistributeOver1DThreads( 0, m, PACK_MR );
74  for ( int j = Loop3rd.beg(), jp = Pack3rd.beg();
75  j < Loop3rd.end();
76  j += Loop3rd.inc(), jp += Pack3rd.inc() )
77  {
78  struct aux_s<TA, TB, TV, TV> aux;
79  aux.pc = pc;
80  aux.b_next = packB;
81  aux.do_packC = 0;
82  aux.jb = std::min( n - j, NR );
84  for ( int i = Loop2nd.beg(), ip = Pack2nd.beg();
85  i < Loop2nd.end();
86  i += Loop2nd.inc(), ip += Pack2nd.inc() )
87  {
88  aux.ib = std::min( m - i, MR );
90  if ( i + MR >= m ) aux.b_next += Pack3rd.inc() * k;
91 
92  if ( aux.jb == NR && aux.ib == MR )
93  {
94  semiringkernel( k, &packA[ ip * k ], &packB[ jp * k ],
95  &V[ i * rs_v + j * cs_v ], rs_v, cs_v, &aux );
96  }
97  else
98  {
99  TV vtmp[ MR * NR ];
100 
101  if ( pc ) // initilize ctmp
102  {
103  for ( auto jj = 0; jj < aux.jb; jj ++ )
104  for ( auto ii = 0; ii < aux.ib; ii ++ )
105  vtmp[ jj * MR + ii ] =
106  V[ ( j + jj ) * cs_v + ( i + ii ) * rs_v ];
107  }
108 
109  semiringkernel( k, &packA[ ip * k ], &packB[ jp * k ],
110  vtmp, 1, MR, &aux );
111 
112  for ( auto jj = 0; jj < aux.jb; jj ++ )
113  for ( auto ii = 0; ii < aux.ib; ii ++ )
114  V[ ( j + jj ) * cs_v + ( i + ii ) * rs_v ]
115  = vtmp[ jj * MR + ii ];
116  }
117  }
118  }
119 };
130 template<
131  int MC, int NC, int KC,
132  typename TPACKA, typename TPACKB, typename TV,
133  typename TA, typename TB, typename TC,
134  typename SEMIRINGKERNEL>
135 void rank_k_internal
136 (
137  tci::Comm &Comm6th,
138  int batchId, int m, int n, int k, int k_stra,
139  TA& A,
140  TB& B,
141  TV* V, int rs_v, int cs_v,
142  SEMIRINGKERNEL semiringkernel
143 )
144 {
146  const static int MR = SEMIRINGKERNEL::mr;
147  const static int NR = SEMIRINGKERNEL::nr;
148  const static int PACK_MR = SEMIRINGKERNEL::pack_mr;
149  const static int PACK_NR = SEMIRINGKERNEL::pack_nr;
150  const static int ALIGN_SIZE = SEMIRINGKERNEL::align_size;
151  const static int PACK_MC = ( MC / MR ) * PACK_MR;
152  const static int PACK_NC = ( NC / NR ) * PACK_NR;
154  auto Comm5th = Comm6th.Split( hmlp_read_nway_from_env( "KS_JC_NT" ) );
155  auto Comm4th = Comm5th.Split( 1 );
156  auto Comm3th = Comm4th.Split( hmlp_read_nway_from_env( "KS_IC_NT" ) );
158  int nc = Comm6th.BalanceOver1DGangs( n, NC, NR );
159  int pack_nc = ( nc / NR ) * PACK_NR;
161  auto *packB = Comm4th.AllocateSharedMemory<ALIGN_SIZE, TPACKB>( KC * ( pack_nc + 1 ) );
163  auto *packA = Comm3th.AllocateSharedMemory<ALIGN_SIZE, TPACKA>( KC * ( PACK_MC + 1 ) );
165  auto Loop6th = Comm6th.DistributeOver1DGangs( 0, n, nc );
167  auto Loop5th = Comm5th.DistributeOver1DGangs( k_stra, k, KC );
169  auto Loop4th = Comm4th.DistributeOver1DGangs( 0, m, MC );
171  for ( int jc = Loop6th.beg();
172  jc < Loop6th.end();
173  jc += Loop6th.inc() )
174  {
175  auto jb = std::min( n - jc, nc );
177  for ( int pc = Loop5th.beg();
178  pc < Loop5th.end();
179  pc += Loop5th.inc() )
180  {
181  auto pb = std::min( k - pc, KC );
183  auto LooppkB = Comm4th.DistributeOver1DThreads( 0, jb, NR );
184  auto PackpkB = Comm4th.DistributeOver1DThreads( 0, jb, PACK_NR );
186  for ( int j = LooppkB.beg(), jp = PackpkB.beg();
187  j < LooppkB.end();
188  j += LooppkB.inc(), jp += PackpkB.inc() )
189  {
190  B.Pack( k, pc, pb, n, jc + j, std::min( jb - j, NR ),
191  &packB[ jp * pb ] );
192  }
194  Comm4th.Barrier();
196  for ( int ic = Loop4th.beg();
197  ic < Loop4th.end();
198  ic += Loop4th.inc() )
199  {
200  auto ib = std::min( m - ic, MC );
202  auto LooppkA = Comm3th.DistributeOver1DThreads( 0, ib, MR );
203  auto PackpkA = Comm3th.DistributeOver1DThreads( 0, ib, PACK_MR );
205  for ( int i = LooppkA.beg(), ip = PackpkA.beg();
206  i < LooppkA.end();
207  i += LooppkA.inc(), ip += PackpkA.inc() )
208  {
209  A.Pack( m, ic + i, std::min( ib - i, MR ),
210  k, pc, pb, &packA[ ip * pb ] );
211  }
213  Comm3th.Barrier();
215  rank_k_macro_kernel<KC>( Comm3th,
216  ic, jc, pc, ib, jb, pb, packA, packB,
217  V + ic * rs_v + jc * cs_v, rs_v, cs_v,
218  semiringkernel );
220  Comm3th.Barrier();
221  }
222  Comm4th.Barrier();
223  }
224  Comm5th.Barrier();
225  }
226  Comm6th.Barrier();
228  Comm3th.FreeSharedMemory( packA );
229  Comm4th.FreeSharedMemory( packB );
230 };
242 template<
243  int MC, int NC, int KC,
244  typename TPACKA, typename TPACKB, typename TV,
245  typename TA, typename TB, typename TC,
246  typename SEMIRINGKERNEL>
247 void rank_k
248 (
249  int batchId, int m, int n, int k,
250  TA& A,
251  TB& B,
252  TC& C,
253  SEMIRINGKERNEL semiringkernel
254 )
255 {
256  const static int MR = SEMIRINGKERNEL::mr;
257  const static int NR = SEMIRINGKERNEL::nr;
258  const static int PACK_MR = SEMIRINGKERNEL::pack_mr;
259  const static int PACK_NR = SEMIRINGKERNEL::pack_nr;
260  const static int ALIGN_SIZE = SEMIRINGKERNEL::align_size;
261  const static int PACK_MC = ( MC / MR ) * PACK_MR;
262  const static int PACK_NC = ( NC / NR ) * PACK_NR;
263  const static bool USE_STRASSEN = false;
264 
266  if ( m == 0 || n == 0 || k == 0 ) return;
268  if ( !is_same<TC, MatrixLike<PACK_MR, TV, TV>>::value )
269  {
270  exit( 1 );
271  }
273  auto *V = reinterpret_cast<TV*>( C.X );
274  auto rs_v = C.rs;
275  auto cs_v = C.cs;
276 
277 
278  int k_stra = 0;
279  if ( USE_STRASSEN )
280  {
281  assert( typeid(TPACKA) == typeid(TPACKB) );
282  assert( typeid(TC) == typeid(TV) );
283  k_stra = k - k % KC;
284 
285  if ( k_stra == k ) k_stra -= KC;
286  }
287 
288  tci::Parallelize( NULL, rank_k_internal<MC, NC, KC, TPACKA, TPACKB, TV,
289  TA, TB, TC, SEMIRINGKERNEL>,
290  batchId, m, n, k, k_stra, A, B, C, V, rs_v, cs_v,
291  semiringkernel );
292 };
294 };
296 #endif
Comm Split(int num_groups)
Definition: tci.cpp:149
void rank_k(int batchId, int m, int n, int k, TA &A, TB &B, TC &C, SEMIRINGKERNEL semiringkernel)
Definition: rank_k.hpp:248
Definition: tci.hpp:89
void rank_k_internal(tci::Comm &Comm6th, int batchId, int m, int n, int k, int k_stra, TA &A, TB &B, TV *V, int rs_v, int cs_v, SEMIRINGKERNEL semiringkernel)
Definition: rank_k.hpp:136
Definition: hmlp_internal.hpp:38
Definition: packing.hpp:198
Range DistributeOver1DGangs(int beg, int end, int nb)
Definition: tci.cpp:335
Definition: gofmm.hpp:83