GCC Code Coverage Report
Directory: . Exec Total Coverage
File: frame/primitives/gsknn.hpp Lines: 0 31 0.0 %
Date: 2019-01-14 Branches: 0 50 0.0 %

Line Exec Source
1
/**
2
 *  HMLP (High-Performance Machine Learning Primitives)
3
 *
4
 *  Copyright (C) 2014-2017, The University of Texas at Austin
5
 *
6
 *  This program is free software: you can redistribute it and/or modify
7
 *  it under the terms of the GNU General Public License as published by
8
 *  the Free Software Foundation, either version 3 of the License, or
9
 *  (at your option) any later version.
10
 *
11
 *  This program is distributed in the hope that it will be useful,
12
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
 *  GNU General Public License for more details.
15
 *
16
 *  You should have received a copy of the GNU General Public License
17
 *  along with this program. If not, see the LICENSE file.
18
 *
19
 **/
20
21
22
23
#ifndef GSKNN_HXX
24
#define GSKNN_HXX
25
26
#include <math.h>
27
#include <vector>
28
29
#include <hmlp.h>
30
#include <hmlp_internal.hpp>
31
#include <hmlp_base.hpp>
32
33
/** for USE_STRASSEN */
34
#include <primitives/strassen.hpp>
35
36
namespace hmlp
37
{
38
namespace gsknn
39
{
40
41
#define min( i, j ) ( (i)<(j) ? (i): (j) )
42
43
/**
44
 *
45
 */
46
template<
47
  int KC, int MR, int NR, int PACK_MR, int PACK_NR,
48
  typename SEMIRINGKERNEL,
49
  typename TA, typename TB, typename TC, typename TV>
50
void rank_k_macro_kernel
51
(
52
  Worker &thread,
53
  int ic, int jc, int pc,
54
  int  m, int n,  int  k,
55
  TA *packA,
56
  TB *packB,
57
  TC *packC, int ldc,
58
  SEMIRINGKERNEL semiringkernel
59
)
60
{
61
  thread_communicator &ic_comm = *thread.ic_comm;
62
63
  auto loop3rd = GetRange( 0, n,      NR, thread.jr_id, ic_comm.GetNumThreads() );
64
  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
65
  auto loop2nd = GetRange( 0, m,      MR );
66
  auto pack2nd = GetRange( 0, m, PACK_MR );
67
68
  for ( int j   = loop3rd.beg(), jp  = pack3rd.beg();
69
            j   < loop3rd.end();
70
            j  += loop3rd.inc(), jp += pack3rd.inc() )     // beg 3rd loop
71
  {
72
    struct aux_s<TA, TB, TC, TV> aux;
73
    aux.pc       = pc;
74
    aux.b_next   = packB;
75
    aux.do_packC = 0;
76
    aux.jb       = min( n - j, NR );
77
78
    for ( int i  = loop2nd.beg(), ip  = pack2nd.beg();
79
              i  < loop2nd.end();
80
              i += loop2nd.inc(), ip += pack2nd.inc() )    // beg 2nd loop
81
    {
82
      aux.ib = min( m - i, MR );
83
      if ( i + MR >= m )
84
      {
85
        aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
86
      }
87
88
      if ( aux.jb == NR && aux.ib == MR )
89
      {
90
        semiringkernel
91
        (
92
          k,
93
          &packA[ ip * k ],
94
          &packB[ jp * k ],
95
          &packC[ j * ldc + i ], 1, ldc,
96
          &aux
97
        );
98
      }
99
      else
100
      {
101
        double c[ MR * NR ] __attribute__((aligned(32)));
102
        double *cbuff = c;
103
        if ( pc ) {
104
          for ( auto jj = 0; jj < aux.jb; jj ++ )
105
            for ( auto ii = 0; ii < aux.ib; ii ++ )
106
              cbuff[ jj * MR + ii ] = packC[ ( j + jj ) * ldc + i + ii ];
107
        }
108
        semiringkernel
109
        (
110
          k,
111
          &packA[ ip * k ],
112
          &packB[ jp * k ],
113
          cbuff, 1, MR,
114
          &aux
115
        );
116
        for ( auto jj = 0; jj < aux.jb; jj ++ )
117
          for ( auto ii = 0; ii < aux.ib; ii ++ )
118
            packC[ ( j + jj ) * ldc + i + ii ] = cbuff[ jj * MR + ii ];
119
      }
120
    }                                                      // end 2nd loop
121
  }                                                        // end 3rd loop
122
}                                                          // end rank_k_macro_kernel
123
124
/**
125
 *
126
 */
127
template<
128
  int KC, int MR, int NR, int PACK_MR, int PACK_NR,
129
  typename MICROKERNEL,
130
  typename TA, typename TB, typename TC, typename TV>
131
void fused_macro_kernel
132
(
133
  Worker &thread,
134
  int pc,
135
  int  m,  int n,  int k,  int r,
136
  TA *packA, TA *packA2,
137
  TB *packB, TB *packB2,
138
  int *bmap,
139
  TV *D,  int *I,  int ldr,
140
  TC *packC, int ldc,
141
  MICROKERNEL microkernel
142
)
143
{
144
  double c[ MR * NR ] __attribute__((aligned(32)));
145
  double *cbuff = c;
146
  thread_communicator &ic_comm = *thread.ic_comm;
147
148
  auto loop3rd = GetRange( 0, n,      NR, thread.jr_id, ic_comm.GetNumThreads() );
149
  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
150
  auto loop2nd = GetRange( 0, m,      MR );
151
  auto pack2nd = GetRange( 0, m, PACK_MR );
152
153
  for ( int j   = loop3rd.beg(), jp  = pack3rd.beg();
154
            j   < loop3rd.end();
155
            j  += loop3rd.inc(), jp += pack3rd.inc() )     // beg 3rd loop
156
  {
157
    struct aux_s<TA, TB, TC, TV> aux;
158
    aux.pc       = pc;
159
    aux.b_next   = packB;
160
    //aux.ldr      = ldr;
161
    aux.jb       = min( n - j, NR );
162
163
    for ( int i  = loop2nd.beg(), ip  = pack2nd.beg();
164
              i  < loop2nd.end();
165
              i += loop2nd.inc(), ip += pack2nd.inc() )    // beg 2nd loop
166
    {
167
      aux.ib = min( m - i, MR );
168
      //aux.I  = I + i * ldr;
169
      //aux.D  = D + i * ldr;
170
      if ( i + MR >= m )
171
      {
172
        aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
173
      }
174
      if ( pc ) {
175
        for ( auto jj = 0; jj < aux.jb; jj ++ )
176
          for ( auto ii = 0; ii < aux.ib; ii ++ )
177
            cbuff[ jj * MR + ii ] = packC[ ( j + jj ) * ldc + i + ii ];
178
      }
179
      microkernel
180
      (
181
        k, r,
182
        packA  + ip * k, packA2 + ip,
183
        packB  + jp * k, packB2 + jp, bmap + j,
184
        cbuff,
185
        D + i * ldr, I + i * ldr, ldr,
186
        &aux
187
      );
188
      if ( pc ) {
189
        for ( auto jj = 0; jj < aux.jb; jj ++ )
190
          for ( auto ii = 0; ii < aux.ib; ii ++ )
191
            packC[ ( j + jj ) * ldc + i + ii ] = cbuff[ jj * MR + ii ];
192
      }
193
    }                                                      // end 2nd loop
194
  }                                                        // end 3rd loop
195
}                                                          // end fused_macro_kernel
196
197
198
/**
199
 *
200
 */
201
template<
202
  int MC, int NC, int KC, int MR, int NR,
203
  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
204
  bool USE_STRASSEN,
205
  typename SEMIRINGKERNEL, typename MICROKERNEL,
206
  typename TA, typename TB, typename TC, typename TV>
207
void gsknn_internal
208
(
209
  Worker &thread,
210
  int m, int n, int k, int k_stra, int r,
211
  TA *A, TA *A2, int *amap,
212
  TB *B, TB *B2, int *bmap,
213
  TV *D,         int *I,
214
  SEMIRINGKERNEL semiringkernel,
215
  MICROKERNEL microkernel,
216
  TA *packA, TA *packA2,
217
  TB *packB, TB *packB2,
218
  TC *packC, int ldpackc, int padn,
219
  int ldr
220
)
221
{
222
223
  packA  += ( thread.jc_id * thread.ic_nt                ) * PACK_MC * KC
224
          + ( thread.ic_id                               ) * PACK_MC * KC;
225
  packA2 += ( thread.jc_id * thread.ic_nt + thread.ic_id ) * PACK_MC;
226
  packB  += ( thread.jc_id                               ) * PACK_NC * KC;
227
  packB2 += ( thread.jc_id                               ) * PACK_NC;
228
229
  auto loop6th = GetRange( 0, n, NC );
230
  auto loop5th = GetRange( k_stra, k, KC );
231
  auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
232
233
  for ( int jc  = loop6th.beg();
234
            jc  < loop6th.end();
235
            jc += loop6th.inc() )                          // beg 6th loop
236
  {
237
    auto jb = min( n - jc, NC );
238
239
    for ( int pc  = loop5th.beg();
240
              pc  < loop5th.end();
241
              pc += loop5th.inc() )
242
    {
243
      auto &pc_comm = *thread.pc_comm;
244
      auto pb = min( k - pc, KC );
245
      auto is_the_last_pc_iteration = ( pc + KC >= k );
246
247
      auto looppkB = GetRange( 0, jb,      NR, thread.ic_jr, pc_comm.GetNumThreads() );
248
      auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
249
250
      for ( int j   = looppkB.beg(), jp  = packpkB.beg();
251
                j   < looppkB.end();
252
                j  += looppkB.inc(), jp += packpkB.inc() )
253
      {
254
        pack2D<true, PACK_NR>                              // packB
255
        (
256
          min( jb - j, NR ), pb,
257
          &B[ pc ], k, &bmap[ jc + j ], &packB[ jp * pb ]
258
        );
259
260
261
        if ( is_the_last_pc_iteration )
262
        {
263
264
          pack2D<true, PACK_NR>                           // packB2
265
          (
266
            min( jb - j, NR ), 1,
267
            &B2[ 0 ], 1, &bmap[ jc + j ], &packB2[ jp * 1 ]
268
          );
269
270
271
        }
272
      }
273
      pc_comm.Barrier();
274
275
      for ( int ic  = loop4th.beg();
276
                ic  < loop4th.end();
277
                ic += loop4th.inc() )                      // beg 4th loop
278
      {
279
        auto &ic_comm = *thread.ic_comm;
280
        auto ib = min( m - ic, MC );
281
282
        auto looppkA = GetRange( 0, ib,      MR, thread.jr_id, 1 );
283
        auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, 1 );
284
285
        for ( int i   = looppkA.beg(), ip  = packpkA.beg();
286
                  i   < looppkA.end();
287
                  i  += looppkA.inc(), ip += packpkA.inc() )
288
        {
289
          pack2D<true, PACK_MR>                            // packA
290
          (
291
            min( ib - i, MR ), pb,
292
            &A[ pc ], k, &amap[ ic + i ], &packA[ ip * pb ]
293
          );
294
295
          if ( is_the_last_pc_iteration )
296
          {
297
            pack2D<true, PACK_MR>                        // packA2
298
            (
299
              min( ib - i, MR ), 1,
300
              &A2[ 0 ], 1, &amap[ ic + i ], &packA2[ ip * 1 ]
301
            );
302
303
          }
304
        }
305
306
307
        ic_comm.Barrier();
308
        if ( pc + KC  < k )
309
        {
310
          rank_k_macro_kernel
311
          <KC, MR, NR, PACK_MR, PACK_NR, SEMIRINGKERNEL, TA, TB, TC, TV>
312
          (
313
            thread,
314
            ic, jc, pc,
315
            ib, jb, pb,
316
            packA,
317
            packB,
318
            packC + jc * ldpackc + ic,
319
            ldpackc,
320
            semiringkernel
321
          );
322
        }
323
        else
324
        {
325
          fused_macro_kernel
326
          <KC, MR, NR, PACK_MR, PACK_NR, MICROKERNEL, TA, TB, TC, TV>
327
          (
328
            thread,
329
            pc,
330
            ib, jb, pb, r,
331
            packA, packA2,
332
            packB, packB2, bmap + jc,
333
            D + ic * ldr,  I + ic * ldr,  ldr,
334
            packC + jc * ldpackc + ic,
335
            ldpackc,
336
            microkernel
337
          );
338
        }
339
340
        ic_comm.Barrier();                                 // sync all jr_id!!
341
342
      }                                                    // end 4th loop
343
      pc_comm.Barrier();
344
    }                                                      // end 5th loop
345
  }                                                        // end 6th loop
346
}                                                          // end gsknn_internal
347
348
349
350
351
352
/**
353
 *
354
 */
355
template<
356
  int MC, int NC, int KC, int MR, int NR,
357
  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
358
  bool USE_STRASSEN,
359
  typename SEMIRINGKERNEL, typename MICROKERNEL,
360
  typename TA, typename TB, typename TC, typename TV>
361
void gsknn(
362
    int m, int n, int k, int r,
363
    TA *A, TA *A2, int *amap,
364
    TB *B, TB *B2, int *bmap,
365
    TV *D,         int *I,
366
    SEMIRINGKERNEL semiringkernel,
367
    MICROKERNEL microkernel
368
    )
369
{
370
  int ic_nt = 1;
371
  int k_stra = 0;
372
  int ldpackc = 0, padn = 0;
373
  int ldr = 0;
374
  char *str;
375
376
  TA *packA_buff = NULL, *packA2_buff = NULL;
377
  TB *packB_buff = NULL, *packB2_buff = NULL;
378
  TC *packC_buff = NULL;
379
380
  // Early return if possible
381
  if ( m == 0 || n == 0 || k == 0 ) return;
382
383
  // Check the environment variable.
384
  str = getenv( "KS_IC_NT" );
385
  if ( str ) ic_nt = (int)strtol( str, NULL, 10 );
386
387
  ldpackc = m;
388
  ldr = r;
389
390
  // allocate packing memory
391
  packA_buff  = hmlp_malloc<ALIGN_SIZE, TA>( KC, ( PACK_MC + 1 ) * ic_nt,         sizeof(TA) );
392
  packB_buff  = hmlp_malloc<ALIGN_SIZE, TB>( KC, ( PACK_NC + 1 ),                 sizeof(TB) );
393
  packA2_buff = hmlp_malloc<ALIGN_SIZE, TA>(  1, ( PACK_MC + 1 ) * ic_nt,         sizeof(TA) );
394
  packB2_buff = hmlp_malloc<ALIGN_SIZE, TB>(  1, ( PACK_NC + 1 ),                 sizeof(TB) );
395
  if ( k > KC ) {
396
    packC_buff = hmlp_malloc<ALIGN_SIZE, TC>(  m, n, sizeof(TC) );
397
  }
398
399
  // allocate tree communicator
400
  thread_communicator my_comm( 1, 1, ic_nt, 1 );
401
402
  if ( USE_STRASSEN )
403
  {
404
    k_stra = k - k % KC;
405
406
    if ( k_stra == k ) k_stra -= KC;
407
408
    if ( k_stra )
409
    {
410
      #pragma omp parallel for
411
      for ( int i = 0; i < m * n; i ++ ) packC_buff[ i ] = 0.0;
412
    }
413
414
  }
415
416
  #pragma omp parallel num_threads( my_comm.GetNumThreads() )
417
  {
418
    Worker thread( &my_comm );
419
420
    if ( USE_STRASSEN && k > KC )
421
    {
422
      strassen::strassen_internal
423
      <MC, NC, KC, MR, NR,
424
      PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
425
      USE_STRASSEN,
426
      SEMIRINGKERNEL, SEMIRINGKERNEL,
427
      TA, TB, TC, TV>
428
      (
429
        thread,
430
        HMLP_OP_T, HMLP_OP_N,
431
        m, n, k_stra,
432
        A, k, amap,
433
        B, k, bmap,
434
        packC_buff, ldpackc,
435
        semiringkernel, semiringkernel,
436
        NC, PACK_NC,
437
        packA_buff,
438
        packB_buff
439
      );
440
    }
441
442
    gsknn_internal
443
    <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
444
    USE_STRASSEN,
445
    SEMIRINGKERNEL, MICROKERNEL,
446
    TA, TB, TC, TB>
447
    (
448
      thread,
449
      m, n, k, k_stra, r,
450
      A, A2, amap,
451
      B, B2, bmap,
452
      D,     I,
453
      semiringkernel, microkernel,
454
      packA_buff, packA2_buff,
455
      packB_buff, packB2_buff,
456
      packC_buff, ldpackc, padn,
457
      ldr
458
    );
459
460
  }                                                        // end omp region
461
462
  hmlp_free( packA_buff );
463
  hmlp_free( packB_buff );
464
  hmlp_free( packA2_buff );
465
  hmlp_free( packB2_buff );
466
  hmlp_free( packC_buff );
467
}                                                          // end gsknn
468
469
470
/**
471
 *
472
 */
473
template<typename T>
474
void gsknn_ref
475
(
476
  int m, int n, int k, int r,
477
  T *A, T *A2, int *amap,
478
  T *B, T *B2, int *bmap,
479
  T *D,        int *I
480
)
481
{
482
  int    i, j, p;
483
  double beg, time_collect, time_dgemm, time_square, time_heap;
484
  std::vector<T> packA, packB, C;
485
  double fneg2 = -2.0, fzero = 0.0, fone = 1.0;
486
487
  // Early return if possible
488
  if ( m == 0 || n == 0 || k == 0 ) return;
489
490
  packA.resize( k * m );
491
  packB.resize( k * n );
492
  C.resize( m * n );
493
494
  // Collect As from A and B.
495
  beg = omp_get_wtime();
496
  #pragma omp parallel for private( p )
497
  for ( i = 0; i < m; i ++ ) {
498
    for ( p = 0; p < k; p ++ ) {
499
      packA[ i * k + p ] = A[ amap[ i ] * k + p ];
500
    }
501
  }
502
  #pragma omp parallel for private( p )
503
  for ( j = 0; j < n; j ++ ) {
504
    for ( p = 0; p < k; p ++ ) {
505
      packB[ j * k + p ] = B[ bmap[ j ] * k + p ];
506
    }
507
  }
508
  time_collect = omp_get_wtime() - beg;
509
510
  // Compute the inner-product term.
511
  beg = omp_get_wtime();
512
  #ifdef USE_BLAS
513
    xgemm
514
    (
515
      "T", "N",
516
      m, n, k,
517
      fone,         packA.data(), k,
518
                    packB.data(), k,
519
      fzero,        C.data(),     m
520
    );
521
  #else
522
    #pragma omp parallel for private( i, p )
523
    for ( j = 0; j < n; j ++ ) {
524
      for ( i = 0; i < m; i ++ ) {
525
        C[ j * m + i ] = 0.0;
526
        for ( p = 0; p < k; p ++ ) {
527
          C[ j * m + i ] += packA[ i * k + p ] * packB[ j * k + p ];
528
        }
529
      }
530
    }
531
  #endif
532
  time_dgemm = omp_get_wtime() - beg;
533
534
  beg = omp_get_wtime();
535
  #pragma omp parallel for private( i )
536
  for ( j = 0; j < n; j ++ )
537
  {
538
    for ( i = 0; i < m; i ++ )
539
    {
540
      C[ j * m + i ] *= -2.0;
541
      C[ j * m + i ] += A2[ amap[ i ] ];
542
      C[ j * m + i ] += B2[ bmap[ j ] ];
543
    }
544
  }
545
  time_square = omp_get_wtime() - beg;
546
547
  // Pure C Max Heap implementation.
548
  beg = omp_get_wtime();
549
  #pragma omp parallel for schedule( dynamic )
550
  for ( j = 0; j < n; j ++ )
551
  {
552
    heap_select<T>( m, r, &C[ j * m ], amap, &D[ j * r ], &I[ j * r ] );
553
  }
554
  time_heap = omp_get_wtime() - beg;
555
556
} // end void gsknn_ref
557
558
559
}; // end namespace gsknn
560
}; // end namespace hmlp
561
562
#endif // define GSKNN_HXX