GCC Code Coverage Report
Directory: . Exec Total Coverage
File: frame/primitives/strassen.hpp Lines: 0 163 0.0 %
Date: 2019-01-14 Branches: 0 68 0.0 %

Line Exec Source
1
/**
2
 *  HMLP (High-Performance Machine Learning Primitives)
3
 *
4
 *  Copyright (C) 2014-2017, The University of Texas at Austin
5
 *
6
 *  This program is free software: you can redistribute it and/or modify
7
 *  it under the terms of the GNU General Public License as published by
8
 *  the Free Software Foundation, either version 3 of the License, or
9
 *  (at your option) any later version.
10
 *
11
 *  This program is distributed in the hope that it will be useful,
12
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
 *  GNU General Public License for more details.
15
 *
16
 *  You should have received a copy of the GNU General Public License
17
 *  along with this program. If not, see the LICENSE file.
18
 *
19
 **/
20
21
22
#ifndef STRASSEN_HPP
23
#define STRASSEN_HPP
24
25
#define STRAPRIM( A0,A1,gamma,B0,B1,delta,C0,C1,alpha0,alpha1 ) \
26
    straprim \
27
    <MC, NC, KC, MR, NR,  \
28
    PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, \
29
    USE_STRASSEN, \
30
    STRA_SEMIRINGKERNEL, STRA_MICROKERNEL, \
31
    TA, TB, TC, TB> \
32
    ( \
33
      thread, \
34
      transA, transB, \
35
      md, nd, kd, \
36
      A0, A1, lda, gamma, \
37
      B0, B1, ldb, delta, \
38
      C0, C1, ldc, alpha0, alpha1, \
39
      stra_semiringkernel, stra_microkernel, \
40
      nc, pack_nc, \
41
      packA_buff, \
42
      packB_buff \
43
    ); \
44
45
#define STRAPRIM_MAP( A0,A1,gamma,B0,B1,delta,C0,C1,alpha0,alpha1 ) \
46
    straprim \
47
    <MC, NC, KC, MR, NR,  \
48
    PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, \
49
    USE_STRASSEN, \
50
    STRA_SEMIRINGKERNEL, STRA_MICROKERNEL, \
51
    TA, TB, TC, TB> \
52
    ( \
53
      thread, \
54
      transA, transB, \
55
      md, nd, kd, \
56
      A0, A1, lda, gamma, amap, \
57
      B0, B1, ldb, delta, bmap, \
58
      C0, C1, ldc, alpha0, alpha1, \
59
      stra_semiringkernel, stra_microkernel, \
60
      nc, pack_nc, \
61
      packA_buff, \
62
      packB_buff \
63
    ); \
64
65
#include <hmlp.h>
66
#include <hmlp_internal.hpp>
67
#include <hmlp_base.hpp>
68
69
namespace hmlp
70
{
71
namespace strassen
72
{
73
74
//#define min( i, j ) ( (i)<(j) ? (i): (j) )
75
76
/**
77
 *
78
 */
79
template<
80
  int KC, int MR, int NR, int PACK_MR, int PACK_NR,
81
  typename SEMIRINGKERNEL,
82
  typename TA, typename TB, typename TC, typename TV>
83
void rank_k_macro_kernel
84
(
85
  Worker &thread,
86
  int ic, int jc, int pc,
87
  int  m, int n,  int  k,
88
  TA *packA,
89
  TB *packB,
90
  TV *C0, TV *C1, int ldc, TV alpha0, TV alpha1,
91
  SEMIRINGKERNEL semiringkernel
92
)
93
{
94
  thread_communicator &ic_comm = *thread.ic_comm;
95
96
  auto loop3rd = GetRange( 0, n,      NR, thread.jr_id, ic_comm.GetNumThreads() );
97
  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
98
  auto loop2nd = GetRange( 0, m,      MR );
99
  auto pack2nd = GetRange( 0, m, PACK_MR );
100
101
  for ( int j   = loop3rd.beg(), jp  = pack3rd.beg();
102
            j   < loop3rd.end();
103
            j  += loop3rd.inc(), jp += pack3rd.inc() )     // beg 3rd loop
104
  {
105
    struct aux_s<TA, TB, TC, TV> aux;
106
    aux.pc       = pc;
107
    aux.b_next   = packB;
108
    aux.do_packC = 0;
109
    aux.jb       = std::min( n - j, NR );
110
111
    for ( int i  = loop2nd.beg(), ip  = pack2nd.beg();
112
              i  < loop2nd.end();
113
              i += loop2nd.inc(), ip += pack2nd.inc() )    // beg 2nd loop
114
    {
115
      aux.ib = std::min( m - i, MR );
116
      if ( aux.ib != MR )
117
      {
118
        aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
119
      }
120
121
      if ( aux.jb == NR && aux.ib == MR )
122
      {
123
124
        if ( alpha1 == 0 || C1 == NULL ) {
125
          TV *c_list[1], alpha_list[1];
126
          c_list[0] = &C0[ j * ldc + i ];
127
          alpha_list[0] = alpha0;
128
129
          semiringkernel
130
          (
131
            k,
132
            &packA[ ip * k ],
133
            &packB[ jp * k ],
134
            1, c_list, ldc, alpha_list,
135
            &aux
136
          );
137
138
        } else {
139
140
          TV *c_list[2], alpha_list[2];
141
          c_list[0] = &C0[ j * ldc + i ]; c_list[1] = &C1[ j * ldc + i ];
142
          alpha_list[0] = alpha0; alpha_list[1] = alpha1;
143
          semiringkernel
144
          (
145
            k,
146
            &packA[ ip * k ],
147
            &packB[ jp * k ],
148
            2, c_list, ldc, alpha_list,
149
            &aux
150
          );
151
152
        }
153
154
        //semiringkernel
155
        //(
156
        //  k,
157
        //  &packA[ ip * k ],
158
        //  &packB[ jp * k ],
159
        //  &C0[ j * ldc + i ], &C1[ j * ldc + i ], ldc, alpha0, alpha1,
160
        //  &aux
161
        //);
162
163
164
      }
165
      else                                                 // corner case
166
      {
167
168
        //printf( "Enter corner case!\n" );
169
        // TODO: this should be initC.
170
        TV ctmp[ MR * NR ] = { (TV)0.0 };
171
172
        TV *c_list[1], alpha_list[1];
173
        c_list[0] = ctmp;
174
        alpha_list[0] = 1;
175
176
        semiringkernel
177
        (
178
          k,
179
          &packA[ ip * k ],
180
          &packB[ jp * k ],
181
          //ctmp, MR,
182
          1, c_list, MR, alpha_list,
183
          &aux
184
        );
185
186
187
        ////rank_k_int_d8x4 rankk_semiringkernel;
188
        ////rankk_semiringkernel
189
        //semiringkernel
190
        //(
191
        //  k,
192
        //  &packA[ ip * k ],
193
        //  &packB[ jp * k ],
194
        //  //ctmp, MR,
195
        //  ctmp, NULL, MR, 1, 0,
196
        //  &aux
197
        //);
198
        //if ( pc )
199
        {
200
          for ( auto jj = 0; jj < aux.jb; jj ++ )
201
          {
202
            for ( auto ii = 0; ii < aux.ib; ii ++ )
203
            {
204
              C0[ ( j + jj ) * ldc + i + ii ] += alpha0 * ctmp[ jj * MR + ii ];
205
206
              if ( alpha1 != 0 && C1 != NULL ) {
207
                C1[ ( j + jj ) * ldc + i + ii ] += alpha1 * ctmp[ jj * MR + ii ];
208
              }
209
            }
210
          }
211
        }
212
        //else
213
        //{
214
        //  for ( auto jj = 0; jj < aux.jb; jj ++ )
215
        //  {
216
        //    for ( auto ii = 0; ii < aux.ib; ii ++ )
217
        //    {
218
        //      C0[ ( j + jj ) * ldc + i + ii ] = alpha0 * ctmp[ jj * MR + ii ];
219
220
        //      if ( alpha1 != 0 && C1 != NULL ) {
221
        //        C1[ ( j + jj ) * ldc + i + ii ] = alpha1 * ctmp[ jj * MR + ii ];
222
        //      }
223
        //    }
224
        //  }
225
        //}
226
      }
227
    }                                                      // end 2nd loop
228
  }                                                        // end 3rd loop
229
}                                                          // end rank_k_macro_kernel
230
231
/**
232
 *
233
 */
234
//template<int KC, int MR, int NR, int PACK_MR, int PACK_NR,
235
//    typename MICROKERNEL,
236
//    typename TA, typename TB, typename TC, typename TV>
237
//void fused_macro_kernel
238
//(
239
//  Worker &thread,
240
//  int ic, int jc, int pc,
241
//  int  m,  int n,  int k,
242
//  TA *packA,
243
//  TB *packB,
244
//  TV *C0, TV *C1, int ldc, TV alpha0, TV alpha1,
245
//  MICROKERNEL microkernel
246
//)
247
//{
248
//  thread_communicator &ic_comm = *thread.ic_comm;
249
//
250
//  auto loop3rd = GetRange( 0, n,      NR, thread.jr_id, ic_comm.GetNumThreads() );
251
//  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
252
//  auto loop2nd = GetRange( 0, m,      MR );
253
//  auto pack2nd = GetRange( 0, m, PACK_MR );
254
//
255
//  for ( int j   = loop3rd.beg(), jp  = pack3rd.beg();
256
//            j   < loop3rd.end();
257
//            j  += loop3rd.inc(), jp += pack3rd.inc() )     // beg 3rd loop
258
//  {
259
//    struct aux_s<TA, TB, TC, TV> aux;
260
//    aux.pc       = pc;
261
//    aux.b_next   = packB;
262
//    aux.do_packC = 0;
263
//    aux.jb       = std::min( n - j, NR );
264
//
265
//    for ( int i  = loop2nd.beg(), ip  = pack2nd.beg();
266
//              i  < loop2nd.end();
267
//              i += loop2nd.inc(), ip += pack2nd.inc() )    // beg 2nd loop
268
//    {
269
//      aux.ib = std::min( m - i, MR );
270
//      if ( aux.ib != MR )
271
//      {
272
//        aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
273
//      }
274
//
275
//      if ( aux.jb == NR && aux.ib == MR )
276
//      {
277
//
278
//        if ( alpha1 == 0 || C1 == NULL ) {
279
//
280
//          double *c_list[1], alpha_list[1];
281
//          c_list[0] = &C0[ j * ldc + i ];
282
//          alpha_list[0] = alpha0;
283
//
284
//          microkernel
285
//          (
286
//            k,
287
//            &packA[ ip * k ],
288
//            &packB[ jp * k ],
289
//            1, c_list, ldc, alpha_list,
290
//            &aux
291
//          );
292
//        } else {
293
//
294
//          double *c_list[2], alpha_list[2];
295
//          c_list[0] = &C0[ j * ldc + i ]; c_list[1] = &C1[ j * ldc + i ];
296
//          alpha_list[0] = alpha0; alpha_list[1] = alpha1;
297
//
298
//          microkernel
299
//          (
300
//            k,
301
//            &packA[ ip * k ],
302
//            &packB[ jp * k ],
303
//            2, c_list, ldc, alpha_list,
304
//            &aux
305
//          );
306
//
307
//        }
308
//
309
//
310
//        //microkernel
311
//        //(
312
//        //  k,
313
//        //  &packA[ ip * k ],
314
//        //  &packB[ jp * k ],
315
//        //  &C0[ j * ldc + i ], &C1[ j * ldc + i ], ldc, alpha0, alpha1,
316
//        //  &aux
317
//        //);
318
//      }
319
//      else                                                 // corner case
320
//      {
321
//        //printf( "Enter corner case!\n" );
322
//        // TODO: this should be initC.
323
//        TV ctmp[ MR * NR ] = { (TV)0.0 };
324
//
325
//        double *c_list[1], alpha_list[1];
326
//        c_list[0] = ctmp;
327
//        alpha_list[0] = 1;
328
//
329
//        microkernel
330
//        (
331
//          k,
332
//          &packA[ ip * k ],
333
//          &packB[ jp * k ],
334
//          //ctmp, MR,
335
//          1, c_list, MR, alpha_list,
336
//          &aux
337
//        );
338
//
339
//        ////rank_k_int_d8x4 rankk_microkernel;
340
//        ////rankk_microkernel
341
//        //microkernel
342
//        //(
343
//        //  k,
344
//        //  &packA[ ip * k ],
345
//        //  &packB[ jp * k ],
346
//        //  //ctmp, MR,
347
//        //  ctmp, NULL, MR, 1, 0,
348
//        //  &aux
349
//        //);
350
//
351
//        //if ( pc )
352
//        {
353
//          for ( auto jj = 0; jj < aux.jb; jj ++ )
354
//          {
355
//            for ( auto ii = 0; ii < aux.ib; ii ++ )
356
//            {
357
//              C0[ ( j + jj ) * ldc + i + ii ] += alpha0 * ctmp[ jj * MR + ii ];
358
//
359
//              if ( alpha1 != 0 && C1 != NULL ) {
360
//                C1[ ( j + jj ) * ldc + i + ii ] += alpha1 * ctmp[ jj * MR + ii ];
361
//              }
362
//            }
363
//          }
364
//        }
365
//        //else
366
//        //{
367
//        //  for ( auto jj = 0; jj < aux.jb; jj ++ )
368
//        //  {
369
//        //    for ( auto ii = 0; ii < aux.ib; ii ++ )
370
//        //    {
371
//        //      C0[ ( j + jj ) * ldc + i + ii ] = alpha0 * ctmp[ jj * MR + ii ];
372
//
373
//        //      if ( alpha1 != 0 && C1 != NULL ) {
374
//        //        C1[ ( j + jj ) * ldc + i + ii ] = alpha1 * ctmp[ jj * MR + ii ];
375
//        //      }
376
//        //    }
377
//        //  }
378
//        //}
379
//      }
380
//    }                                                      // end 2nd loop
381
//  }                                                        // end 3rd loop
382
//}                                                          // end fused_macro_kernel
383
384
385
/*
386
 *
387
 */
388
template<
389
  int MC, int NC, int KC, int MR, int NR,
390
  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
391
  bool USE_STRASSEN,
392
  typename STRA_SEMIRINGKERNEL, typename STRA_MICROKERNEL,
393
  typename TA, typename TB, typename TC, typename TV>
394
void straprim
395
(
396
  Worker &thread,
397
  hmlpOperation_t transA, hmlpOperation_t transB,
398
  int m, int n, int k,
399
  TA *A0, TA *A1, int lda, TA gamma,
400
  TB *B0, TB *B1, int ldb, TB delta,
401
  TV *C0, TV *C1, int ldc, TV alpha0, TV alpha1,
402
  STRA_SEMIRINGKERNEL stra_semiringkernel,
403
  STRA_MICROKERNEL stra_microkernel,
404
  int nc, int pack_nc,
405
  TA *packA,
406
  TB *packB
407
)
408
{
409
  //printf( "m: %d, n: %d, k: %d\n", m, n, k );
410
411
  packA  += ( thread.jc_id * thread.ic_nt                ) * PACK_MC * KC
412
          + ( thread.ic_id                               ) * PACK_MC * KC;
413
  packB  += ( thread.jc_id                               ) * pack_nc * KC;
414
415
  auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt );
416
  auto loop5th = GetRange( 0, k, KC );
417
  auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
418
419
  for ( int jc  = loop6th.beg();
420
            jc  < loop6th.end();
421
            jc += loop6th.inc() )                          // beg 6th loop
422
  {
423
    auto &jc_comm = *thread.jc_comm;
424
    auto jb = std::min( n - jc, nc );
425
426
    for ( int pc  = loop5th.beg();
427
              pc  < loop5th.end();
428
              pc += loop5th.inc() )
429
    {
430
      auto &pc_comm = *thread.pc_comm;
431
      auto pb = std::min( k - pc, KC );
432
      auto is_the_last_pc_iteration = ( pc + KC >= k );
433
      auto looppkB = GetRange( 0, jb,      NR, thread.ic_jr, pc_comm.GetNumThreads() );
434
      auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
435
436
      for ( int j   = looppkB.beg(), jp  = packpkB.beg();
437
                j   < looppkB.end();
438
                j  += looppkB.inc(), jp += packpkB.inc() )
439
      {
440
441
        //printf( "before packB\n" );
442
        if ( transB == HMLP_OP_N )
443
        {
444
445
          if ( delta == 0 || B1 == NULL ) {
446
            pack2D<true, PACK_NR>                            // packB
447
            (
448
              std::min( jb - j, NR ), pb,
449
              &B0[ ( jc + j ) * ldb + pc ], ldb, &packB[ jp * pb ]
450
            );
451
          } else {
452
453
            pack2D<true, PACK_NR>                            // packB
454
            (
455
              std::min( jb - j, NR ), pb,
456
              &B0[ ( jc + j ) * ldb + pc ], &B1[ ( jc + j ) * ldb + pc ], ldb, delta, &packB[ jp * pb ]
457
            );
458
459
          }
460
461
        }
462
        else
463
        {
464
          if ( delta == 0 || B1 == NULL ) {
465
            pack2D<false, PACK_NR>                           // packB (transB)
466
            (
467
              std::min( jb - j, NR ), pb,
468
              &B0[ pc * ldb + ( jc + j ) ], ldb, &packB[ jp * pb ]
469
            );
470
          } else {
471
472
            //printf( "before pack2D\n" );
473
            //printf( "B1[%d]=%lf\n", pc * ldb + ( jc + j ), B1[ pc * ldb + ( jc + j ) ] );
474
475
            pack2D<false, PACK_NR>                           // packB (transB)
476
            (
477
              std::min( jb - j, NR ), pb,
478
              &B0[ pc * ldb + ( jc + j ) ], &B1[ pc * ldb + ( jc + j ) ], ldb, delta, &packB[ jp * pb ]
479
            );
480
            //printf( "after pack2D\n" );
481
482
          }
483
484
        }
485
        //printf( "After packB\n" );
486
      }
487
      pc_comm.Barrier();
488
489
    //printf( "packB:\n" );
490
    //hmlp_printmatrix( 4, 1, packB, PACK_NR );
491
492
493
494
      for ( int ic  = loop4th.beg();
495
                ic  < loop4th.end();
496
                ic += loop4th.inc() )                      // beg 4th loop
497
      {
498
        auto &ic_comm = *thread.ic_comm;
499
        auto ib = std::min( m - ic, MC );
500
        auto looppkA = GetRange( 0, ib,      MR, thread.jr_id, thread.jr_nt );
501
        auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt );
502
503
        for ( int i   = looppkA.beg(), ip  = packpkA.beg();
504
                  i   < looppkA.end();
505
                  i  += looppkA.inc(), ip += packpkA.inc() )
506
        {
507
508
          //printf( "Before packA\n" );
509
510
          if ( transA == HMLP_OP_N )
511
          {
512
513
            if ( gamma == 0 || A1 == NULL ) {
514
              pack2D<false, PACK_MR>                         // packA
515
              (
516
                std::min( ib - i, MR ), pb,
517
                &A0[ pc * lda + ( ic + i ) ], lda, &packA[ ip * pb ]
518
              );
519
            } else {
520
521
            //printf( "flag1\n" );
522
              pack2D<false, PACK_MR>                         // packA
523
              (
524
                std::min( ib - i, MR ), pb,
525
                &A0[ pc * lda + ( ic + i ) ], &A1[ pc * lda + ( ic + i ) ], lda, gamma, &packA[ ip * pb ]
526
              );
527
            //printf( "flag2\n" );
528
            }
529
530
          }
531
          else
532
          {
533
534
            if ( gamma == 0 || A1 == NULL ) {
535
              pack2D<true, PACK_MR>                          // packA (transA)
536
              (
537
                std::min( ib - i, MR ), pb,
538
                &A0[ ( ic + i ) * lda + pc ], lda, &packA[ ip * pb ]
539
              );
540
            } else {
541
              pack2D<true, PACK_MR>                          // packA (transA)
542
              (
543
                std::min( ib - i, MR ), pb,
544
                &A0[ ( ic + i ) * lda + pc ], &A1[ ( ic + i ) * lda + pc ], lda, gamma, &packA[ ip * pb ]
545
              );
546
            }
547
548
          }
549
550
          //printf( "After packA\n" );
551
        }
552
        ic_comm.Barrier();
553
554
//        if ( is_the_last_pc_iteration )                    // fused_macro_kernel
555
//        {
556
//          if ( alpha1 == 0 || C1 == NULL ) {
557
//
558
//            //hmlp::gkmx::fused_macro_kernel
559
//            //<KC, MR, NR, PACK_MR, PACK_NR, RANK_MICROKERNEL, TA, TB, TC, TV>
560
//            //(
561
//            //  thread,
562
//            //  ic, jc, pc,
563
//            //  ib, jb, pb,
564
//            //  packA,
565
//            //  packB,
566
//            //  C0 + jc * ldc + ic, ldc,
567
//            //  rank_microkernel
568
//            //);
569
//
570
//            //printf( "before fused macro kernel\n" );
571
//            fused_macro_kernel
572
//            <KC, MR, NR, PACK_MR, PACK_NR, STRA_MICROKERNEL, TA, TB, TC, TV>
573
//            (
574
//              thread,
575
//              ic, jc, pc,
576
//              ib, jb, pb,
577
//              packA,
578
//              packB,
579
//              C0 + jc * ldc + ic,
580
//              NULL, ldc, alpha0, 0,
581
//              stra_microkernel
582
//            );
583
//            //printf( "after fused macro kernel\n" );
584
//
585
//          } else {
586
//            fused_macro_kernel
587
//            <KC, MR, NR, PACK_MR, PACK_NR, STRA_MICROKERNEL, TA, TB, TC, TV>
588
//            (
589
//              thread,
590
//              ic, jc, pc,
591
//              ib, jb, pb,
592
//              packA,
593
//              packB,
594
//              C0 + jc * ldc + ic,
595
//              C1 + jc * ldc + ic, ldc, alpha0, alpha1,
596
//              stra_microkernel
597
//            );
598
//          }
599
//
600
//        }
601
//        else                                               // semiring rank-k update
602
//        {
603
604
          if ( alpha1 == 0 || C1 == NULL )
605
          {
606
            //hmlp::gkmx::rank_k_macro_kernel
607
            //<KC, MR, NR, PACK_MR, PACK_NR, RANK_SEMIRINGKERNEL, TA, TB, TC, TV>
608
            //(
609
            //  thread,
610
            //  ic, jc, pc,
611
            //  ib, jb, pb,
612
            //  packA,
613
            //  packB,
614
            //  C0 + jc * ldc + ic, ldc,
615
            //  rank_semiringkernel
616
            //);
617
618
            rank_k_macro_kernel
619
            //strassen_macro_kernel
620
            <KC, MR, NR, PACK_MR, PACK_NR, STRA_SEMIRINGKERNEL, TA, TB, TC, TV>
621
            (
622
              thread,
623
              ic, jc, pc,
624
              ib, jb, pb,
625
              packA,
626
              packB,
627
              C0 + jc * ldc + ic,
628
              NULL, ldc, alpha0, 0,
629
              stra_semiringkernel
630
            );
631
632
          }
633
          else
634
          {
635
636
            rank_k_macro_kernel
637
            //strassen_macro_kernel
638
            <KC, MR, NR, PACK_MR, PACK_NR, STRA_SEMIRINGKERNEL, TA, TB, TC, TV>
639
            (
640
              thread,
641
              ic, jc, pc,
642
              ib, jb, pb,
643
              packA,
644
              packB,
645
              C0 + jc * ldc + ic,
646
              C1 + jc * ldc + ic, ldc, alpha0, alpha1,
647
              stra_semiringkernel
648
            );
649
650
          }
651
652
//        }
653
        ic_comm.Barrier();                                 // sync all jr_id!!
654
      }                                                    // end 4th loop
655
      pc_comm.Barrier();
656
    }                                                      // end 5th loop
657
  }                                                        // end 6th loop
658
}                                                          // end strassen_internal
659
660
661
662
663
/*
664
 *
665
 */
666
template<
667
  int MC, int NC, int KC, int MR, int NR,
668
  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
669
  bool USE_STRASSEN,
670
  typename STRA_SEMIRINGKERNEL, typename STRA_MICROKERNEL,
671
  typename TA, typename TB, typename TC, typename TV>
672
void straprim
673
(
674
  Worker &thread,
675
  hmlpOperation_t transA, hmlpOperation_t transB,
676
  int m, int n, int k,
677
  TA *A0, TA *A1, int lda, TA gamma, int *amap,
678
  TB *B0, TB *B1, int ldb, TB delta, int *bmap,
679
  TV *C0, TV *C1, int ldc, TV alpha0, TV alpha1,
680
  STRA_SEMIRINGKERNEL stra_semiringkernel,
681
  STRA_MICROKERNEL stra_microkernel,
682
  int nc, int pack_nc,
683
  TA *packA,
684
  TB *packB
685
)
686
{
687
  //printf( "m: %d, n: %d, k: %d\n", m, n, k );
688
689
  packA  += ( thread.jc_id * thread.ic_nt                ) * PACK_MC * KC
690
          + ( thread.ic_id                               ) * PACK_MC * KC;
691
  packB  += ( thread.jc_id                               ) * pack_nc * KC;
692
693
  auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt );
694
  auto loop5th = GetRange( 0, k, KC );
695
  auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
696
697
  for ( int jc  = loop6th.beg();
698
            jc  < loop6th.end();
699
            jc += loop6th.inc() )                          // beg 6th loop
700
  {
701
    auto &jc_comm = *thread.jc_comm;
702
    auto jb = std::min( n - jc, nc );
703
704
    for ( int pc  = loop5th.beg();
705
              pc  < loop5th.end();
706
              pc += loop5th.inc() )
707
    {
708
      auto &pc_comm = *thread.pc_comm;
709
      auto pb = std::min( k - pc, KC );
710
      auto is_the_last_pc_iteration = ( pc + KC >= k );
711
      auto looppkB = GetRange( 0, jb,      NR, thread.ic_jr, pc_comm.GetNumThreads() );
712
      auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
713
714
      for ( int j   = looppkB.beg(), jp  = packpkB.beg();
715
                j   < looppkB.end();
716
                j  += looppkB.inc(), jp += packpkB.inc() )
717
      {
718
719
        //printf( "before packB\n" );
720
        if ( transB == HMLP_OP_N )
721
        {
722
723
          if ( delta == 0 || B1 == NULL ) {
724
            // ldb == k
725
            pack2D<true, PACK_NR>                              // packB
726
            (
727
              std::min( jb - j, NR ), pb,
728
              &B0[ pc ], ldb, &bmap[ jc + j ], &packB[ jp * pb ]
729
            );
730
          } else {
731
            pack2D<true, PACK_NR>                              // packB
732
            (
733
              std::min( jb - j, NR ), pb,
734
              &B0[ pc ],  &B1[ pc ], ldb, delta, &bmap[ jc + j ], &packB[ jp * pb ]
735
            );
736
          }
737
738
        }
739
        else
740
        {
741
          if ( delta == 0 || B1 == NULL ) {
742
            pack2D<false, PACK_NR>                              // packB (transB)
743
            (
744
              std::min( jb - j, NR ), pb,
745
              &B0[ pc ], ldb, &bmap[ jc + j ], &packB[ jp * pb ]
746
            );
747
          } else {
748
            pack2D<false, PACK_NR>                              // packB (transB)
749
            (
750
              std::min( jb - j, NR ), pb,
751
              &B0[ pc ],  &B1[ pc ], ldb, delta, &bmap[ jc + j ], &packB[ jp * pb ]
752
            );
753
754
755
          }
756
757
        }
758
759
      }
760
      pc_comm.Barrier();
761
762
      for ( int ic  = loop4th.beg();
763
                ic  < loop4th.end();
764
                ic += loop4th.inc() )                      // beg 4th loop
765
      {
766
        auto &ic_comm = *thread.ic_comm;
767
        auto ib = std::min( m - ic, MC );
768
        auto looppkA = GetRange( 0, ib,      MR, thread.jr_id, thread.jr_nt );
769
        auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt );
770
771
        for ( int i   = looppkA.beg(), ip  = packpkA.beg();
772
                  i   < looppkA.end();
773
                  i  += looppkA.inc(), ip += packpkA.inc() )
774
        {
775
776
          //assert( lda == k );
777
          //For transpose cases, lda should be equal to k.
778
779
          if ( transA == HMLP_OP_N )
780
          {
781
782
            if ( gamma == 0 || A1 == NULL ) {
783
              pack2D<false, PACK_MR>                            // packA
784
              (
785
                std::min( ib - i, MR ), pb,
786
                &A0[ pc ], lda, &amap[ ic + i ], &packA[ ip * pb ]
787
              );
788
            } else {
789
              pack2D<false, PACK_MR>                            // packA
790
              (
791
                std::min( ib - i, MR ), pb,
792
                &A0[ pc ], &A1[ pc ], lda, gamma, &amap[ ic + i ], &packA[ ip * pb ]
793
              );
794
            }
795
796
          }
797
          else
798
          {
799
800
            if ( gamma == 0 || A1 == NULL ) {
801
              pack2D<true, PACK_MR>                            // packA (transA)
802
              (
803
                std::min( ib - i, MR ), pb,
804
                &A0[ pc ], lda, &amap[ ic + i ], &packA[ ip * pb ]
805
              );
806
            } else {
807
              pack2D<true, PACK_MR>                            // packA (transA)
808
              (
809
                std::min( ib - i, MR ), pb,
810
                &A0[ pc ], &A1[ pc ], lda, gamma, &amap[ ic + i ], &packA[ ip * pb ]
811
              );
812
813
            }
814
815
          }
816
817
        }
818
        ic_comm.Barrier();
819
820
//        if ( is_the_last_pc_iteration )                    // fused_macro_kernel
821
//        {
822
//          if ( alpha1 == 0 || C1 == NULL ) {
823
//
824
//            //hmlp::gkmx::fused_macro_kernel
825
//            //<KC, MR, NR, PACK_MR, PACK_NR, RANK_MICROKERNEL, TA, TB, TC, TV>
826
//            //(
827
//            //  thread,
828
//            //  ic, jc, pc,
829
//            //  ib, jb, pb,
830
//            //  packA,
831
//            //  packB,
832
//            //  C0 + jc * ldc + ic, ldc,
833
//            //  rank_microkernel
834
//            //);
835
//
836
//            //printf( "before fused macro kernel\n" );
837
//            fused_macro_kernel
838
//            <KC, MR, NR, PACK_MR, PACK_NR, STRA_MICROKERNEL, TA, TB, TC, TV>
839
//            (
840
//              thread,
841
//              ic, jc, pc,
842
//              ib, jb, pb,
843
//              packA,
844
//              packB,
845
//              C0 + jc * ldc + ic,
846
//              NULL, ldc, alpha0, 0,
847
//              stra_microkernel
848
//            );
849
//            //printf( "after fused macro kernel\n" );
850
//
851
//          } else {
852
//            fused_macro_kernel
853
//            <KC, MR, NR, PACK_MR, PACK_NR, STRA_MICROKERNEL, TA, TB, TC, TV>
854
//            (
855
//              thread,
856
//              ic, jc, pc,
857
//              ib, jb, pb,
858
//              packA,
859
//              packB,
860
//              C0 + jc * ldc + ic,
861
//              C1 + jc * ldc + ic, ldc, alpha0, alpha1,
862
//              stra_microkernel
863
//            );
864
//          }
865
//
866
//        }
867
//        else                                               // semiring rank-k update
868
//        {
869
870
          if ( alpha1 == 0 || C1 == NULL )
871
          {
872
            //hmlp::gkmx::rank_k_macro_kernel
873
            //<KC, MR, NR, PACK_MR, PACK_NR, RANK_SEMIRINGKERNEL, TA, TB, TC, TV>
874
            //(
875
            //  thread,
876
            //  ic, jc, pc,
877
            //  ib, jb, pb,
878
            //  packA,
879
            //  packB,
880
            //  C0 + jc * ldc + ic, ldc,
881
            //  rank_semiringkernel
882
            //);
883
884
            rank_k_macro_kernel
885
            //strassen_macro_kernel
886
            <KC, MR, NR, PACK_MR, PACK_NR, STRA_SEMIRINGKERNEL, TA, TB, TC, TV>
887
            (
888
              thread,
889
              ic, jc, pc,
890
              ib, jb, pb,
891
              packA,
892
              packB,
893
              C0 + jc * ldc + ic,
894
              NULL, ldc, alpha0, 0,
895
              stra_semiringkernel
896
            );
897
898
          }
899
          else
900
          {
901
902
            rank_k_macro_kernel
903
            //strassen_macro_kernel
904
            <KC, MR, NR, PACK_MR, PACK_NR, STRA_SEMIRINGKERNEL, TA, TB, TC, TV>
905
            (
906
              thread,
907
              ic, jc, pc,
908
              ib, jb, pb,
909
              packA,
910
              packB,
911
              C0 + jc * ldc + ic,
912
              C1 + jc * ldc + ic, ldc, alpha0, alpha1,
913
              stra_semiringkernel
914
            );
915
916
          }
917
918
//        }
919
        ic_comm.Barrier();                                 // sync all jr_id!!
920
      }                                                    // end 4th loop
921
      pc_comm.Barrier();
922
    }                                                      // end 5th loop
923
  }                                                        // end 6th loop
924
}                                                          // end strassen_internal
925
926
template<typename TA, typename TB, typename TV>
927
void hmlp_dynamic_peeling
928
(
929
  hmlpOperation_t transA, hmlpOperation_t transB,
930
  int m, int n, int k,
931
  TA *A, int lda,
932
  TB *B, int ldb,
933
  TV *C, int ldc,
934
  int dim1, int dim2, int dim3
935
)
936
{
937
  //printf( "Enter dynamic peeling\n" );
938
  int mr = m % dim1;
939
  int kr = k % dim2;
940
  int nr = n % dim3;
941
  int ms = m - mr;
942
  int ns = n - nr;
943
  int ks = k - kr;
944
  TA *A_extra;
945
  TB *B_extra;
946
  TV *C_extra;
947
948
  char transA_val, transB_val;
949
  char *char_transA = &transA_val, *char_transB = &transB_val;
950
951
952
  //printf( "flag d1\n" );
953
954
  // Adjust part handled by fast matrix multiplication.
955
  // Add far column of A outer product bottom row B
956
  if ( kr > 0 ) {
957
    // In Strassen, this looks like C([1, 2], [1, 2]) += A([1, 2], 3) * B(3, [1, 2])
958
959
    //printf( "flag d2\n" );
960
961
    if ( transA == HMLP_OP_N ) {
962
      A_extra = &A[ 0 + ks * lda ];//ms * kr
963
      *char_transA = 'N';
964
    } else {
965
      A_extra = &A[ 0 * lda + ks ];//ms * kr
966
      *char_transA = 'T';
967
    }
968
969
    //printf( "flag d3\n" );
970
    if ( transB == HMLP_OP_N ) {
971
      B_extra = &B[ ks + 0 * ldb ];//kr * ns
972
      *char_transB = 'N';
973
    } else {
974
      B_extra = &B[ ks * ldb + 0 ];//kr * ns
975
      *char_transB = 'T';
976
    }
977
978
    //printf( "flag d4\n" );
979
    C_extra = &C[ 0  + 0  * ldc ];//ms * ns
980
    if ( ms > 0 && ns > 0 )
981
    {
982
      //bl_dgemm( ms, ns, kr, A_extra, lda, B_extra, ldb, C_extra, ldc );
983
      xgemm( char_transA, char_transB, ms, ns, kr, 1.0, A_extra, lda, B_extra, ldb, 1.0, C_extra, ldc );
984
    }
985
  }
986
987
  //printf( "flag d5\n" );
988
989
  // Adjust for far right columns of C
990
  if ( nr > 0 ) {
991
    // In Strassen, this looks like C(:, 3) = A * B(:, 3)
992
993
    if ( transA == HMLP_OP_N ) {
994
      *char_transA = 'N';
995
    } else {
996
      *char_transA = 'T';
997
    }
998
    //printf( "flag d6\n" );
999
1000
    if ( transB == HMLP_OP_N ) {
1001
      B_extra = &B[ 0 + ns * ldb ];//k * nr
1002
      *char_transB = 'N';
1003
    } else {
1004
      B_extra = &B[ 0 * ldb + ns ];//k * nr
1005
      *char_transB = 'T';
1006
    }
1007
1008
1009
    //printf( "flag d7\n" );
1010
1011
    C_extra = &C[ 0 + ns * ldc ];//m * nr
1012
    //bl_dgemm( m, nr, k, A, lda, B_extra, ldb, C_extra, ldc );
1013
    xgemm( char_transA, char_transB, m, nr, k, 1.0,  A, lda, B_extra, ldb, 1.0, C_extra, ldc );
1014
1015
  }
1016
1017
  //printf( "flag d8\n" );
1018
1019
  // Adjust for bottom rows of C
1020
  if ( mr > 0 ) {
1021
    // In Strassen, this looks like C(3, [1, 2]) = A(3, :) * B(:, [1, 2])
1022
1023
1024
  //printf( "flag d8.1\n" );
1025
    if ( transA == HMLP_OP_N ) {
1026
1027
  //printf( "flag d8.15\n" );
1028
      A_extra = &A[ ms + 0 * lda ];// mr * k
1029
  //printf( "flag d8.16\n" );
1030
      *char_transA = 'N';
1031
1032
      //printf( "flag d8.2\n" );
1033
    } else {
1034
      A_extra = &A[ ms * lda + 0 ];// mr * k
1035
      *char_transA = 'T';
1036
      //printf( "flag d8.3\n" );
1037
    }
1038
1039
    //printf( "flag d8.4\n" );
1040
1041
    if ( transB == HMLP_OP_N ) {
1042
      B_extra = &B[ 0  + 0 * ldb ];// k  * ns
1043
      *char_transB = 'N';
1044
1045
      //printf( "flag d8.5\n" );
1046
1047
    } else {
1048
      B_extra = &B[ 0 * ldb  + 0 ];// k  * ns
1049
      *char_transB = 'T';
1050
1051
      //printf( "flag d8.6\n" );
1052
    }
1053
1054
  //printf( "flag d9\n" );
1055
1056
    TV *C_extra = &C[ ms + 0 * ldc ];// mr * ns
1057
    if ( ns > 0 )
1058
    {
1059
      //bl_dgemm( mr, ns, k, A_extra, lda, B_extra, ldb, C_extra, ldc );
1060
      xgemm( char_transA, char_transB, mr, ns, k, 1.0, A_extra, lda, B_extra, ldb, 1.0, C_extra, ldc );
1061
    }
1062
  }
1063
  //printf( "Leave dynamic peeling\n" );
1064
}
1065
1066
template<
1067
  int MC, int NC, int KC, int MR, int NR,
1068
  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
1069
  bool USE_STRASSEN,
1070
  typename STRA_SEMIRINGKERNEL, typename STRA_MICROKERNEL,
1071
  typename TA, typename TB, typename TC, typename TV>
1072
void strassen_internal
1073
(
1074
  Worker &thread,
1075
  hmlpOperation_t transA, hmlpOperation_t transB,
1076
  int m, int n, int k,
1077
  TA *A, int lda, int *amap,
1078
  TB *B, int ldb, int *bmap,
1079
  TV *C, int ldc,
1080
  STRA_SEMIRINGKERNEL stra_semiringkernel,
1081
  STRA_MICROKERNEL stra_microkernel,
1082
  int nc, int pack_nc,
1083
  TA *packA_buff,
1084
  TB *packB_buff
1085
)
1086
{
1087
1088
  int ms, ks, ns;
1089
  int md, kd, nd;
1090
  int mr, kr, nr;
1091
1092
  mr = m % ( 2 ), kr = k % ( 2 ), nr = n % ( 2 );
1093
  md = m - mr, kd = k - kr, nd = n - nr;
1094
1095
  // Partition code.
1096
  ms=md, ks=kd, ns=nd;
1097
  TA *A00, *A01, *A10, *A11;
1098
  hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 0, 0, &A00 );
1099
  hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 0, 1, &A01 );
1100
  hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 1, 0, &A10 );
1101
  hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 1, 1, &A11 );
1102
1103
  TB *B00, *B01, *B10, *B11;
1104
  hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 0, 0, &B00 );
1105
  hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 0, 1, &B01 );
1106
  hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 1, 0, &B10 );
1107
  hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 1, 1, &B11 );
1108
1109
  TV *C00, *C01, *C10, *C11;
1110
  hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 0, 0, &C00 );
1111
  hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 0, 1, &C01 );
1112
  hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 1, 0, &C10 );
1113
  hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 1, 1, &C11 );
1114
1115
  md = md / 2, kd = kd / 2, nd = nd / 2;
1116
1117
  // M1: C00 = 1*C00+1*(A00+A11)(B00+B11); C11 = 1*C11+1*(A00+A11)(B00+B11)
1118
  STRAPRIM_MAP( A00, A11, 1, B00, B11, 1, C00, C11, 1, 1 );
1119
  // M2: C10 = 1*C10+1*(A10+A11)B00; C11 = 1*C11-1*(A10+A11)B00
1120
  STRAPRIM_MAP( A10, A11, 1, B00, NULL, 0, C10, C11, 1, -1 )
1121
  // M3: C01 = 1*C01+1*A00(B01-B11); C11 = 1*C11+1*A00(B01-B11)
1122
  STRAPRIM_MAP( A00, NULL, 0, B01, B11, -1, C01, C11, 1, 1 )
1123
  // M4: C00 = 1*C00+1*A11(B10-B00); C10 = 1*C10+1*A11(B10-B00)
1124
  STRAPRIM_MAP( A11, NULL, 0, B10, B00, -1, C00, C10, 1, 1 )
1125
  // M5: C00 = 1*C00-1*(A00+A01)B11; C01 = 1*C01+1*(A00+A01)B11
1126
  STRAPRIM_MAP( A00, A01, 1, B11, NULL, 0, C00, C01, -1, 1 )
1127
  // M6: C11 = 1*C11+(A10-A00)(B00+B01)
1128
  STRAPRIM_MAP( A10, A00, -1, B00, B01, 1, C11, NULL, 1, 0 )
1129
  // M7: C00 = 1*C00+(A01-A11)(B10+B11)
1130
  STRAPRIM_MAP( A01, A11, -1, B10, B11, 1, C00, NULL, 1, 0 )
1131
1132
  if ( omp_get_thread_num() == 0 ) { //Chief thread
1133
    hmlp_dynamic_peeling( transA, transB, m, n, k, A, lda, B, ldb, C, ldc, 2, 2, 2 );
1134
  }
1135
1136
}
1137
1138
template<
1139
  int MC, int NC, int KC, int MR, int NR,
1140
  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
1141
  bool USE_STRASSEN,
1142
  typename STRA_SEMIRINGKERNEL, typename STRA_MICROKERNEL,
1143
  typename TA, typename TB, typename TC, typename TV>
1144
void strassen_internal
1145
(
1146
  Worker &thread,
1147
  hmlpOperation_t transA, hmlpOperation_t transB,
1148
  int m, int n, int k,
1149
  TA *A, int lda,
1150
  TB *B, int ldb,
1151
  TV *C, int ldc,
1152
  STRA_SEMIRINGKERNEL stra_semiringkernel,
1153
  STRA_MICROKERNEL stra_microkernel,
1154
  int nc, int pack_nc,
1155
  TA *packA_buff,
1156
  TB *packB_buff
1157
)
1158
{
1159
1160
  int ms, ks, ns;
1161
  int md, kd, nd;
1162
  int mr, kr, nr;
1163
1164
  mr = m % ( 2 ), kr = k % ( 2 ), nr = n % ( 2 );
1165
  md = m - mr, kd = k - kr, nd = n - nr;
1166
1167
  // Partition code.
1168
  ms=md, ks=kd, ns=nd;
1169
  TA *A00, *A01, *A10, *A11;
1170
  hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 0, 0, &A00 );
1171
  hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 0, 1, &A01 );
1172
  hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 1, 0, &A10 );
1173
  hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 1, 1, &A11 );
1174
1175
  TB *B00, *B01, *B10, *B11;
1176
  hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 0, 0, &B00 );
1177
  hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 0, 1, &B01 );
1178
  hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 1, 0, &B10 );
1179
  hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 1, 1, &B11 );
1180
1181
  TV *C00, *C01, *C10, *C11;
1182
  hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 0, 0, &C00 );
1183
  hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 0, 1, &C01 );
1184
  hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 1, 0, &C10 );
1185
  hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 1, 1, &C11 );
1186
1187
  md = md / 2, kd = kd / 2, nd = nd / 2;
1188
1189
  // M1: C00 = 1*C00+1*(A00+A11)(B00+B11); C11 = 1*C11+1*(A00+A11)(B00+B11)
1190
  STRAPRIM( A00, A11, 1, B00, B11, 1, C00, C11, 1, 1 );
1191
1192
  //printf( "A00:\n" );
1193
  //hmlp_printmatrix( md, kd, A00, m );
1194
  //printf( "A11:\n" );
1195
  //hmlp_printmatrix( md, kd, A11, m );
1196
  //printf( "B00:\n" );
1197
  //hmlp_printmatrix( kd, nd, B00, k );
1198
  //printf( "B11:\n" );
1199
  //hmlp_printmatrix( kd, nd, B11, k );
1200
  //printf( "C00:\n" );
1201
  //hmlp_printmatrix( md, nd, C00, m );
1202
  //printf( "C01:\n" );
1203
  //hmlp_printmatrix( md, nd, C11, m );
1204
1205
  // M2: C10 = 1*C10+1*(A10+A11)B00; C11 = 1*C11-1*(A10+A11)B00
1206
  STRAPRIM( A10, A11, 1, B00, NULL, 0, C10, C11, 1, -1 )
1207
1208
  // M3: C01 = 1*C01+1*A00(B01-B11); C11 = 1*C11+1*A00(B01-B11)
1209
  STRAPRIM( A00, NULL, 0, B01, B11, -1, C01, C11, 1, 1 )
1210
  // M4: C00 = 1*C00+1*A11(B10-B00); C10 = 1*C10+1*A11(B10-B00)
1211
  STRAPRIM( A11, NULL, 0, B10, B00, -1, C00, C10, 1, 1 )
1212
  // M5: C00 = 1*C00-1*(A00+A01)B11; C01 = 1*C01+1*(A00+A01)B11
1213
  STRAPRIM( A00, A01, 1, B11, NULL, 0, C00, C01, -1, 1 )
1214
  // M6: C11 = 1*C11+(A10-A00)(B00+B01)
1215
  STRAPRIM( A10, A00, -1, B00, B01, 1, C11, NULL, 1, 0 )
1216
  // M7: C00 = 1*C00+(A01-A11)(B10+B11)
1217
  STRAPRIM( A01, A11, -1, B10, B11, 1, C00, NULL, 1, 0 )
1218
1219
  //printf( "C00:" );
1220
  //hmlp_printmatrix( md, nd, C00, m );
1221
1222
  //printf( "before dynamic peeling\n" );
1223
1224
  if ( omp_get_thread_num() == 0 ) { //Chief thread
1225
    hmlp_dynamic_peeling( transA, transB, m, n, k, A, lda, B, ldb, C, ldc, 2, 2, 2 );
1226
  }
1227
1228
}
1229
1230
1231
/**
1232
 *
1233
 *
1234
 */
1235
template<
1236
  int MC, int NC, int KC, int MR, int NR,
1237
  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
1238
  bool USE_STRASSEN,
1239
  typename STRA_SEMIRINGKERNEL, typename STRA_MICROKERNEL,
1240
  typename TA, typename TB, typename TC, typename TV>
1241
void strassen
1242
(
1243
  hmlpOperation_t transA, hmlpOperation_t transB,
1244
  int m, int n, int k,
1245
  TA *A, int lda,
1246
  TB *B, int ldb,
1247
  TV *C, int ldc,
1248
  STRA_SEMIRINGKERNEL stra_semiringkernel,
1249
  STRA_MICROKERNEL stra_microkernel
1250
)
1251
{
1252
  int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1;
1253
  int nc = NC, pack_nc = PACK_NC;
1254
  char *str;
1255
1256
  TA *packA_buff = NULL;
1257
  TB *packB_buff = NULL;
1258
1259
  // Early return if possible
1260
  if ( m == 0 || n == 0 || k == 0 ) return;
1261
1262
  // Check the environment variable.
1263
  jc_nt = hmlp_read_nway_from_env( "KS_JC_NT" );
1264
  ic_nt = hmlp_read_nway_from_env( "KS_IC_NT" );
1265
  jr_nt = hmlp_read_nway_from_env( "KS_JR_NT" );
1266
1267
1268
  if ( jc_nt > 1 )
1269
  {
1270
    nc = ( ( n - 1 ) / ( NR * jc_nt ) + 1 ) * NR;
1271
    pack_nc = ( nc / NR ) * PACK_NR;
1272
  }
1273
1274
  // allocate packing memory
1275
  packA_buff  = hmlp_malloc<ALIGN_SIZE, TA>( KC, ( PACK_MC + 1 ) * jc_nt * ic_nt,         sizeof(TA) );
1276
  packB_buff  = hmlp_malloc<ALIGN_SIZE, TB>( KC, ( pack_nc + 1 ) * jc_nt,                 sizeof(TB) );
1277
1278
  // allocate tree communicator
1279
  thread_communicator my_comm( jc_nt, pc_nt, ic_nt, jr_nt );
1280
1281
  #pragma omp parallel num_threads( my_comm.GetNumThreads() )
1282
  {
1283
    Worker thread( &my_comm );
1284
1285
    strassen_internal
1286
    <MC, NC, KC, MR, NR,
1287
    PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
1288
    USE_STRASSEN,
1289
    STRA_SEMIRINGKERNEL, STRA_MICROKERNEL,
1290
    TA, TB, TC, TB>
1291
    (
1292
       thread,
1293
       transA, transB,
1294
       m, n, k,
1295
       A, lda,
1296
       B, ldb,
1297
       C, ldc,
1298
       stra_semiringkernel, stra_microkernel,
1299
       nc, pack_nc,
1300
       packA_buff,
1301
       packB_buff
1302
    );
1303
1304
  }
1305
                                                        // end omp
1306
}                                                       // end strassen
1307
1308
1309
}; // end namespace strassen
1310
}; // end namespace hmlp
1311
1312
#endif // define STRASSEN_HPP
1313
1314