HMLP: High-performance Machine Learning Primitives
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Pages
gemm.hpp
1 
23 #ifndef GEMM_HPP
24 #define GEMM_HPP
25 
26 #include <hmlp.h>
27 #include <hmlp_base.hpp>
29 //#include <View.hpp>
30 
31 using namespace std;
32 using namespace hmlp;
33 
34 namespace hmlp
35 {
36 namespace gemm
37 {
38 
39 template<typename T>
40 class xgemmTask : public Task
41 {
42  public:
43 
44  T alpha = 0.0;
45 
46  View<T> A;
47 
48  View<T> B;
49 
50  T beta = 0.0;
51 
52  View<T> C;
53 
54  void Set(
55  T alpha, View<T> &A,
56  View<T> &B,
57  T beta, View<T> &C )
58  {
60  this->alpha = alpha;
61  this->A = A;
62  this->B = B;
63  this->beta = beta;
64  this->C = C;
65 
67  ostringstream ss;
68  name = string( "gemm" );
69 
71  double flops = 0.0, mops = 0.0;
72  cost = 2.0 * C.row() * C.col();
73  event.Set( name + label, flops, mops );
74  };
75 
77  {
79  A.DependencyAnalysis( R, this );
80  B.DependencyAnalysis( R, this );
81  C.DependencyAnalysis( RW, this );
82  assert( !this->TryEnqueue() );
83  };
84 
85  void Execute( Worker* user_worker )
86  {
87  string transA, transB;
88  if ( A.IsTransposed() ) transA = "Transpose";
89  else transA = "No transpose";
90  if ( B.IsTransposed() ) transB = "Transpose";
91  else transB = "No transpose";
92 
93  size_t m = C.row();
94  size_t n = C.col();
95  size_t k = A.col();
96 
97  assert( A.row() == m );
98  assert( B.row() == k );
99  assert( B.col() == n );
100 
101  //int rand_id = rand();
102 
103  //printf( "%d GEMM task %s %s %lu %lu %lu, %E, %E\n",
104  // rand_id, transA.data(), transB.data(), m, n, k, alpha, beta ); fflush( stdout );
105  //printf( "%d lda %lu ldb %lu ldc %lu\n", rand_id, A.ld(), B.ld(), C.ld() ); fflush( stdout );
106 
107  xgemm( transA.data(), transB.data(), m, n, k,
108  alpha, A.data(), A.ld(),
109  B.data(), B.ld(),
110  beta, C.data(), C.ld() );
111 
112  //printf( "%d end GEMM task %s %s %lu %lu %lu, %E, %E\n",
113  // rand_id, transA.data(), transB.data(), m, n, k, alpha, beta ); fflush( stdout );
114  };
115 
116 };
122 template<typename T>
123 class xgemmBarrierTask : public Task
124 {
125  public:
126 
127  View<T> C;
128 
129  void Set(
130  T alpha, View<T> &A,
131  View<T> &B,
132  T beta, View<T> &C )
133  {
135  this->C = C;
136  this->stealable = false;
137 
139  name = string( "gemmBarrier" );
140 
142  double flops, mops;
143  flops = 0.0;
144  mops = 0.0;
145  cost = 1.0;
146  event.Set( name + label, flops, mops );
147  };
148 
150  void DependencyAnalysis() { C.DependencyAnalysis( RW, this ); };
151 
152  void Execute( Worker* user_worker ) {};
153 
154 };
157 template<typename T>
158 void CreatexgemmTask( T alpha, View<T> &A, View<T> &B, T beta, View<T> &C )
159 {
160  auto *task = new xgemmTask<T>();
161  task->Set( alpha, A, B, beta, C );
162  task->Submit();
163  task->DependencyAnalysis();
164 };
170 template<size_t NB = 512, typename T>
171 void xgemm_var1( T alpha, View<T> &A, View<T> &B, T beta, View<T> &C )
172 {
174  View<T> AL, AR,
175  A0, A1, A2;
176  View<T> BT, BB,
177  B0, B1, B2;
179  A.Partition1x2( AL, AR, 0, LEFT );
181  B.Partition2x1( BT,
182  BB, 0, TOP );
183 
184  //printf( "AL.col() %lu AR.col() %lu A.col() %lu\n", AL.col(), AR.col(), A.col() );
185 
186  while ( AL.col() < A.col() )
187  {
188  //printf( "AL.col() %lu AR.col() %lu A.col() %lu\n", AL.col(), AR.col(), A.col() );
189  size_t b = std::min( AR.col(), NB );
190 
192  Repartition1x2To1x3( AL, AR,
194  A0, A1, A2, b, RIGHT );
196  Repartition2x1To3x1( BT, B0,
197  B1,
198  BB, B2, b, BOTTOM );
199 
201  CreatexgemmTask( alpha, A1, B1, beta, C );
202  beta = 1.0;
206  ContinueWith1x3To1x2( AL, AR,
208  A0, A1, A2, LEFT );
210  ContinueWith3x1To2x1( BT, B0,
211  B1,
212  BB, B2, TOP );
213 
214  }
215 };
219 template<size_t NB = 512, typename T>
220 void xgemm_var2( T alpha, View<T> &A, View<T> &B, T beta, View<T> &C )
221 {
223  View<T> CL, CR,
224  C0, C1, C2;
225  View<T> BL, BR,
226  B0, B1, B2;
227 
228  C.Partition1x2( CL, CR, 0, LEFT );
229  B.Partition1x2( BL, BR, 0, LEFT );
230 
231  while ( BL.col() < B.col() )
232  {
233  size_t b = std::min( BR.col(), NB );
234 
236  Repartition1x2To1x3( CL, CR,
238  C0, C1, C2, b, RIGHT );
240  Repartition1x2To1x3( BL, BR,
242  B0, B1, B2, b, RIGHT );
243 
245  xgemm_var1( alpha, A, B1, beta, C1 );
249  ContinueWith1x3To1x2( CL, CR,
251  C0, C1, C2, LEFT );
253  ContinueWith1x3To1x2( BL, BR,
255  B0, B1, B2, LEFT );
256 
257  }
258 };
262 template<size_t NB = 512, typename T>
263 void xgemm_var3( T alpha, View<T> &A, View<T> &B, T beta, View<T> &C )
264 {
266  View<T> AT, A0, CT, C0,
267  AB, A1, CB, C1,
268  A2, C2;
269 
270  A.Partition2x1( AT,
271  AB, 0, TOP );
272  C.Partition2x1( CT,
273  CB, 0, TOP );
274 
275  while ( AT.row() < A.row() )
276  {
277  size_t b = std::min( AB.row(), NB );
278 
280  Repartition2x1To3x1( AT, A0,
281  A1,
282  AB, A2, b, BOTTOM );
284  Repartition2x1To3x1( CT, C0,
285  C1,
286  CB, C2, b, BOTTOM );
287 
289  xgemm_var2( alpha, A1, B, beta, C1 );
293  ContinueWith3x1To2x1( AT, A0,
294  A1,
295  AB, A2, TOP );
297  ContinueWith3x1To2x1( CT, C0,
298  C1,
299  CB, C2, TOP );
300  };
301 };
305 template<size_t NB = 512, typename T>
306 void xgemm( T alpha, View<T> &A, View<T> &B, T beta, View<T> &C )
307 {
308  //string transA, transB;
309  //if ( A.IsTransposed() ) transA = "Transpose";
310  //else transA = "No transpose";
311  //if ( B.IsTransposed() ) transB = "Transpose";
312  //else transB = "No transpose";
313  //size_t m = C.row();
314  //size_t n = C.col();
315  //size_t k = A.col();
316  //assert( A.row() == m );
317  //assert( B.row() == k );
318  //assert( B.col() == n );
319  //hmlp::xgemm( transA.data(), transB.data(), m, n, k,
320  // alpha, A.data(), A.ld(),
321  // B.data(), B.ld(),
322  // beta, C.data(), C.ld() );
323  //return;
324 
325 
326 
327 
328 
330  A.CreateLeafMatrixBlocks( NB, NB );
331  B.CreateLeafMatrixBlocks( NB, NB );
332  C.CreateLeafMatrixBlocks( NB, NB );
333 
335  if ( hmlp_is_in_epoch_session() )
336  {
337  auto *begXGEMMtask = new xgemmBarrierTask<T>();
338  auto *endXGEMMtask = new xgemmBarrierTask<T>();
351  begXGEMMtask->Set( alpha, A, B, beta, C );
352  begXGEMMtask->Submit();
353  begXGEMMtask->DependencyAnalysis();
354 
360  xgemm_var3( alpha, A, B, beta, C );
361 
366  endXGEMMtask->Set( alpha, A, B, beta, C );
367  endXGEMMtask->Submit();
368  endXGEMMtask->DependencyAnalysis();
369 
371  begXGEMMtask->TryEnqueue();
372  endXGEMMtask->CallBackWhileWaiting();
373  }
374  else
375  {
376  xgemm_var3( alpha, A, B, beta, C );
377  }
378 };
381 template<typename T>
382 void xgemm( hmlpOperation_t transA, hmlpOperation_t transB,
383  T alpha, Data<T> &A, Data<T> &B, T beta, Data<T> &C )
384 {
385  const bool TRANS = true;
386  const bool NOTRANS = true;
387 
389  View<T> Aview, Bview, Cview;
390 
392  if ( transA == HMLP_OP_T ) Aview.Set( true, A );
393  else Aview.Set( false, A );
394  if ( transB == HMLP_OP_T ) Bview.Set( true, B );
395  else Bview.Set( false, B );
397  Cview.Set( C );
398 
399  xgemm( alpha, Aview, Bview, beta, Cview );
400 
401 };
404 template<typename T>
405 void xgemm( T alpha, Data<T> &A, Data<T> &B, T beta, Data<T> &C )
406 {
407  xgemm( HMLP_OP_N, HMLP_OP_N, alpha, A, B, beta, C );
408 };
413 };
414 };
417 #endif
void Repartition2x1To3x1(View< T > &AT, View< T > &A0, View< T > &A1, View< T > &AB, View< T > &A2, size_t mb, SideType side)
Definition: View.hpp:523
void Partition1x2(View< T > &A1, View< T > &A2, size_t nb, SideType side)
Definition: View.hpp:180
void CallBackWhileWaiting()
This is the callback function for the owner of thenested task.
Definition: runtime.cpp:417
This task is generated by the top level routine.
Definition: gemm.hpp:123
void DependencyAnalysis(ReadWriteType type, Task *task)
If leaf r/w blocks were created, then the r/w dependency applies to all leaf r/w blocks covered by th...
Definition: View.hpp:312
void ContinueWith3x1To2x1(View< T > &AT, View< T > &A0, View< T > &A1, View< T > &AB, View< T > &A2, SideType side)
Definition: View.hpp:557
void CreateLeafMatrixBlocks(size_t mb, size_t nb)
Definition: View.hpp:267
void Set(T alpha, View< T > &A, View< T > &B, T beta, View< T > &C)
Definition: gemm.hpp:129
size_t row()
Definition: View.hpp:345
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
void Submit()
Ask the runtime to create an normal task in file.
Definition: runtime.cpp:264
size_t col()
Definition: View.hpp:348
T * data()
Definition: View.hpp:354
void DependencyAnalysis()
Definition: gemm.hpp:76
void Set(T alpha, View< T > &A, View< T > &B, T beta, View< T > &C)
Definition: gemm.hpp:54
void Repartition1x2To1x3(View< T > &AL, View< T > &AR, View< T > &A0, View< T > &A1, View< T > &A2, size_t nb, SideType side)
Definition: View.hpp:458
void Set(bool TRANS, Data< T > &buff)
Definition: View.hpp:60
Definition: View.hpp:43
size_t ld()
Definition: View.hpp:351
void DependencyAnalysis()
Definition: gemm.hpp:150
Definition: gemm.hpp:40
void ContinueWith1x3To1x2(View< T > &AL, View< T > &AR, View< T > &A0, View< T > &A1, View< T > &A2, SideType side)
Definition: View.hpp:490
Definition: gofmm.hpp:83
Definition: runtime.hpp:174
void Partition2x1(View< T > &A1, View< T > &A2, size_t mb, SideType side)
Definition: View.hpp:155
Definition: thread.hpp:166