GCC Code Coverage Report
Directory: . Exec Total Coverage
File: frame/primitives/conv2d.hpp Lines: 0 175 0.0 %
Date: 2019-01-14 Branches: 0 84 0.0 %

Line Exec Source
1
/**
2
 *  HMLP (High-Performance Machine Learning Primitives)
3
 *
4
 *  Copyright (C) 2014-2017, The University of Texas at Austin
5
 *
6
 *  This program is free software: you can redistribute it and/or modify
7
 *  it under the terms of the GNU General Public License as published by
8
 *  the Free Software Foundation, either version 3 of the License, or
9
 *  (at your option) any later version.
10
 *
11
 *  This program is distributed in the hope that it will be useful,
12
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
 *  GNU General Public License for more details.
15
 *
16
 *  You should have received a copy of the GNU General Public License
17
 *  along with this program. If not, see the LICENSE file.
18
 *
19
 **/
20
21
22
23
24
#ifndef CNN_HPP
25
#define CNN_HPP
26
27
#include <hmlp.h>
28
#include <hmlp_internal.hpp>
29
#include <hmlp_base.hpp>
30
31
// #define DEBUG_CONV2D 1
32
33
namespace hmlp
34
{
35
namespace cnn
36
{
37
38
/**
39
 *
40
 */
41
template<
42
  int KC, int MR, int NR, int PACK_MR, int PACK_NR,
43
  typename SEMIRINGKERNEL,
44
  typename TA, typename TB, typename TC, typename TV>
45
void rank_k_macro_kernel
46
(
47
  Worker &thread,
48
  int ic, int jc, int pc,
49
  int  m, int n,  int  k,
50
  TA *packA,
51
  TB *packB,
52
  TV *C, int ldc,
53
  SEMIRINGKERNEL semiringkernel
54
)
55
{
56
  thread_communicator &ic_comm = *thread.ic_comm;
57
58
  auto loop3rd = GetRange( 0, n,      NR, thread.jr_id, ic_comm.GetNumThreads() );
59
  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
60
  auto loop2nd = GetRange( 0, m,      MR );
61
  auto pack2nd = GetRange( 0, m, PACK_MR );
62
63
  for ( int j   = loop3rd.beg(), jp  = pack3rd.beg();
64
            j   < loop3rd.end();
65
            j  += loop3rd.inc(), jp += pack3rd.inc() )     // beg 3rd loop
66
  {
67
    struct aux_s<TA, TB, TC, TV> aux;
68
    aux.pc       = pc;
69
    aux.b_next   = packB;
70
    aux.do_packC = 0;
71
    aux.jb       = std::min( n - j, NR );
72
73
    for ( int i  = loop2nd.beg(), ip  = pack2nd.beg();
74
              i  < loop2nd.end();
75
              i += loop2nd.inc(), ip += pack2nd.inc() )    // beg 2nd loop
76
    {
77
      aux.ib = std::min( m - i, MR );
78
      if ( aux.ib != MR )
79
      {
80
        aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
81
      }
82
83
      if ( aux.jb == NR && aux.ib == MR )
84
      {
85
        semiringkernel
86
        (
87
          k,
88
          &packA[ ip * k ],
89
          &packB[ jp * k ],
90
          &C[ j * ldc + i ], 1, ldc,
91
          &aux
92
        );
93
      }
94
      else                                                 // corner case
95
      {
96
        // TODO: this should be initC.
97
        TV ctmp[ MR * NR ] = { (TV)0.0 };
98
        semiringkernel
99
        (
100
          k,
101
          &packA[ ip * k ],
102
          &packB[ jp * k ],
103
          ctmp, 1, MR,
104
          &aux
105
        );
106
        if ( pc )
107
        {
108
          for ( auto jj = 0; jj < aux.jb; jj ++ )
109
          {
110
            for ( auto ii = 0; ii < aux.ib; ii ++ )
111
            {
112
              C[ ( j + jj ) * ldc + i + ii ] += ctmp[ jj * MR + ii ];
113
            }
114
          }
115
        }
116
        else
117
        {
118
          for ( auto jj = 0; jj < aux.jb; jj ++ )
119
          {
120
            for ( auto ii = 0; ii < aux.ib; ii ++ )
121
            {
122
              C[ ( j + jj ) * ldc + i + ii ] = ctmp[ jj * MR + ii ];
123
            }
124
          }
125
        }
126
      }
127
    }                                                      // end 2nd loop
128
  }                                                        // end 3rd loop
129
}                                                          // end rank_k_macro_kernel
130
131
/**
132
 *
133
 */
134
template<
135
  int KC,
136
  int MR,
137
  int NR,
138
  int PACK_MR,
139
  int PACK_NR,
140
  typename MICROKERNEL,
141
  typename TA, typename TB, typename TC, typename TV>
142
void fused_macro_kernel
143
(
144
  Worker &thread,
145
  int ic, int jc, int pc,
146
  int  m,  int n,  int k,
147
  TA *packA,
148
  TB *packB,
149
  TV *C, int ldc,
150
  MICROKERNEL microkernel
151
)
152
{
153
  thread_communicator &ic_comm = *thread.ic_comm;
154
155
  auto loop3rd = GetRange( 0, n,      NR, thread.jr_id, ic_comm.GetNumThreads() );
156
  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
157
  auto loop2nd = GetRange( 0, m,      MR );
158
  auto pack2nd = GetRange( 0, m, PACK_MR );
159
160
  for ( int j   = loop3rd.beg(), jp  = pack3rd.beg();
161
            j   < loop3rd.end();
162
            j  += loop3rd.inc(), jp += pack3rd.inc() )     // beg 3rd loop
163
  {
164
    struct aux_s<TA, TB, TC, TV> aux;
165
    aux.pc       = pc;
166
    aux.b_next   = packB;
167
    aux.do_packC = 0;
168
    aux.jb       = std::min( n - j, NR );
169
170
    for ( int i  = loop2nd.beg(), ip  = pack2nd.beg();
171
              i  < loop2nd.end();
172
              i += loop2nd.inc(), ip += pack2nd.inc() )    // beg 2nd loop
173
    {
174
      aux.ib = std::min( m - i, MR );
175
      if ( aux.ib != MR )
176
      {
177
        aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
178
      }
179
180
      if ( aux.jb == NR && aux.ib == MR )
181
      {
182
        microkernel
183
        (
184
          k,
185
          &packA[ ip * k ],
186
          &packB[ jp * k ],
187
          &C[ j * ldc + i ], 1, ldc,
188
          &aux
189
        );
190
      }
191
      else                                                 // corner case
192
      {
193
        TV ctmp[ MR * NR ] = { (TV)0.0 };
194
        microkernel
195
        (
196
          k,
197
          &packA[ ip * k ],
198
          &packB[ jp * k ],
199
          ctmp, 1, MR,
200
          &aux
201
        );
202
203
        if ( pc )
204
        {
205
          for ( auto jj = 0; jj < aux.jb; jj ++ )
206
          {
207
            for ( auto ii = 0; ii < aux.ib; ii ++ )
208
            {
209
              C[ ( j + jj ) * ldc + i + ii ] += ctmp[ jj * MR + ii ];
210
            }
211
          }
212
        }
213
        else
214
        {
215
          for ( auto jj = 0; jj < aux.jb; jj ++ )
216
          {
217
            for ( auto ii = 0; ii < aux.ib; ii ++ )
218
            {
219
              C[ ( j + jj ) * ldc + i + ii ] = ctmp[ jj * MR + ii ];
220
            }
221
          }
222
        }
223
      }
224
    }                                                      // end 2nd loop
225
  }                                                        // end 3rd loop
226
};                                                         // end fused_macro_kernel
227
228
229
230
/*
231
 *
232
 */
233
template<
234
  int MC, int NC, int KC, int MR, int NR,
235
  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
236
  bool USE_STRASSEN,
237
  typename SEMIRINGKERNEL, typename MICROKERNEL,
238
  typename TA, typename TB, typename TC, typename TV>
239
void conv2d_internal
240
(
241
  Worker &thread,
242
  int w0, int h0, int d0, int s, int p,
243
  TB *B,
244
  int w1, int h1, int d1,
245
  TA *A,
246
  TC *C,
247
  SEMIRINGKERNEL semiringkernel,
248
  MICROKERNEL microkernel,
249
  int nc, int pack_nc,
250
  TA *packA,
251
  TB *packB
252
)
253
{
254
  packA  += ( thread.jc_id * thread.ic_nt                ) * PACK_MC * KC
255
          + ( thread.ic_id                               ) * PACK_MC * KC;
256
  packB  += ( thread.jc_id                               ) * pack_nc * KC;
257
258
259
  // Now compute parameters such that I can transform the problem into GEMM.
260
  int m = d1;
261
  int nx = ( w0 - w1 + 2 * p ) / s + 1;
262
  int ny = ( h0 - h1 + 2 * p ) / s + 1;
263
  int n = nx * ny;
264
  int k = w1 * h1 * d0;
265
266
  //auto loop6th = GetRange( HMLP_SCHEDULE_HEFT, 0, n, nc, thread.jc_id, thread.jc_nt );
267
  auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt );
268
  auto loop5th = GetRange( 0, k, KC );
269
  auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
270
271
  //printf( "tid %d beg %d end %d inc %d\n", thread.jc_id, loop6th.beg(), loop6th.end(), loop6th.inc() );
272
273
  //double my_beg = omp_get_wtime();
274
  /*
275
   *  @CHENHAN: loop over your filters.
276
   */
277
  for ( int jc  = loop6th.beg();
278
            jc  < loop6th.end();
279
            jc += loop6th.inc() )                          // beg 6th loop
280
  {
281
    auto &jc_comm = *thread.jc_comm;
282
    auto jb = std::min( n - jc, nc );
283
284
    /*
285
     *  @CHENHAN: loop over your window size ( w1 * h1 * d0 ).
286
     */
287
    for ( int pc  = loop5th.beg();
288
              pc  < loop5th.end();
289
              pc += loop5th.inc() )
290
    {
291
      auto &pc_comm = *thread.pc_comm;
292
      auto pb = std::min( k - pc, KC );
293
      auto is_the_last_pc_iteration = ( pc + KC >= k );
294
295
      /*
296
       *  @CHENHAN: pack image into packB.
297
       */
298
      auto looppkB = GetRange( 0, jb,      NR, thread.ic_jr, pc_comm.GetNumThreads() );
299
      auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
300
301
      for ( int j   = looppkB.beg(), jp  = packpkB.beg();
302
                j   < looppkB.end();
303
                j  += looppkB.inc(), jp += packpkB.inc() )
304
      {
305
        auto x0 = ( ( jc + j ) % nx ) * s - p; // top-left
306
        auto y0 = ( ( jc + j ) / nx ) * s - p; // top-left
307
308
#ifdef DEBUG_CONV2D
309
     printf( "x0 %4d y0 %4d\n", x0, y0 );
310
#endif
311
312
        pack2Dimg<PACK_NR>                            // packB
313
        (
314
          std::min( jb - j, NR ), pb,
315
          &packB[ jp  * pb ],
316
          x0, y0, pc,
317
          B,
318
          w0, h0, d0, s, p,
319
          w1, h1
320
        );
321
      }
322
      pc_comm.Barrier();
323
324
325
#ifdef DEBUG_CONV2D
326
      for ( int i = 0; i < pb; i ++ )
327
      {
328
        for ( int jj = 0; jj < jb; jj += NR )
329
        {
330
          for ( int j = 0; j < NR; j ++ )
331
          {
332
            printf( "%5.2lf ", packB[ jj * pb + i * NR + j ] );
333
          }
334
          printf( "   " );
335
        }
336
        printf( "\n" );
337
      }
338
      printf( "\n" );
339
#endif
340
341
342
      for ( int ic  = loop4th.beg();
343
                ic  < loop4th.end();
344
                ic += loop4th.inc() )                      // beg 4th loop
345
      {
346
        auto &ic_comm = *thread.ic_comm;
347
        auto ib = std::min( m - ic, MC );
348
349
        auto looppkA = GetRange( 0, ib,      MR, thread.jr_id, thread.jr_nt );
350
        auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt );
351
352
        /*
353
         *  @CHENHAN: assume filters were already packed format.
354
         */
355
        for ( int i   = looppkA.beg(), ip  = packpkA.beg();
356
                  i   < looppkA.end();
357
                  i  += looppkA.inc(), ip += packpkA.inc() )
358
        {
359
          pack2D<true, PACK_MR>                          // packA (transA)
360
          (
361
            std::min( ib - i, MR ), pb,
362
            &A[ ( ic + i ) * k + pc ], k, &packA[ ip * pb ]
363
          );
364
        }
365
366
        if ( is_the_last_pc_iteration )                    // fused_macro_kernel
367
        {
368
          fused_macro_kernel
369
          <KC, MR, NR, PACK_MR, PACK_NR, MICROKERNEL, TA, TB, TC, TV>
370
          (
371
            thread,
372
            ic, jc, pc,
373
            ib, jb, pb,
374
            packA,
375
            packB,
376
            C + jc * m + ic, m,
377
            microkernel
378
          );
379
        }
380
        else                                               // semiring rank-k update
381
        {
382
          rank_k_macro_kernel
383
          <KC, MR, NR, PACK_MR, PACK_NR, SEMIRINGKERNEL, TA, TB, TC, TV>
384
          (
385
            thread,
386
            ic, jc, pc,
387
            ib, jb, pb,
388
            packA,
389
            packB,
390
            C + jc * m + ic, m,
391
            semiringkernel
392
          );
393
        }
394
        ic_comm.Barrier();                                 // sync all jr_id!!
395
      }                                                    // end 4th loop
396
      pc_comm.Barrier();
397
    }                                                      // end 5th loop
398
  }                                                        // end 6th loop
399
  //double my_time = omp_get_wtime() - my_beg;
400
  //double my_flop = ( ( loop6th.end() - loop6th.beg() ) / 1e+9 ) * 2 * m * k;
401
  ////printf( "tid %d GFLOPS %5.2lf\n", thread.jc_id, my_flop / my_time );
402
  //printf( "tid %d GFLOPS %5.2lf\n", thread.jc_id, my_time );
403
};                                                         // end cnn_internal
404
405
406
407
408
409
/**
410
 *  @CHENHAN:
411
 *
412
 *  These templates (the same as gkmx.hpp) define a general matrix-matrix multiplication.
413
 *  You will be using these existing code to write a convolution operation.
414
 *
415
 *  (First) you should define what parameters you need. For convolution, your
416
 *  input A will be a image (tensor). B is filters (tensor). C is the output,
417
 *  which again should be a tensor. Since tensors need more attributes to
418
 *  describe. You will need to think about what you need instead of m, n, k,
419
 *  lda, ldb, ldc.
420
 *
421
 *  (Second) you need to restructure the loop to loop over each convolution
422
 *  window. The window size (width*length) is the k dimension of your GEMM.
423
 *  Notice for each loop in the original GEMM operation you may need more than
424
 *  one loop in the convolution expression.
425
 *
426
 *  The jc loop (6th) will loop over each NC filters.
427
 *  The pc loop (5th) will loop over each KC elements in one window.
428
 *  The ic loop (4th) will loop over each MC windows of your image.
429
 *
430
 *  You probably don't need to change anything about the macro kernels we
431
 *  define here (3rd, 2nd loops), since in 4th loop you already transformed the
432
 *  problem into a GEMM operation.
433
 *
434
 *  (Third) finally you need to write two packing routines and one unpacking
435
 *  routine. Think about how to pack your image into packA and how to pack your
436
 *  filters into packB. Finally, you need to reshape your C back to the
437
 *  original tensor shape.
438
 *
439
 *  (Fourth) write a reference function cnn_ref and a test function
440
 *  /hmlp/test/test_cnn.cpp to compare your results.
441
 *
442
 *  Good luck and have fun!
443
 *
444
 *
445
 */
446
template<
447
  int MC, int NC, int KC, int MR, int NR,
448
  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
449
  bool USE_STRASSEN,
450
  typename SEMIRINGKERNEL, typename MICROKERNEL,
451
  typename TA, typename TB, typename TC, typename TV>
452
void conv2d
453
(
454
  int w0, int h0, int d0, int s, int p,
455
  TA *B,
456
  int w1, int h1, int d1,
457
  TB *A,
458
  TC *C,
459
  SEMIRINGKERNEL semiringkernel,
460
  MICROKERNEL microkernel
461
)
462
{
463
  int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1;
464
  int nc = NC, pack_nc = PACK_NC;
465
  char *str;
466
467
  int m = d1;
468
  int nx = ( w0 - w1 + 2 * p ) / s + 1;
469
  int ny = ( h0 - h1 + 2 * p ) / s + 1;
470
  int n = nx * ny;
471
  int k = w1 * h1 * d0;
472
473
474
  //printf( "m %4d n %4d k %4d\n", m, n, k );
475
476
  TA *packA_buff = NULL;
477
  TB *packB_buff = NULL;
478
479
  // Early return if possible
480
481
  // Check the environment variable.
482
  if ( omp_get_num_threads() == 1 && omp_get_max_threads() > 1 )
483
  {
484
    jc_nt = hmlp_read_nway_from_env( "KS_JC_NT" );
485
    ic_nt = hmlp_read_nway_from_env( "KS_IC_NT" );
486
    jr_nt = hmlp_read_nway_from_env( "KS_JR_NT" );
487
  }
488
489
490
  if ( jc_nt > 1 )
491
  {
492
    nc = ( ( n - 1 ) / ( NR * jc_nt ) + 1 ) * NR;
493
    //if ( nc > NC ) nc = NC;
494
    pack_nc = ( nc / NR ) * PACK_NR;
495
  }
496
497
  // allocate packing memory
498
  packA_buff  = hmlp_malloc<ALIGN_SIZE, TA>( KC, ( PACK_MC + 1 ) * jc_nt * ic_nt,         sizeof(TA) );
499
  packB_buff  = hmlp_malloc<ALIGN_SIZE, TB>( KC, ( pack_nc + 1 ) * jc_nt,                 sizeof(TB) );
500
501
  //#pragma omp parallel for
502
  //for ( int i = 0; i < KC * ( PACK_MC + 1 ) * jc_nt * ic_nt; i ++ ) packA_buff[ i ] = 1.0;
503
504
505
  // allocate tree communicator
506
  thread_communicator my_comm( jc_nt, pc_nt, ic_nt, jr_nt );
507
508
509
  #pragma omp parallel num_threads( my_comm.GetNumThreads() )
510
  {
511
    Worker thread( &my_comm );
512
513
    if ( USE_STRASSEN )
514
    {
515
      printf( "cnn: strassen algorithms haven't been implemented." );
516
      exit( 1 );
517
    }
518
519
    conv2d_internal
520
    <MC, NC, KC, MR, NR,
521
    PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
522
    USE_STRASSEN,
523
    SEMIRINGKERNEL, MICROKERNEL,
524
    TA, TB, TC, TB>
525
    (
526
      thread,
527
      w0, h0, d0, s, p,
528
      B,
529
      w1, h1, d1,
530
      A,
531
      C,
532
      semiringkernel, microkernel,
533
      nc, pack_nc,
534
      packA_buff,
535
      packB_buff
536
    );
537
  }                                                        // end omp
538
539
#ifdef DEBUG_CONV2D
540
  for ( int j = 0; j < ny; j ++ )
541
  {
542
    for ( int i = 0; i < nx; i ++ )
543
    {
544
      printf( "%5.2lf ", C[ j * nx + i ] );
545
    }
546
    printf( "\n" );
547
  }
548
#endif
549
550
};                                                         // end cnn
551
552
553
//template<
554
//  int MC, int NC, int KC, int MR, int NR,
555
//  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
556
//  bool USE_STRASSEN,
557
//  typename SEMIRINGKERNEL, typename MICROKERNEL,
558
//  typename TA, typename TB, typename TC, typename TV>
559
//void conv2d
560
//(
561
//  int w0, int h0, int d0,
562
//  TA *B,
563
//  int w1, int h1, int d1,
564
//  TB *A,
565
//  TC *C,
566
//  SEMIRINGKERNEL semiringkernel,
567
//  MICROKERNEL microkernel
568
//)
569
//{
570
//  // Deciding s and p given the output size is also (w0, h0).
571
//  // w0 = ( w0 - w1 + 2 * p ) / s + 1
572
//  // h0 = ( h0 - h1 + 2 * p ) / s + 1
573
//  // if s = 1, then p = ( w1 - 1 ) / 2
574
//  //                p = ( h1 - 1 ) / 2
575
//  // that is w1 and h1 must be odd.
576
//
577
//  assert( w1 == h1 );
578
//
579
//  conv2d
580
//  <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
581
//  USE_STRASSEN,
582
//  SEMIRINGKERNEL, MICROKERNEL,
583
//  TA, TB, TC, TV>
584
//  (
585
//    w0, h0, d0, 1, ( w1 - 1 ) / 2,
586
//    B,
587
//    w1, h1, d1,
588
//    A,
589
//    C,
590
//    semiringkernel,
591
//    microkernel
592
//  );
593
//};
594
595
template<
596
  int MC, int NC, int KC, int MR, int NR,
597
  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
598
  bool USE_STRASSEN,
599
  typename SEMIRINGKERNEL, typename MICROKERNEL,
600
  typename TA, typename TB, typename TC, typename TV>
601
void conv2d
602
(
603
  int w0, int h0, int d0, int s, int p, int batchSize,
604
  TA *B,
605
  int w1, int h1, int d1,
606
  TB *A,
607
  TC *C,
608
  SEMIRINGKERNEL semiringkernel,
609
  MICROKERNEL microkernel
610
)
611
{
612
  // Deciding s and p given the output size is also (w0, h0).
613
  // w0 = ( w0 - w1 + 2 * p ) / s + 1
614
  // h0 = ( h0 - h1 + 2 * p ) / s + 1
615
  // if s = 1, then p = ( w1 - 1 ) / 2
616
  //                p = ( h1 - 1 ) / 2
617
  // that is w1 and h1 must be odd.
618
619
  int nx = ( w0 - w1 + 2 * p ) / s + 1;
620
  int ny = ( h0 - h1 + 2 * p ) / s + 1;
621
622
623
  assert( w1 == h1 );
624
625
  #pragma omp parallel for
626
  for ( int b = 0; b < batchSize; b ++ )
627
  {
628
    conv2d
629
    <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
630
    USE_STRASSEN,
631
    SEMIRINGKERNEL, MICROKERNEL,
632
    TA, TB, TC, TV>
633
    (
634
      w0, h0, d0, s, p,
635
      B + b * w0 * h0 * d0,
636
      w1, h1, d1,
637
      A,
638
      C + b * nx * ny * d1,
639
      semiringkernel,
640
      microkernel
641
    );
642
  }
643
};
644
645
646
/**
647
 *  @CHENHAN: write a reference function using GEMM. The signiture of xgemm can
648
 *  be found in hmlp_blas_lapack.h.
649
 */
650
template<typename T>
651
void conv2d_ref
652
(
653
  int w0, int h0, int d0, int s, int p,
654
  T *B,
655
  int w1, int h1, int d1,
656
  T *A,
657
  T *C
658
)
659
{
660
  int m = d1;
661
  int nx = ( w0 - w1 + 2 * p ) / s + 1;
662
  int ny = ( h0 - h1 + 2 * p ) / s + 1;
663
  int n = nx * ny;
664
  int k = w1 * h1 * d0;
665
666
  T *packA = A;
667
  T *packB = hmlp_malloc<16, T>( k, n, sizeof(T) );
668
669
  double beg = omp_get_wtime();
670
  im2col<T>
671
  (
672
    n, k,
673
    packB, B,
674
    w0, h0, d0, s, p,
675
    w1, h1
676
  );
677
  double im2col_t = omp_get_wtime() - beg;
678
  printf( "im2col( B ) %3.1Es\n", im2col_t ); fflush( stdout );
679
680
#ifdef DEBUG_CONV2D
681
  printf( "packB\n" );
682
  for ( int p = 0; p < k; p ++ )
683
  {
684
    for ( int j = 0; j < n; j ++ )
685
    {
686
      printf( "%5.2lf ", packB[ j * k + p ] );
687
    }
688
    printf( "\n" );
689
  }
690
#endif
691
692
693
#ifdef USE_BLAS
694
  xgemm
695
  (
696
    "T", "N",
697
    m, n, k,
698
    1.0, packA, k,
699
         packB, k,
700
    0.0,     C, m
701
  );
702
#else
703
  #pragma omp parallel for
704
  for ( int j = 0; j < n; j ++ )
705
  {
706
    for ( int i = 0; i < m; i ++ )
707
    {
708
      C[ j * m + i ] = 0.0;
709
      for ( int p = 0; p < k; p ++ )
710
      {
711
        C[ j * m + i ] += packA[ i * k + p ] * packB[ j * k + p ];
712
      }
713
    }
714
  }
715
#endif
716
}; // end void conv2d_ref
717
718
template<typename T>
719
void conv2d_ref
720
(
721
  int w0, int h0, int d0, int s, int p, int batchSize,
722
  T *B,
723
  int w1, int h1, int d1,
724
  T *A,
725
  T *C
726
)
727
{
728
  int nx = ( w0 - w1 + 2 * p ) / s + 1;
729
  int ny = ( h0 - h1 + 2 * p ) / s + 1;
730
731
  #pragma omp parallel for
732
  for ( int b = 0; b < batchSize; b ++ )
733
  {
734
    conv2d_ref<T>
735
    (
736
      w0, h0, d0, s, p,
737
      B + b * w0 * h0 * d0,
738
      w1, h1, d1,
739
      A,
740
      C + b * nx * ny * d1
741
    );
742
  }
743
};
744
745
}; // end namespace conv2d
746
}; // end namespace hmlp
747
748
#endif // define GKMX_HPP