GCC Code Coverage Report
Directory: . Exec Total Coverage
File: frame/primitives/gnbx.hpp Lines: 0 154 0.0 %
Date: 2019-01-14 Branches: 0 186 0.0 %

Line Exec Source
1
/**
2
 *  HMLP (High-Performance Machine Learning Primitives)
3
 *
4
 *  Copyright (C) 2014-2017, The University of Texas at Austin
5
 *
6
 *  This program is free software: you can redistribute it and/or modify
7
 *  it under the terms of the GNU General Public License as published by
8
 *  the Free Software Foundation, either version 3 of the License, or
9
 *  (at your option) any later version.
10
 *
11
 *  This program is distributed in the hope that it will be useful,
12
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
 *  GNU General Public License for more details.
15
 *
16
 *  You should have received a copy of the GNU General Public License
17
 *  along with this program. If not, see the LICENSE file.
18
 *
19
 **/
20
21
22
23
#ifndef GNBX_HPP
24
#define GNBX_HPP
25
26
#include <assert.h>
27
#include <typeinfo>
28
#include <algorithm>
29
30
#include <hmlp.h>
31
#include <hmlp_internal.hpp>
32
#include <hmlp_base.hpp>
33
34
/** for USE_STRASSEN */
35
//#include <primitives/strassen.hpp>
36
37
/** reference microkernels */
38
#include <packing.hpp>
39
#include <semiring_mrxnr.hpp>
40
#include <fused_mrxnr.hpp>
41
42
using namespace std;
43
44
45
namespace hmlp
46
{
47
namespace gnbx
48
{
49
50
  /**
51
 *  @brief Macro kernel contains the 3rd and 2nd loops. Depending on the
52
 *         configuration of the communicator, the 3rd loop may be parallelized.
53
 *         b_next is the prefetch pointer.
54
 */
55
template<int KC, typename SEMIRINGKERNEL, typename TA, typename TB, typename TV>
56
void rank_k_macro_kernel
57
(
58
  Worker &Comm4th,
59
  int ic, int jc, int pc,
60
  int  m, int  n, int  k,
61
  TA *packA,
62
  TB *packB,
63
  TV *V, int rs_v, int cs_v,
64
  SEMIRINGKERNEL semiringkernel
65
)
66
{
67
  /** Get all block sizes */
68
  const static int MR         = SEMIRINGKERNEL::mr;
69
  const static int NR         = SEMIRINGKERNEL::nr;
70
  const static int PACK_MR    = SEMIRINGKERNEL::pack_mr;
71
  const static int PACK_NR    = SEMIRINGKERNEL::pack_nr;
72
73
  /** Get ic loop communicator */
74
  thread_communicator &ic_comm = *Comm4th.comm;
75
76
  /** Compute loop ranges for each thread */
77
  auto Loop3rd = Comm4th.DistributeOver1DGangs(        0, n,      NR );
78
  auto Pack3rd = Comm4th.DistributeOver1DGangs(        0, n, PACK_NR );
79
  auto Loop2nd = Comm4th.DistributeOver1DThreads(      0, m,      MR );
80
  auto Pack2nd = Comm4th.DistributeOver1DThreads(      0, m, PACK_MR );
81
82
  /** Loop 3rd (jr loop) */
83
  for ( int j  = get<0>( Loop3rd ), jp  = get<0>( Pack3rd );
84
            j  < get<1>( Loop3rd );
85
            j += get<2>( Loop3rd ), jp += get<2>( Pack3rd ) )
86
  {
87
    struct aux_s<TA, TB, TV, TV> aux;
88
    aux.pc       = pc;
89
    aux.b_next   = packB;
90
    aux.do_packC = 0;
91
    aux.jb       = std::min( n - j, NR );
92
93
    /** Loop 2nd (ir loop) */
94
    for ( int i  = get<0>( Loop2nd ), ip  = get<0>( Pack2nd );
95
              i  < get<1>( Loop2nd );
96
              i += get<2>( Loop2nd ), ip += get<2>( Pack2nd ) )
97
    {
98
      aux.ib = std::min( m - i, MR );
99
      if ( i + MR >= m )
100
      {
101
        aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
102
      }
103
104
      if ( aux.jb == NR && aux.ib == MR )
105
      {
106
        semiringkernel
107
        (
108
          k,
109
          &packA[ ip * k ],
110
          &packB[ jp * k ],
111
          &V[ i * rs_v + j * cs_v ], rs_v, cs_v,
112
          &aux
113
        );
114
      }
115
      else                                                 // corner case
116
      {
117
        TV vtmp[ MR * NR ];
118
119
        if ( pc ) // initilize ctmp
120
        {
121
          for ( auto jj = 0; jj < aux.jb; jj ++ )
122
            for ( auto ii = 0; ii < aux.ib; ii ++ )
123
              vtmp[ jj * MR + ii ] =
124
                V[ ( j + jj ) * cs_v + ( i + ii ) * rs_v ];
125
        }
126
127
        semiringkernel
128
        (
129
          k,
130
          &packA[ ip * k ],
131
          &packB[ jp * k ],
132
          vtmp, 1, MR,
133
          &aux
134
        );
135
136
        for ( auto jj = 0; jj < aux.jb; jj ++ )
137
          for ( auto ii = 0; ii < aux.ib; ii ++ )
138
            V[ ( j + jj ) * cs_v + ( i + ii ) * rs_v ] = vtmp[ jj * MR + ii ];
139
      }
140
    }                                                      // end 2nd loop
141
  }                                                        // end 3rd loop
142
};                                                         // end rank_k_macro_kernel
143
144
145
146
147
148
/**
149
 *  @brief fused_macro_kernel contains the 3rd, 2nd loops and the fused micro
150
 *         kernel. Notice that here C has type TC, which is differnet from the
151
 *         one in rank_k_macro_kernel. ctmp used in the conner case is also
152
 *         type TC.
153
 */
154
template<int KC, typename FUSEDKERNEL, typename TA, typename TB, typename TC, typename TV>
155
void fused_macro_kernel
156
(
157
  Worker &Comm4th,
158
  int m, int n,
159
  int ic, int jc, int pc,
160
  int mc, int nc, int kc,
161
  TA *packA,
162
  TB *packB,
163
  TC *C,
164
  TV *V, int rs_v, int cs_v,
165
  int batchId,
166
  FUSEDKERNEL fusedkernel
167
)
168
{
169
  /** Get all block sizes */
170
  const static int MR         = FUSEDKERNEL::mr;
171
  const static int NR         = FUSEDKERNEL::nr;
172
  const static int PACK_MR    = FUSEDKERNEL::pack_mr;
173
  const static int PACK_NR    = FUSEDKERNEL::pack_nr;
174
175
  /** Get ic loop communicator */
176
  thread_communicator &ic_comm = *Comm4th.comm;
177
178
  /** Compute loop ranges for each thread */
179
  auto Loop3rd = Comm4th.DistributeOver1DGangs(        0, nc,      NR );
180
  auto Pack3rd = Comm4th.DistributeOver1DGangs(        0, nc, PACK_NR );
181
  auto Loop2nd = Comm4th.DistributeOver1DThreads(      0, mc,      MR );
182
  auto Pack2nd = Comm4th.DistributeOver1DThreads(      0, mc, PACK_MR );
183
184
  /** Loop 3rd (jr loop) */
185
  for ( int j  = get<0>( Loop3rd ), jp  = get<0>( Pack3rd );
186
            j  < get<1>( Loop3rd );
187
            j += get<2>( Loop3rd ), jp += get<2>( Pack3rd ) )
188
  {
189
    struct aux_s<TA, TB, TC, TV> aux;
190
    aux.pc       = pc;
191
    aux.b_next   = packB;
192
    aux.do_packC = 0;
193
194
    /** Loop 2nd (ir loop) */
195
    for ( int i  = get<0>( Loop2nd ), ip  = get<0>( Pack2nd );
196
              i  < get<1>( Loop2nd );
197
              i += get<2>( Loop2nd ), ip += get<2>( Pack2nd ) )
198
    {
199
      /**
200
       *  These auxiluary infos are used to access data in the closure of
201
       *  opkernel and opreduce.
202
       */
203
      aux.m = m;
204
      aux.n = n;
205
      aux.i = ic + i;
206
      aux.j = jc + j;
207
      aux.b = batchId;
208
209
      /**
210
       *  Encapsulate edge case information.
211
       */
212
      aux.ib = std::min( mc - i, MR );
213
      aux.jb = std::min( nc - j, NR );
214
215
      /**
216
       * Prepare the intermediate semiring rank-k update
217
       */
218
      aux.V = V + i * rs_v + j * cs_v;
219
      aux.ldv = cs_v;
220
221
      if ( i + MR >= mc )
222
      {
223
        aux.b_next += ic_comm.GetNumThreads() * PACK_NR * kc;
224
      }
225
226
      if ( aux.jb == NR && aux.ib == MR )
227
      {
228
        fusedkernel
229
        (
230
          kc,
231
          &packA[ ip * kc ],
232
          &packB[ jp * kc ],
233
          C,
234
          &V[ i * rs_v + j * cs_v ], rs_v, cs_v,
235
          &aux
236
        );
237
      }
238
      else
239
      {
240
        TV vtmp[ MR * NR ];
241
        if ( pc ) // initilize ctmp
242
        {
243
          for ( auto jj = 0; jj < aux.jb; jj ++ )
244
            for ( auto ii = 0; ii < aux.ib; ii ++ )
245
              vtmp[ jj * MR + ii ] =
246
                V[ ( j + jj ) * cs_v + ( i + ii ) * rs_v ];
247
          aux.V = vtmp;
248
          aux.ldv = MR;
249
        }
250
        fusedkernel
251
        (
252
          kc,
253
          &packA[ ip * kc ],
254
          &packB[ jp * kc ],
255
          C,
256
          vtmp, 1, MR,
257
          &aux
258
        );
259
      }
260
    }
261
  }
262
263
}; /** end fused_macro_kernel() */
264
265
266
267
268
/**
269
 *  @breif This function contains the loop body of the 6th to 4th loops,
270
 *         including all packing and unpacking routines. Notice that this
271
 *         function is executed by all threads in the root communicator.
272
 *         To access each thread in different level of communicators, use
273
 *         their ids.
274
 */
275
template<
276
  int MC, int NC, int KC,
277
  typename TPACKA, typename TPACKB, typename TV,
278
  typename     TA, typename     TB, typename TC,
279
  typename SEMIRINGKERNEL, typename MICROKERNEL>
280
void gnbx_internal
281
(
282
  Worker &thread,
283
  int batchId, int m, int n, int k, int k_stra,
284
  TA& A,
285
  TB& B,
286
  TC& C,
287
  TV* V, int rs_v, int cs_v,
288
  SEMIRINGKERNEL semiringkernel,
289
  MICROKERNEL microkernel
290
)
291
{
292
  /** Get all block sizes */
293
  const static int MR         = SEMIRINGKERNEL::mr;
294
  const static int NR         = SEMIRINGKERNEL::nr;
295
  const static int PACK_MR    = SEMIRINGKERNEL::pack_mr;
296
  const static int PACK_NR    = SEMIRINGKERNEL::pack_nr;
297
  const static int ALIGN_SIZE = SEMIRINGKERNEL::align_size;
298
  const static int PACK_MC    = ( MC / MR ) * PACK_MR;
299
  const static int PACK_NC    = ( NC / NR ) * PACK_NR;
300
301
  /** Create subcommunicators for each loop */
302
  auto CommGLB = thread.Split();
303
  auto Comm6th = CommGLB.Split();
304
  auto Comm5th = Comm6th.Split();
305
  auto Comm4th = Comm5th.Split();
306
307
308
  /** Adjuest nc and pack_nc if the 6th loop is parallelized */
309
  int nc = CommGLB.BalanceOver1DGangs( n, NC, NR );
310
  int pack_nc = ( nc / NR ) * PACK_NR;
311
312
313
314
  //printf( "CommGLB %s tid %d gid %d ngangs %d\n", CommGLB.comm->name.data(), CommGLB.tid, CommGLB.gid, CommGLB.comm->GetNumGroups() );
315
  //printf( "Comm6th %s tid %d gid %d ngangs %d\n", Comm6th.comm->name.data(), Comm6th.tid, Comm6th.gid, Comm6th.comm->GetNumGroups() );
316
  //printf( "Comm5th %s tid %d gid %d ngangs %d\n", Comm5th.comm->name.data(), Comm5th.tid, Comm5th.gid, Comm5th.comm->GetNumGroups() );
317
  //printf( "Comm4th %s tid %d gid %d ngangs %d\n", Comm4th.comm->name.data(), Comm4th.tid, Comm4th.gid, Comm4th.comm->GetNumGroups() );
318
  //fflush( stdout );
319
320
  /**
321
   *  Allocate packing buffers:
322
   *
323
   *  packA is shared over Comm4th
324
   *  packB is shared over Comm5th
325
   */
326
  auto *packA = Comm4th.AllocateSharedMemory<ALIGN_SIZE, TPACKA>( KC * ( PACK_MC + 1 ) );
327
  auto *packB = Comm5th.AllocateSharedMemory<ALIGN_SIZE, TPACKB>( KC * ( pack_nc + 1 ) );
328
329
  /** Compute loop ranges for each thread */
330
  auto Loop6th = CommGLB.DistributeOver1DGangs(      0, n, nc );
331
  auto Loop5th = Comm6th.DistributeOver1DGangs( k_stra, k, KC );
332
  auto Loop4th = Comm5th.DistributeOver1DGangs(      0, m, MC );
333
334
  /** Comm6th is used inside the 6th loop (i.e. jc loop) */
335
  for ( int jc  = get<0>( Loop6th );
336
            jc  < get<1>( Loop6th );
337
            jc += get<2>( Loop6th ) )
338
  {
339
    auto jb = std::min( n - jc, nc );
340
341
342
    /** Comm5th is used inside the 6th loop (i.e. pc loop) */
343
    for ( int pc  = get<0>( Loop5th );
344
              pc  < get<1>( Loop5th );
345
              pc += get<2>( Loop5th ) )
346
    {
347
      auto pb = std::min( k - pc, KC );
348
      auto is_the_last_pc_iteration = ( pc + KC >= k );
349
      auto LooppkB = Comm5th.DistributeOver1DThreads( 0, jb,      NR );
350
      auto PackpkB = Comm5th.DistributeOver1DThreads( 0, jb, PACK_NR );
351
352
      for ( int j  = get<0>( LooppkB ), jp  = get<0>( PackpkB );
353
                j  < get<1>( LooppkB );
354
                j += get<2>( LooppkB ), jp += get<2>( PackpkB ) )
355
      {
356
        /** packB and typecast from TB to TPACKB  */
357
        B.Pack(
358
            k, pc, pb,
359
            n, jc + j, std::min( jb - j, NR ),
360
            &packB[ jp * pb ] );
361
      }
362
      Comm5th.Barrier();
363
364
365
      /** Comm4th is used inside the 6th loop (i.e. pc loop) */
366
      for ( int ic  = get<0>( Loop4th );
367
                ic  < get<1>( Loop4th );
368
                ic += get<2>( Loop4th ) )
369
      {
370
        auto &ic_comm = *thread.ic_comm;
371
        auto ib = std::min( m - ic, MC );
372
        auto LooppkA = Comm4th.DistributeOver1DThreads( 0, ib, MR );
373
        auto PackpkA = Comm4th.DistributeOver1DThreads( 0, ib, PACK_MR );
374
375
        for ( int i  = get<0>( LooppkA ), ip  = get<0>( PackpkA );
376
                  i  < get<1>( LooppkA );
377
                  i += get<2>( LooppkA ), ip += get<2>( PackpkA ) )
378
        {
379
          /** packA and typecast from TA to TPACKA  */
380
          A.Pack(
381
              m, ic + i, std::min( ib - i, MR ),
382
              k, pc, pb,
383
              &packA[ ip * pb ] );
384
        }
385
        Comm4th.Barrier();
386
387
        if ( is_the_last_pc_iteration )                    // fused_macro_kernel
388
        {
389
          fused_macro_kernel<KC>
390
          (
391
            Comm4th,
392
            m, n,
393
            ic, jc, pc,
394
            ib, jb, pb,
395
            packA,
396
            packB,
397
            &C,
398
            V + ic * rs_v + jc * cs_v, rs_v, cs_v,
399
            batchId,
400
            microkernel
401
          );
402
403
        }
404
        else                                               // semiring rank-k update
405
        {
406
          rank_k_macro_kernel<KC>
407
          (
408
            Comm4th,
409
            ic, jc, pc,
410
            ib, jb, pb,
411
            packA,
412
            packB,
413
            V + ic * rs_v + jc * cs_v, rs_v, cs_v,
414
            semiringkernel
415
          );
416
        }
417
        Comm4th.Barrier();
418
      }                                                    // end 4th loop
419
      Comm5th.Barrier();
420
    }                                                      // end 5th loop
421
    Comm6th.Barrier();
422
  }                                                        // end 6th loop
423
  CommGLB.Barrier();
424
425
  /** Free packing buffer */
426
  Comm4th.FreeSharedMemory( packA );
427
  Comm5th.FreeSharedMemory( packB );
428
429
}; /** end gnbx_internal() */
430
431
432
433
434
435
/**
436
 *  @breif This is the main routine of gkmx. All packing buffers are
437
 *         managed here. The communicator and the parallel section
438
 *         start here.
439
 *
440
 */
441
template<
442
  int MC, int NC, int KC,
443
  typename TPACKA, typename TPACKB, typename TV,
444
  typename     TA, typename     TB, typename TC,
445
  typename SEMIRINGKERNEL, typename MICROKERNEL>
446
void gnbx
447
(
448
  int batchId, int m, int n, int k,
449
  TA& A,
450
  TB& B,
451
  TC& C,
452
  SEMIRINGKERNEL semiringkernel,
453
  MICROKERNEL microkernel
454
)
455
{
456
  const static int MR         = SEMIRINGKERNEL::mr;
457
  const static int NR         = SEMIRINGKERNEL::nr;
458
  const static int PACK_MR    = SEMIRINGKERNEL::pack_mr;
459
  const static int PACK_NR    = SEMIRINGKERNEL::pack_nr;
460
  const static int ALIGN_SIZE = SEMIRINGKERNEL::align_size;
461
  const static int PACK_MC    = ( MC / MR ) * PACK_MR;
462
  const static int PACK_NC    = ( NC / NR ) * PACK_NR;
463
  const static bool USE_STRASSEN = false;
464
465
  /** Early return if possible */
466
  if ( m == 0 || n == 0 || k == 0 ) return;
467
468
469
  TV *V = NULL;
470
  int rs_v = 0;
471
  int cs_v = 0;
472
473
474
  if ( k > KC && !is_same<TC, MatrixLike<PACK_MR, TV, TV>>::value )
475
  {
476
    //printf( "here m %d n %d\n", m, n );
477
    V = hmlp_malloc<ALIGN_SIZE, TV>( m * n );
478
    rs_v = 1;
479
    cs_v = m;
480
  }
481
  else
482
  {
483
    /** Directly use C for intermediate semiring rank-k update */
484
    V = reinterpret_cast<TV*>( C.X );
485
    rs_v = C.rs;
486
    cs_v = C.cs;
487
  }
488
489
490
  int k_stra = 0;
491
  if ( USE_STRASSEN )
492
  {
493
    assert( typeid(TPACKA) == typeid(TPACKB) );
494
    assert( typeid(TC) == typeid(TV) );
495
    k_stra = k - k % KC;
496
497
    if ( k_stra == k ) k_stra -= KC;
498
  }
499
500
  int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1;
501
  if ( omp_get_num_threads() == 1 && omp_get_max_threads() > 1 )
502
  {
503
    /** Check the environment variable. */
504
    jc_nt = hmlp_read_nway_from_env( "KS_JC_NT" );
505
    ic_nt = hmlp_read_nway_from_env( "KS_IC_NT" );
506
    jr_nt = hmlp_read_nway_from_env( "KS_JR_NT" );
507
  }
508
509
  /** allocate tree communicator */
510
  thread_communicator my_comm( jc_nt, pc_nt, ic_nt, jr_nt );
511
512
  #pragma omp parallel num_threads( my_comm.GetNumThreads() )
513
  {
514
    Worker thread( &my_comm );
515
516
    /** TODO:  */
517
    thread.InitWithCommunicator( &my_comm, omp_get_thread_num(), 0 );
518
519
    //if ( USE_STRASSEN )
520
    //{
521
    //  strassen::strassen_internal
522
    //  <MC, NC, KC, MR, NR,
523
    //  PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
524
    //  USE_STRASSEN,
525
    //  SEMIRINGKERNEL, SEMIRINGKERNEL,
526
    //  TA, TPACKA, TB, TPACKB, TC, TV>
527
    //  (
528
    //    thread,
529
    //    m, n, k_stra,
530
    //    A, packakernel,
531
    //    B, packbkernel,
532
    //    V, ldv,
533
    //    semiringkernel, semiringkernel,
534
    //    nc, pack_nc,
535
    //    packA_buff,
536
    //    packB_buff
537
    //  );
538
    //}
539
540
    gnbx_internal<MC, NC, KC, TPACKA, TPACKB>
541
    (
542
      thread,
543
      batchId, m, n, k, k_stra,
544
      A,
545
      B,
546
      C,
547
      V, rs_v, cs_v,
548
      semiringkernel, microkernel
549
    );
550
  }                                                        // end omp parallel
551
552
  if ( k > KC && !is_same<TC, MatrixLike<PACK_MR, TV, TV>>::value )
553
  {
554
    hmlp_free( V );
555
  }
556
};                                                         // end gkmx
557
558
559
560
561
562
/**
563
 *  @beief
564
 */
565
template<
566
  int MR, int NR, int MC, int NC, int KC,
567
  typename TPACKA, typename TPACKB, typename TPACKC, typename TV,
568
  typename     TA, typename     TB, typename     TC,
569
  typename OPKERNEL, typename OP1, typename OP2>
570
void gnbx
571
(
572
  int batchId, int m, int n, int k,
573
  TA& A,
574
  TB& B,
575
  TC& C,
576
  OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV
577
)
578
{
579
  semiring_mrxnr<MR, NR, OP1, OP2, TPACKA, TPACKB, TV, TV> semiringkernel;
580
  gnbx_mrxnr<MR, NR, OPKERNEL, OP1, OP2, TPACKA, TPACKB, TC, TPACKC, TV> gkrmkernel;
581
582
  semiringkernel.op1 = op1;
583
  semiringkernel.op2 = op2;
584
  semiringkernel.initV = initV;
585
586
  gkrmkernel.op1 = op1;
587
  gkrmkernel.op2 = op2;
588
  gkrmkernel.opkernel = opkernel;
589
  gkrmkernel.initV = initV;
590
591
  gnbx<MC, NC, KC, TPACKA, TPACKB, TV>
592
    ( batchId, m, n, k, A, B, C, semiringkernel, gkrmkernel );
593
594
}; /** end gnbx() */
595
596
}; /** end namespace gnbx */
597
}; /** end namespace hmlp */
598
599
#endif /** define GNBX_HPP */