GCC Code Coverage Report
Directory: . Exec Total Coverage
File: frame/primitives/gkmx.hpp Lines: 0 153 0.0 %
Date: 2019-01-14 Branches: 0 144 0.0 %

Line Exec Source
1
/**
2
 *  HMLP (High-Performance Machine Learning Primitives)
3
 *
4
 *  Copyright (C) 2014-2017, The University of Texas at Austin
5
 *
6
 *  This program is free software: you can redistribute it and/or modify
7
 *  it under the terms of the GNU General Public License as published by
8
 *  the Free Software Foundation, either version 3 of the License, or
9
 *  (at your option) any later version.
10
 *
11
 *  This program is distributed in the hope that it will be useful,
12
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
 *  GNU General Public License for more details.
15
 *
16
 *  You should have received a copy of the GNU General Public License
17
 *  along with this program. If not, see the LICENSE file.
18
 *
19
 **/
20
21
22
23
#ifndef GKMX_HPP
24
#define GKMX_HPP
25
26
#include <assert.h>
27
#include <typeinfo>
28
#include <algorithm>
29
30
#include <hmlp.h>
31
#include <hmlp_internal.hpp>
32
#include <hmlp_base.hpp>
33
34
/** for USE_STRASSEN */
35
#include <primitives/strassen.hpp>
36
37
/** reference microkernels */
38
#include <semiring_mrxnr.hpp>
39
#include <fused_mrxnr.hpp>
40
41
//#define GKMX_CONFIG \
42
43
44
namespace hmlp
45
{
46
namespace gkmx
47
{
48
49
/**
50
 *  @brief Macro kernel contains the 3rd and 2nd loops. Depending on the
51
 *         configuration of the communicator, the 3rd loop may be parallelized.
52
 *         b_next is the prefetch pointer.
53
 */
54
template<
55
  int KC, int MR, int NR, int PACK_MR, int PACK_NR,
56
  typename SEMIRINGKERNEL,
57
  typename TA, typename TB, typename TC, typename TV>
58
void rank_k_macro_kernel
59
(
60
  Worker &thread,
61
  int ic, int jc, int pc,
62
  int  m, int n,  int  k,
63
  TA *packA,
64
  TB *packB,
65
  TV *V, int ldv,
66
  SEMIRINGKERNEL semiringkernel
67
)
68
{
69
  thread_communicator &ic_comm = *thread.ic_comm;
70
71
  auto loop3rd = GetRange( 0, n,      NR, thread.jr_id, ic_comm.GetNumThreads() );
72
  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
73
  auto loop2nd = GetRange( 0, m,      MR );
74
  auto pack2nd = GetRange( 0, m, PACK_MR );
75
76
  for ( int j   = loop3rd.beg(), jp  = pack3rd.beg();
77
            j   < loop3rd.end();
78
            j  += loop3rd.inc(), jp += pack3rd.inc() )     // beg 3rd loop
79
  {
80
    struct aux_s<TA, TB, TC, TV> aux;
81
    aux.pc       = pc;
82
    aux.b_next   = packB;
83
    aux.do_packC = 0;
84
    aux.jb       = std::min( n - j, NR );
85
86
    for ( int i  = loop2nd.beg(), ip  = pack2nd.beg();
87
              i  < loop2nd.end();
88
              i += loop2nd.inc(), ip += pack2nd.inc() )    // beg 2nd loop
89
    {
90
      aux.ib = std::min( m - i, MR );
91
      if ( i + MR >= m )
92
      {
93
        aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
94
      }
95
96
      if ( aux.jb == NR && aux.ib == MR )
97
      {
98
        semiringkernel
99
        (
100
          k,
101
          &packA[ ip * k ],
102
          &packB[ jp * k ],
103
          &V[ j * ldv + i ], 1, ldv,
104
          &aux
105
        );
106
      }
107
      else                                                 // corner case
108
      {
109
        TV vtmp[ MR * NR ];
110
111
        if ( pc ) // initilize ctmp
112
        {
113
          for ( auto jj = 0; jj < aux.jb; jj ++ )
114
            for ( auto ii = 0; ii < aux.ib; ii ++ )
115
              vtmp[ jj * MR + ii ] = V[ ( j + jj ) * ldv + i + ii ];
116
        }
117
118
        semiringkernel
119
        (
120
          k,
121
          &packA[ ip * k ],
122
          &packB[ jp * k ],
123
          vtmp, 1, MR,
124
          &aux
125
        );
126
127
        for ( auto jj = 0; jj < aux.jb; jj ++ )
128
          for ( auto ii = 0; ii < aux.ib; ii ++ )
129
            V[ ( j + jj ) * ldv + i + ii ] = vtmp[ jj * MR + ii ];
130
      }
131
    }                                                      // end 2nd loop
132
  }                                                        // end 3rd loop
133
}                                                          // end rank_k_macro_kernel
134
135
136
/**
137
 *  @brief fused_macro_kernel contains the 3rd, 2nd loops and the fused micro
138
 *         kernel. Notice that here C has type TC, which is differnet from the
139
 *         one in rank_k_macro_kernel. ctmp used in the conner case is also
140
 *         type TC.
141
 */
142
template<
143
int KC, int MR, int NR, int PACK_MR, int PACK_NR,
144
bool REUSE_C,
145
typename FUSEDKERNEL,
146
typename TA, typename TB, typename TC, typename TV>
147
void fused_macro_kernel
148
(
149
  Worker &thread,
150
  int ic, int jc, int pc,
151
  int  m,  int n,  int k,
152
  TA *packA,
153
  TB *packB,
154
  TC *C, int ldc,
155
  TV *V, int ldv,
156
  int batchId,
157
  FUSEDKERNEL fusedkernel
158
)
159
{
160
  thread_communicator &ic_comm = *thread.ic_comm;
161
162
  auto loop3rd = GetRange( 0, n,      NR, thread.jr_id, ic_comm.GetNumThreads() );
163
  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
164
  auto loop2nd = GetRange( 0, m,      MR );
165
  auto pack2nd = GetRange( 0, m, PACK_MR );
166
167
  for ( int j   = loop3rd.beg(), jp  = pack3rd.beg();
168
            j   < loop3rd.end();
169
            j  += loop3rd.inc(), jp += pack3rd.inc() )     // beg 3rd loop
170
  {
171
    struct aux_s<TA, TB, TC, TV> aux;
172
    aux.pc       = pc;
173
    aux.b_next   = packB;
174
    aux.do_packC = 0;
175
176
    for ( int i  = loop2nd.beg(), ip  = pack2nd.beg();
177
              i  < loop2nd.end();
178
              i += loop2nd.inc(), ip += pack2nd.inc() )    // beg 2nd loop
179
    {
180
      // These auxiluary infos are used to access data in the closure of
181
      // opkernel and opreduce.
182
      aux.i = ic + i;
183
      aux.j = jc + j;
184
      aux.b = batchId;
185
186
      aux.ib = std::min( m - i, MR );
187
      aux.jb = std::min( n - j, NR );
188
189
      aux.V = V + j * ldv + i;
190
      aux.ldv = ldv;
191
192
      if ( i + MR >= m )
193
      {
194
        aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
195
      }
196
197
      if ( aux.jb == NR && aux.ib == MR )
198
      {
199
        fusedkernel
200
        (
201
          k,
202
          &packA[ ip * k ],
203
          &packB[ jp * k ],
204
          &C[ j * ldc + i ], 1, ldc,
205
          //&C[ ( j / NR ) * ldc + i ], ldc, // for conv_relu_pool
206
          &aux
207
        );
208
      }
209
      else                                                 // corner case
210
      {
211
        TC ctmp[ MR * NR ];
212
        TV vtmp[ MR * NR ];
213
214
        if ( pc ) // initilize ctmp
215
        {
216
          if ( REUSE_C )
217
          {
218
            for ( auto jj = 0; jj < aux.jb; jj ++ )
219
              for ( auto ii = 0; ii < aux.ib; ii ++ )
220
                ctmp[ jj * MR + ii ] = C[ ( j + jj ) * ldc + i + ii ];
221
          }
222
          else
223
          {
224
            for ( auto jj = 0; jj < aux.jb; jj ++ )
225
              for ( auto ii = 0; ii < aux.ib; ii ++ )
226
                vtmp[ jj * MR + ii ] = V[ ( j + jj ) * ldv + i + ii ];
227
            aux.V = vtmp;
228
            aux.ldv = MR;
229
          }
230
        }
231
232
        fusedkernel
233
        (
234
          k,
235
          &packA[ ip * k ],
236
          &packB[ jp * k ],
237
          ctmp, 1, MR,
238
          &aux
239
        );
240
241
        for ( auto jj = 0; jj < aux.jb; jj ++ )
242
          for ( auto ii = 0; ii < aux.ib; ii ++ )
243
            C[ ( j + jj ) * ldc + i + ii ] = ctmp[ jj * MR + ii ];
244
245
      }
246
    }                                                      // end 2nd loop
247
  }                                                        // end 3rd loop
248
};                                                         // end fused_macro_kernel
249
250
251
252
253
254
/**
255
 *  @breif This function contains the loop body of the 6th to 4th loops,
256
 *         including all packing and unpacking routines. Notice that this
257
 *         function is executed by all threads in the root communicator.
258
 *         To access each thread in different level of communicators, use
259
 *         their ids.
260
 */
261
template<
262
  int MC,
263
  int NC,
264
  int KC,
265
  int MR,
266
  int NR,
267
  int PACK_MC,
268
  int PACK_NC,
269
  int PACK_MR,
270
  int PACK_NR,
271
  int ALIGN_SIZE,
272
  bool USE_STRASSEN,
273
  bool REUSE_C,
274
  typename SEMIRINGKERNEL, typename MICROKERNEL,
275
  typename TA, typename TB, typename TC, typename TV>
276
void gkmx_internal
277
(
278
  Worker &thread,
279
  hmlpOperation_t transA, hmlpOperation_t transB,
280
  int m, int n, int k, int k_stra,
281
  TA *A, int lda,
282
  TB *B, int ldb,
283
  TC *C, int ldc,
284
  TV *V, int ldv,
285
  int batchId,
286
  SEMIRINGKERNEL semiringkernel,
287
  MICROKERNEL microkernel,
288
  int nc, int pack_nc,
289
  TA *packA,
290
  TB *packB
291
)
292
{
293
  packA  += ( thread.jc_id * thread.ic_nt                ) * PACK_MC * KC
294
          + ( thread.ic_id                               ) * PACK_MC * KC;
295
  packB  += ( thread.jc_id                               ) * pack_nc * KC;
296
297
  auto loop6th = GetRange( 0,      n, nc, thread.jc_id, thread.jc_nt );
298
  auto loop5th = GetRange( k_stra, k, KC );
299
  auto loop4th = GetRange( 0,      m, MC, thread.ic_id, thread.ic_nt );
300
301
  for ( int jc  = loop6th.beg();
302
            jc  < loop6th.end();
303
            jc += loop6th.inc() )                          // beg 6th loop
304
  {
305
    auto &jc_comm = *thread.jc_comm;
306
    auto jb = std::min( n - jc, nc );
307
308
    for ( int pc  = loop5th.beg();
309
              pc  < loop5th.end();
310
              pc += loop5th.inc() )
311
    {
312
      auto &pc_comm = *thread.pc_comm;
313
      auto pb = std::min( k - pc, KC );
314
      auto is_the_last_pc_iteration = ( pc + KC >= k );
315
      auto looppkB = GetRange( 0, jb,      NR, thread.ic_jr, pc_comm.GetNumThreads() );
316
      auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
317
318
      for ( int j   = looppkB.beg(), jp  = packpkB.beg();
319
                j   < looppkB.end();
320
                j  += looppkB.inc(), jp += packpkB.inc() )
321
      {
322
        if ( transB == HMLP_OP_N )
323
        {
324
          pack2D<true, PACK_NR>                            // packB
325
          (
326
            std::min( jb - j, NR ), pb,
327
            &B[ ( jc + j ) * ldb + pc ], ldb, &packB[ jp * pb ]
328
          );
329
        }
330
        else
331
        {
332
          pack2D<false, PACK_NR>                           // packB (transB)
333
          (
334
            std::min( jb - j, NR ), pb,
335
            &B[ pc * ldb + ( jc + j ) ], ldb, &packB[ jp * pb ]
336
          );
337
        }
338
      }
339
      pc_comm.Barrier();
340
341
      for ( int ic  = loop4th.beg();
342
                ic  < loop4th.end();
343
                ic += loop4th.inc() )                      // beg 4th loop
344
      {
345
        auto &ic_comm = *thread.ic_comm;
346
        auto ib = std::min( m - ic, MC );
347
        auto looppkA = GetRange( 0, ib,      MR, thread.jr_id, thread.jr_nt );
348
        auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt );
349
350
        for ( int i   = looppkA.beg(), ip  = packpkA.beg();
351
                  i   < looppkA.end();
352
                  i  += looppkA.inc(), ip += packpkA.inc() )
353
        {
354
          if ( transA == HMLP_OP_N )
355
          {
356
            pack2D<false, PACK_MR>                         // packA
357
            (
358
              std::min( ib - i, MR ), pb,
359
              &A[ pc * lda + ( ic + i ) ], lda, &packA[ ip * pb ]
360
            );
361
          }
362
          else
363
          {
364
            pack2D<true, PACK_MR>                          // packA (transA)
365
            (
366
              std::min( ib - i, MR ), pb,
367
              &A[ ( ic + i ) * lda + pc ], lda, &packA[ ip * pb ]
368
            );
369
          }
370
        }
371
        ic_comm.Barrier();
372
373
        if ( is_the_last_pc_iteration )                    // fused_macro_kernel
374
        {
375
          fused_macro_kernel
376
          <KC, MR, NR, PACK_MR, PACK_NR, REUSE_C, MICROKERNEL, TA, TB, TC, TV>
377
          (
378
            thread,
379
            ic, jc, pc,
380
            ib, jb, pb,
381
            packA,
382
            packB,
383
            C + jc * ldc + ic, ldc,
384
            V + jc * ldv + ic, ldv, // if REUSE_C, then V = C.
385
            batchId,
386
            microkernel
387
          );
388
        }
389
        else                                               // semiring rank-k update
390
        {
391
          rank_k_macro_kernel
392
          <KC, MR, NR, PACK_MR, PACK_NR, SEMIRINGKERNEL, TA, TB, TC, TV>
393
          (
394
            thread,
395
            ic, jc, pc,
396
            ib, jb, pb,
397
            packA,
398
            packB,
399
            //C + jc * ldc + ic, ldc,
400
            V + jc * ldv + ic, ldv,
401
            semiringkernel
402
          );
403
        }
404
        ic_comm.Barrier();                                 // sync all jr_id!!
405
      }                                                    // end 4th loop
406
      pc_comm.Barrier();
407
    }                                                      // end 5th loop
408
  }                                                        // end 6th loop
409
}                                                          // end gkmx_internal
410
411
412
413
414
415
/**
416
 *  @breif This is the main routine of gkmx. All packing buffers are
417
 *         managed here. The communicator and the parallel section
418
 *         start here.
419
 *
420
 */
421
template<
422
  int MC,
423
  int NC,
424
  int KC,
425
  int MR,
426
  int NR,
427
  int PACK_MC,
428
  int PACK_NC,
429
  int PACK_MR,
430
  int PACK_NR,
431
  int ALIGN_SIZE,
432
  bool USE_STRASSEN = false,
433
  bool REUSE_C,
434
  typename SEMIRINGKERNEL, typename MICROKERNEL,
435
  typename TA, typename TB, typename TC, typename TV = TC>
436
void gkmx
437
(
438
  hmlpOperation_t transA, hmlpOperation_t transB,
439
  int m, int n, int k,
440
  TA *A, int lda,
441
  TB *B, int ldb,
442
  TC *C, int ldc,
443
  int batchId,
444
  SEMIRINGKERNEL semiringkernel,
445
  MICROKERNEL microkernel
446
)
447
{
448
  int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1;
449
  int k_stra = 0;
450
  int ldv = 0;
451
  int nc = NC, pack_nc = PACK_NC;
452
  char *str;
453
454
  TA *packA_buff = NULL;
455
  TB *packB_buff = NULL;
456
  TV *V = NULL;
457
458
  // Early return if possible
459
  if ( m == 0 || n == 0 || k == 0 ) return;
460
461
  // type checking (currently assume TC == TV)
462
  if ( typeid(TC) != typeid(TV) && k > KC )
463
  {
464
    printf( "gkmx: currently k(%d) must be smaller than %d when TC != TV\n", k, KC );
465
    exit( 1 );
466
  }
467
468
  if ( omp_get_num_threads() == 1 && omp_get_max_threads() > 1 )
469
  {
470
    // Check the environment variable.
471
    jc_nt = hmlp_read_nway_from_env( "KS_JC_NT" );
472
    ic_nt = hmlp_read_nway_from_env( "KS_IC_NT" );
473
    jr_nt = hmlp_read_nway_from_env( "KS_JR_NT" );
474
  }
475
476
  if ( jc_nt > 1 )
477
  {
478
    nc = ( ( n - 1 ) / ( NR * jc_nt ) + 1 ) * NR;
479
    pack_nc = ( nc / NR ) * PACK_NR;
480
  }
481
482
  // allocate packing memory
483
  packA_buff  = hmlp_malloc<ALIGN_SIZE, TA>( KC * ( PACK_MC + 1 ) * jc_nt * ic_nt );
484
  packB_buff  = hmlp_malloc<ALIGN_SIZE, TB>( KC * ( pack_nc + 1 ) * jc_nt         );
485
486
487
  // allocate V if k > KC
488
  if ( k > KC && !std::is_same<TC, TV>::value && !REUSE_C )
489
  {
490
    V = hmlp_malloc<ALIGN_SIZE, TV>( m * n );
491
    ldv = m;
492
  }
493
  else // TODO: do not free V in this case.
494
  {
495
    V = reinterpret_cast<TV*>( C );
496
    ldv = ldc;
497
  }
498
499
  // allocate tree communicator
500
  thread_communicator my_comm( jc_nt, pc_nt, ic_nt, jr_nt );
501
502
503
  if ( USE_STRASSEN )
504
  {
505
    assert( typeid(TA) == typeid(TB) );
506
    assert( typeid(TC) == typeid(TV) );
507
    k_stra = k - k % KC;
508
509
    if ( k_stra == k ) k_stra -= KC;
510
511
    if ( k_stra )
512
    {
513
      #pragma omp parallel for
514
      for ( int i = 0; i < n * ldv; i ++ ) V[ i ] = 0.0;
515
    }
516
  }
517
518
519
  #pragma omp parallel num_threads( my_comm.GetNumThreads() )
520
  {
521
    Worker thread( &my_comm );
522
523
    if ( USE_STRASSEN )
524
    {
525
      strassen::strassen_internal
526
      <MC, NC, KC, MR, NR,
527
      PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
528
      USE_STRASSEN,
529
      SEMIRINGKERNEL, SEMIRINGKERNEL,
530
      TA, TB, TC, TV>
531
      (
532
        thread,
533
        transA, transB,
534
        m, n, k_stra,
535
        A, lda,
536
        B, ldb,
537
        V, ldv,
538
        semiringkernel, semiringkernel,
539
        nc, pack_nc,
540
        packA_buff,
541
        packB_buff
542
      );
543
    }
544
545
    gkmx_internal
546
    <MC, NC, KC, MR, NR,
547
    PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
548
    USE_STRASSEN, REUSE_C,
549
    SEMIRINGKERNEL, MICROKERNEL,
550
    TA, TB, TC, TV>
551
    (
552
      thread,
553
      transA, transB,
554
      m, n, k, k_stra,
555
      A, lda,
556
      B, ldb,
557
      C, ldc,
558
      V, ldv,
559
      batchId,
560
      semiringkernel, microkernel,
561
      nc, pack_nc,
562
      packA_buff,
563
      packB_buff
564
    );
565
  }                                                        // end omp parallel
566
567
  hmlp_free( packA_buff );
568
  hmlp_free( packB_buff );
569
  //hmlp_free( V );
570
};                                                         // end gkmx
571
572
573
574
575
576
/**
577
 *  @beief
578
 */
579
template<
580
  int MC            = 104,
581
  int NC            = 1024,
582
  int KC            = 256,
583
  int MR            = 8,
584
  int NR            = 4,
585
  int PACK_MC       = 104,
586
  int PACK_NC       = 1024,
587
  int PACK_MR       = 8,
588
  int PACK_NR       = 4,
589
  int ALIGN_SIZE    = 32,
590
  bool USE_STRASSEN = false,
591
  bool REUSE_C = false,
592
  typename OPKERNEL, typename OP1, typename OP2,
593
  typename TA, typename TB, typename TC, typename TV>
594
void gkmm
595
(
596
  hmlpOperation_t transA, hmlpOperation_t transB,
597
  int m, int n, int k,
598
  TA *A, int lda,
599
  TB *B, int ldb,
600
  TC *C, int ldc,
601
  int batchId,
602
  OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV
603
)
604
{
605
  semiring_mrxnr<MR, NR, OP1, OP2, TA, TB, TC, TV> semiringkernel;
606
  gkmm_mrxnr<MR, NR, OPKERNEL, OP1, OP2, TA, TB, TC, TV> gkmmkernel;
607
608
  semiringkernel.op1 = op1;
609
  semiringkernel.op2 = op2;
610
  semiringkernel.initV = initV;
611
612
  gkmmkernel.op1 = op1;
613
  gkmmkernel.op2 = op2;
614
  gkmmkernel.opkernel = opkernel;
615
  gkmmkernel.initV = initV;
616
617
  gkmx
618
  <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
619
  USE_STRASSEN, REUSE_C,
620
  semiring_mrxnr<MR, NR, OP1, OP2, TA, TB, TC, TV>,
621
  gkmm_mrxnr<MR, NR, OPKERNEL, OP1, OP2, TA, TB, TC, TV>,
622
  TA, TB, TC, TV>
623
  (
624
    transA, transB,
625
    m, n, k,
626
    A, lda,
627
    B, ldb,
628
    C, ldc,
629
    batchId,
630
    semiringkernel, gkmmkernel
631
  );
632
};
633
634
635
/**
636
 *  @brief batched interface with array of arrays
637
 *
638
 *  TODO: the problem is how to manage thread here? Do I want to use omp
639
 *  nested? or there is a better way to deal with this.
640
 *
641
 */
642
template<
643
  int MC, int NC, int KC, int MR, int NR,
644
  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
645
  bool USE_STRASSEN, bool REUSE_C,
646
  typename OPKERNEL, typename OP1, typename OP2,
647
  typename TA, typename TB, typename TC, typename TV>
648
void gkmm
649
(
650
  hmlpOperation_t transA, hmlpOperation_t transB,
651
  int m, int n, int k,
652
  TA *Aarray[], int lda,
653
  TB *Barray[], int ldb,
654
  TC *Carray[], int ldc,
655
  int batchSize,
656
  OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV
657
)
658
{
659
  #pragma omp parallel for
660
  for ( auto b = 0; b < batchSize; b ++ )
661
  {
662
    gkmm
663
    <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
664
    USE_STRASSEN,
665
    OPKERNEL, OP1, OP2,
666
    TA, TB, TC, TV>
667
    (
668
      transA, transB,
669
      m, n, k,
670
      Aarray[ b ], lda,
671
      Barray[ b ], ldb,
672
      Carray[ b ], ldc,
673
      b,
674
      opkernel, op1, op2, initV
675
    );
676
  }
677
}; // end gkmm
678
679
680
/**
681
 *  @brief batched interface with strides
682
 *
683
 *  TODO: the problem is how to manage thread here? Do I want to use omp
684
 *  nested? or there is a better way to deal with this.
685
 *
686
 */
687
template<
688
  int MC,
689
  int NC,
690
  int KC, int MR, int NR,
691
  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
692
  bool USE_STRASSEN, bool REUSE_C,
693
  typename OPKERNEL, typename OP1, typename OP2,
694
  typename TA, typename TB, typename TC, typename TV>
695
void gkmm
696
(
697
  hmlpOperation_t transA, hmlpOperation_t transB,
698
  int m, int n, int k,
699
  TA *Aarray, int lda, int loa,
700
  TB *Barray, int ldb, int lob,
701
  TC *Carray, int ldc, int loc,
702
  int batchSize,
703
  OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV
704
)
705
{
706
  #pragma omp parallel for
707
  for ( auto b = 0; b < batchSize; b ++ )
708
  {
709
    gkmm
710
    <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
711
    USE_STRASSEN, REUSE_C,
712
    OPKERNEL, OP1, OP2,
713
    TA, TB, TC, TV>
714
    (
715
      transA, transB,
716
      m, n, k,
717
      Aarray + b * loa, lda,
718
      Barray + b * lob, ldb,
719
      Carray + b * loc, ldc,
720
      b,
721
      opkernel, op1, op2, initV
722
    );
723
  }
724
}; // end gkmm
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
/**
745
 *  @beief Implement GKRM with GKMX template. Notice that OPREDUCE
746
 *         is handled inside fusedkernel. Updating microkernel has
747
 *         to be atomic if jc_nt or jr_nt is not 1. We may be atomic
748
 *         update.
749
 *
750
 */
751
template<
752
  int MC            = 104,
753
  int NC            = 1024,
754
  int KC            = 256,
755
  int MR            = 8,
756
  int NR            = 4,
757
  int PACK_MC       = 104,
758
  int PACK_NC       = 1024,
759
  int PACK_MR       = 8,
760
  int PACK_NR       = 4,
761
  int ALIGN_SIZE    = 32,
762
  bool USE_STRASSEN = false,
763
  typename OPKERNEL, typename OP1, typename OP2, typename OPREDUCE,
764
  typename TA, typename TB, typename TC, typename TV = TC>
765
void gkrm
766
(
767
  hmlpOperation_t transA, hmlpOperation_t transB,
768
  int m, int n, int k,
769
  TA *A, int lda,
770
  TB *B, int ldb,
771
  TC *C, int ldc,
772
  int batchId,
773
  OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV,
774
  OPREDUCE opreduce, TC initC
775
)
776
{
777
  semiring_mrxnr<MR, NR, OP1, OP2, TA, TB, TC, TV> semiringkernel;
778
  gkrm_mrxnr<MR, NR, OPKERNEL, OP1, OP2, OPREDUCE, TA, TB, TC, TV> gkrmkernel;
779
780
  semiringkernel.op1 = op1;
781
  semiringkernel.op2 = op2;
782
  semiringkernel.initV = initV;
783
784
  gkrmkernel.op1 = op1;
785
  gkrmkernel.op2 = op2;
786
  gkrmkernel.opkernel = opkernel;
787
  gkrmkernel.initV = initV;
788
  gkrmkernel.opreduce = opreduce;
789
  gkrmkernel.initC = initC;
790
791
  gkmx
792
  <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
793
  USE_STRASSEN,
794
  semiring_mrxnr<MR, NR, OP1, OP2, TA, TB, TC, TV>,
795
  gkmm_mrxnr<MR, NR, OPKERNEL, OP1, OP2, TA, TB, TC, TV>,
796
  TA, TB, TC, TV>
797
  (
798
    transA, transB,
799
    m, n, k,
800
    A, lda,
801
    B, ldb,
802
    C, 0, // TODO: is there a better way to do this?
803
    batchId,
804
    semiringkernel, gkrmkernel
805
  );
806
}; // end gkrm
807
808
809
810
811
/**
812
 *  @breif This is a simple triple loop reference.
813
 */
814
template<
815
  typename OPKERNEL, typename OP1, typename OP2,
816
  typename TA, typename TB, typename TC, typename TV = TC>
817
void gkmm_ref
818
(
819
 hmlpOperation_t transA, hmlpOperation_t transB,
820
 int m, int n, int k,
821
 TA *A, int lda,
822
 TB *B, int ldb,
823
 TC *C, int ldc,
824
 OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV
825
)
826
{
827
  for ( int i = 0; i < m; i ++ )
828
  {
829
    for ( int j = 0; j < n; j ++ )
830
    {
831
      auto v = initV;
832
      for ( int p = 0; p < k; p ++ )
833
      {
834
        TA a;
835
        TB b;
836
        if ( transA == HMLP_OP_N ) a = A[ p * lda + i ];
837
        else                       a = A[ i * lda + p ];
838
        if ( transB == HMLP_OP_N ) b = B[ j * ldb + p ];
839
        else                       b = B[ p * ldb + j ];
840
        v = op1( v, op2( a, b ) );
841
      }
842
      C[ j * ldc + i ] = opkernel( v );
843
    }
844
  }
845
}; // end gkmm_ref
846
847
848
/**
849
 *  @breif This is a simple triple loop reference.
850
 *
851
 *  TODO: ldc is strange here, assuming that C is a vector.
852
 */
853
template<
854
  typename OPKERNEL, typename OP1, typename OP2, typename OPREDUCE,
855
  typename TA, typename TB, typename TC, typename TV = TC>
856
void gkrm_ref
857
(
858
 hmlpOperation_t transA, hmlpOperation_t transB,
859
 int m, int n, int k,
860
 TA *A, int lda,
861
 TB *B, int ldb,
862
 TC *C, int ldc,
863
 int batchId,
864
 OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV,
865
 OPREDUCE opreduce, TC initC
866
 )
867
{
868
  for ( int i = 0; i < m; i ++ )
869
  {
870
    auto c = initC;
871
    for ( int j = 0; j < n; j ++ )
872
    {
873
      auto v = initV;
874
      for ( int p = 0; p < k; p ++ )
875
      {
876
        TA a;
877
        TB b;
878
        if ( transA == HMLP_OP_N ) a = A[ p * lda + i ];
879
        else                       a = A[ i * lda + p ];
880
        if ( transB == HMLP_OP_N ) b = B[ j * ldb + p ];
881
        else                       b = B[ p * ldb + j ];
882
        v = op1( v, op2( a, b ) );
883
      }
884
      c = opreduce( c, opkernel( v ) );
885
    }
886
    C[ i ] = c;
887
  }
888
}; // end gkrm_ref
889
890
891
}; // end namespace gkmx
892
}; // end namespace hmlp
893
894
#endif // define GKMX_HPP