GCC Code Coverage Report
Directory: . Exec Total Coverage
File: frame/primitives/gemm.hpp Lines: 0 105 0.0 %
Date: 2019-01-14 Branches: 0 284 0.0 %

Line Exec Source
1
/**
2
 *  HMLP (High-Performance Machine Learning Primitives)
3
 *
4
 *  Copyright (C) 2014-2017, The University of Texas at Austin
5
 *
6
 *  This program is free software: you can redistribute it and/or modify
7
 *  it under the terms of the GNU General Public License as published by
8
 *  the Free Software Foundation, either version 3 of the License, or
9
 *  (at your option) any later version.
10
 *
11
 *  This program is distributed in the hope that it will be useful,
12
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
 *  GNU General Public License for more details.
15
 *
16
 *  You should have received a copy of the GNU General Public License
17
 *  along with this program. If not, see the LICENSE file.
18
 *
19
 **/
20
21
22
23
#ifndef GEMM_HPP
24
#define GEMM_HPP
25
26
#include <hmlp.h>
27
#include <hmlp_base.hpp>
28
/** Use matrix view to employ SuperMatrix style task parallelism. */
29
//#include <View.hpp>
30
31
using namespace std;
32
using namespace hmlp;
33
34
namespace hmlp
35
{
36
namespace gemm
37
{
38
39
template<typename T>
40
class xgemmTask : public Task
41
{
42
  public:
43
44
    T alpha = 0.0;
45
46
    View<T> A;
47
48
    View<T> B;
49
50
    T beta = 0.0;
51
52
    View<T> C;
53
54
    void Set(
55
        T alpha, View<T> &A,
56
                 View<T> &B,
57
        T beta,  View<T> &C )
58
    {
59
      /** main arguments  */
60
      this->alpha = alpha;
61
      this->A = A;
62
      this->B = B;
63
      this->beta = beta;
64
      this->C = C;
65
66
      /** name and label */
67
      ostringstream ss;
68
      name = string( "gemm" );
69
70
      /** Flops, mops, cost and event */
71
      double flops = 0.0, mops = 0.0;
72
      cost  = 2.0 * C.row() * C.col();
73
      event.Set( name + label, flops, mops );
74
    };
75
76
    void DependencyAnalysis()
77
    {
78
      /** Read A and B, read/write C */
79
      A.DependencyAnalysis( R,  this );
80
      B.DependencyAnalysis( R,  this );
81
      C.DependencyAnalysis( RW, this );
82
      assert( !this->TryEnqueue() );
83
    };
84
85
    void Execute( Worker* user_worker )
86
    {
87
      string transA, transB;
88
      if ( A.IsTransposed() ) transA = "Transpose";
89
      else                    transA = "No transpose";
90
      if ( B.IsTransposed() ) transB = "Transpose";
91
      else                    transB = "No transpose";
92
93
      size_t m = C.row();
94
      size_t n = C.col();
95
      size_t k = A.col();
96
97
      assert( A.row() == m );
98
      assert( B.row() == k );
99
      assert( B.col() == n );
100
101
      //int rand_id = rand();
102
103
      //printf( "%d GEMM task %s %s %lu %lu %lu, %E, %E\n",
104
      //    rand_id, transA.data(), transB.data(), m, n, k, alpha, beta ); fflush( stdout );
105
      //printf( "%d lda %lu ldb %lu ldc %lu\n", rand_id, A.ld(), B.ld(), C.ld() ); fflush( stdout );
106
107
      xgemm( transA.data(), transB.data(), m, n, k,
108
        alpha, A.data(), A.ld(),
109
               B.data(), B.ld(),
110
        beta,  C.data(), C.ld() );
111
112
      //printf( "%d end GEMM task %s %s %lu %lu %lu, %E, %E\n",
113
      //    rand_id, transA.data(), transB.data(), m, n, k, alpha, beta ); fflush( stdout );
114
    };
115
116
}; /** end class xgemmTask */
117
118
119
/**
120
 *  @brief  This task is generated by the top level routine.
121
 **/
122
template<typename T>
123
class xgemmBarrierTask : public Task
124
{
125
  public:
126
127
    View<T> C;
128
129
    void Set(
130
        T alpha, View<T> &A,
131
                 View<T> &B,
132
        T beta,  View<T> &C )
133
    {
134
      /** Main arguments */
135
      this->C = C;
136
      this->stealable = false;
137
138
      /** Name and label */
139
      name = string( "gemmBarrier" );
140
141
      /** Flops, Mops, cost and event */
142
      double flops, mops;
143
      flops = 0.0;
144
      mops  = 0.0;
145
      cost  = 1.0;
146
      event.Set( name + label, flops, mops );
147
    };
148
149
    /** Create RAW dependencies on all submatrices C. */
150
    void DependencyAnalysis() { C.DependencyAnalysis( RW, this ); };
151
152
    void Execute( Worker* user_worker ) {};
153
154
}; /** end class xgemmBarrierTask */
155
156
157
template<typename T>
158
void CreatexgemmTask( T alpha, View<T> &A, View<T> &B, T beta,  View<T> &C )
159
{
160
  auto *task = new xgemmTask<T>();
161
  task->Set( alpha, A, B, beta, C );
162
  task->Submit();
163
  task->DependencyAnalysis();
164
}; /** end xgemmTask() */
165
166
167
/**
168
 *  @brief
169
 */
170
template<size_t NB = 512, typename T>
171
void xgemm_var1( T alpha, View<T> &A, View<T> &B, T beta,  View<T> &C )
172
{
173
  /** All subviews */
174
  View<T> AL, AR,
175
          A0, A1, A2;
176
  View<T> BT, BB,
177
          B0, B1, B2;
178
  /** A = [ AL, AR ] */
179
  A.Partition1x2( AL, AR, 0, LEFT );
180
  /** B = [ BT; BB ] */
181
  B.Partition2x1( BT,
182
                  BB,     0, TOP  );
183
184
  //printf( "AL.col() %lu AR.col() %lu A.col() %lu\n", AL.col(), AR.col(), A.col() );
185
186
  while ( AL.col() < A.col() )
187
  {
188
    //printf( "AL.col() %lu AR.col() %lu A.col() %lu\n", AL.col(), AR.col(), A.col() );
189
    size_t b = std::min( AR.col(), NB );
190
191
    /** Repartition A */
192
    Repartition1x2To1x3( AL,      AR,
193
                         /** **** */
194
                         A0,  A1, A2, b, RIGHT );
195
    /** Repartition B */
196
    Repartition2x1To3x1( BT, /**/ B0,
197
                             /**/ B1,
198
                         BB, /**/ B2, b, BOTTOM );
199
200
    /** --------------------------------------------------- */
201
    CreatexgemmTask( alpha, A1, B1, beta, C );
202
    beta = 1.0;
203
    /** --------------------------------------------------- */
204
205
    /** Merge A */
206
    ContinueWith1x3To1x2( AL,      AR,
207
                          /** **** */
208
                          A0,  A1, A2, LEFT );
209
    /** Merge B */
210
    ContinueWith3x1To2x1( BT, /**/ B0,
211
                              /**/ B1,
212
                          BB, /**/ B2,  TOP );
213
214
  } /** end while */
215
}; /** end xgemm_var1() */
216
217
218
/** @brief [ A * BL + CL, A * BR + CR ] */
219
template<size_t NB = 512, typename T>
220
void xgemm_var2( T alpha, View<T> &A, View<T> &B, T beta,  View<T> &C )
221
{
222
  /** All subviews */
223
  View<T> CL, CR,
224
          C0, C1, C2;
225
  View<T> BL, BR,
226
          B0, B1, B2;
227
228
  C.Partition1x2( CL, CR, 0, LEFT );
229
  B.Partition1x2( BL, BR, 0, LEFT );
230
231
  while ( BL.col() < B.col() )
232
  {
233
    size_t b = std::min( BR.col(), NB );
234
235
    /** Repartition C */
236
    Repartition1x2To1x3( CL,      CR,
237
                         /** **** */
238
                         C0,  C1, C2, b, RIGHT );
239
    /** Repartition B */
240
    Repartition1x2To1x3( BL,      BR,
241
                         /** **** */
242
                         B0,  B1, B2, b, RIGHT );
243
244
    /** --------------------------------------------------- */
245
    xgemm_var1( alpha, A, B1, beta, C1 );
246
    /** --------------------------------------------------- */
247
248
    /** Merge C */
249
    ContinueWith1x3To1x2( CL,      CR,
250
                          /** **** */
251
                          C0,  C1, C2, LEFT );
252
    /** Merge B */
253
    ContinueWith1x3To1x2( BL,      BR,
254
                          /** **** */
255
                          B0,  B1, B2, LEFT );
256
257
  } /** end while */
258
}; /** end xgemm_var2() */
259
260
261
/** @brief [ AT * B + CT; AB * B + CB ] */
262
template<size_t NB = 512, typename T>
263
void xgemm_var3( T alpha, View<T> &A, View<T> &B, T beta,  View<T> &C )
264
{
265
  /** All subviews */
266
  View<T> AT, A0, CT, C0,
267
          AB, A1, CB, C1,
268
              A2,     C2;
269
270
  A.Partition2x1( AT,
271
                  AB,     0, TOP  );
272
  C.Partition2x1( CT,
273
                  CB,     0, TOP  );
274
275
  while ( AT.row() < A.row() )
276
  {
277
    size_t b = std::min( AB.row(), NB );
278
279
    /** Repartition A */
280
    Repartition2x1To3x1( AT, /**/ A0,
281
                             /**/ A1,
282
                         AB, /**/ A2, b, BOTTOM );
283
    /** Repartition B */
284
    Repartition2x1To3x1( CT, /**/ C0,
285
                             /**/ C1,
286
                         CB, /**/ C2, b, BOTTOM );
287
288
    /** --------------------------------------------------- */
289
    xgemm_var2( alpha, A1, B, beta, C1 );
290
    /** --------------------------------------------------- */
291
292
    /** Merge A */
293
    ContinueWith3x1To2x1( AT, /**/ A0,
294
                              /**/ A1,
295
                          AB, /**/ A2,  TOP );
296
    /** Merge C */
297
    ContinueWith3x1To2x1( CT, /**/ C0,
298
                              /**/ C1,
299
                          CB, /**/ C2,  TOP );
300
  }; /** end while */
301
}; /** end xgemm_var3() */
302
303
304
/** @breif  Interface for automatic task-bsed parallelism. */
305
template<size_t NB = 512, typename T>
306
void xgemm( T alpha, View<T> &A, View<T> &B, T beta,  View<T> &C )
307
{
308
  //string transA, transB;
309
  //if ( A.IsTransposed() ) transA = "Transpose";
310
  //else                    transA = "No transpose";
311
  //if ( B.IsTransposed() ) transB = "Transpose";
312
  //else                    transB = "No transpose";
313
  //size_t m = C.row();
314
  //size_t n = C.col();
315
  //size_t k = A.col();
316
  //assert( A.row() == m );
317
  //assert( B.row() == k );
318
  //assert( B.col() == n );
319
  //hmlp::xgemm( transA.data(), transB.data(), m, n, k,
320
  //  alpha, A.data(), A.ld(),
321
  //         B.data(), B.ld(),
322
  //  beta,  C.data(), C.ld() );
323
  //return;
324
325
326
327
328
329
  /** try to  */
330
  A.CreateLeafMatrixBlocks( NB, NB );
331
  B.CreateLeafMatrixBlocks( NB, NB );
332
  C.CreateLeafMatrixBlocks( NB, NB );
333
334
  /** Call back */
335
  if ( hmlp_is_in_epoch_session() )
336
  {
337
    auto *begXGEMMtask = new xgemmBarrierTask<T>();
338
    auto *endXGEMMtask = new xgemmBarrierTask<T>();
339
    /**
340
     *  The reason why we need the begin barrier
341
     *  task is to ensure the whole DAG will be
342
     *  inserted at once. Otherwise, the dependent
343
     *  task may not be created while the traversal
344
     *  has already reached by other workers.
345
     *
346
     *  The solution is to create a beginning barrier
347
     *  such that all the following tasks depend on it.
348
     *  Only enqueue the beginning barrier while all
349
     *  dependent tasks have been created.
350
     */
351
    begXGEMMtask->Set( alpha, A, B, beta, C );
352
    begXGEMMtask->Submit();
353
    begXGEMMtask->DependencyAnalysis();
354
355
    /**
356
     *  Now we create all dependent tasks. Since they
357
     *  all dependent on begXGEMMtask. They all have
358
     *  STATUS=NOTREADY. Thus, no one will be enqueued.
359
     */
360
    xgemm_var3( alpha, A, B, beta, C );
361
362
    /**
363
     *  Create a termination barrier such that it depends
364
     *  on all tasks.
365
     */
366
    endXGEMMtask->Set( alpha, A, B, beta, C );
367
    endXGEMMtask->Submit();
368
    endXGEMMtask->DependencyAnalysis();
369
370
    /** Now enqueue begXGEMMtask and callback with endXGEMMtask. */
371
    begXGEMMtask->TryEnqueue();
372
    endXGEMMtask->CallBackWhileWaiting();
373
  }
374
  else
375
  {
376
    xgemm_var3( alpha, A, B, beta, C );
377
  }
378
}; /** xgemm() */
379
380
381
template<typename T>
382
void xgemm( hmlpOperation_t transA, hmlpOperation_t transB,
383
    T alpha, Data<T> &A, Data<T> &B, T beta,  Data<T> &C )
384
{
385
  const bool TRANS = true;
386
  const bool NOTRANS = true;
387
388
  /** Matrix views of A, B and C */
389
  View<T> Aview, Bview, Cview;
390
391
  /** A and B may be in tranpose view */
392
  if ( transA == HMLP_OP_T ) Aview.Set(  true, A );
393
  else                       Aview.Set( false, A );
394
  if ( transB == HMLP_OP_T ) Bview.Set(  true, B );
395
  else                       Bview.Set( false, B );
396
  /** C is always not transpose */
397
  Cview.Set( C );
398
399
  xgemm( alpha, Aview, Bview, beta, Cview );
400
401
}; /** end xgemm() */
402
403
404
template<typename T>
405
void xgemm( T alpha, Data<T> &A, Data<T> &B, T beta,  Data<T> &C )
406
{
407
  xgemm( HMLP_OP_N, HMLP_OP_N, alpha, A, B, beta, C );
408
}; /** end xgemm() */
409
410
411
412
413
}; /** end namespace gemm */
414
}; /** end namespace hmlp */
415
416
417
#endif /** define GEMM_HPP */