GCC Code Coverage Report
Directory: . Exec Total Coverage
File: frame/primitives/gsks.hpp Lines: 0 58 0.0 %
Date: 2019-01-14 Branches: 0 95 0.0 %

Line Exec Source
1
/**
2
 *  HMLP (High-Performance Machine Learning Primitives)
3
 *
4
 *  Copyright (C) 2014-2017, The University of Texas at Austin
5
 *
6
 *  This program is free software: you can redistribute it and/or modify
7
 *  it under the terms of the GNU General Public License as published by
8
 *  the Free Software Foundation, either version 3 of the License, or
9
 *  (at your option) any later version.
10
 *
11
 *  This program is distributed in the hope that it will be useful,
12
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
 *  GNU General Public License for more details.
15
 *
16
 *  You should have received a copy of the GNU General Public License
17
 *  along with this program. If not, see the LICENSE file.
18
 *
19
 **/
20
21
22
23
24
#ifndef GSKS_HXX
25
#define GSKS_HXX
26
27
#include <math.h>
28
#include <vector>
29
30
#include <hmlp.h>
31
#include <hmlp_internal.hpp>
32
#include <hmlp_base.hpp>
33
34
#include <KernelMatrix.hpp>
35
36
37
namespace hmlp
38
{
39
namespace gsks
40
{
41
42
#define min( i, j ) ( (i)<(j) ? (i): (j) )
43
#define KS_RHS 1
44
45
/**
46
 *
47
 */
48
template<
49
  int KC, int MR, int NR, int PACK_MR, int PACK_NR,
50
  typename SEMIRINGKERNEL,
51
  typename TA, typename TB, typename TC, typename TV>
52
void rank_k_macro_kernel
53
(
54
  Worker &thread,
55
  int ic, int jc, int pc,
56
  int  m, int n,  int  k,
57
  TA *packA,
58
  TB *packB,
59
  TV *packC, int ldc,
60
  SEMIRINGKERNEL semiringkernel
61
)
62
{
63
  thread_communicator &ic_comm = *thread.ic_comm;
64
65
  auto loop3rd = GetRange( 0, n,      NR, thread.jr_id, ic_comm.GetNumThreads() );
66
  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
67
  auto loop2nd = GetRange( 0, m,      MR );
68
  auto pack2nd = GetRange( 0, m, PACK_MR );
69
70
  for ( int j   = loop3rd.beg(), jp  = pack3rd.beg();
71
            j   < loop3rd.end();
72
            j  += loop3rd.inc(), jp += pack3rd.inc() )     // beg 3rd loop
73
  {
74
    struct aux_s<TA, TB, TC, TV> aux;
75
    aux.pc       = pc;
76
    aux.b_next   = packB;
77
    aux.do_packC = 1;
78
    aux.jb       = min( n - j, NR );
79
80
    for ( int i  = loop2nd.beg(), ip  = pack2nd.beg();
81
              i  < loop2nd.end();
82
              i += loop2nd.inc(), ip += pack2nd.inc() )    // beg 2nd loop
83
    {
84
      aux.ib = min( m - i, MR );
85
      if ( i + MR >= m )
86
      {
87
        aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
88
      }
89
      semiringkernel
90
      (
91
        k,
92
        &packA[ ip * k ],
93
        &packB[ jp * k ],
94
        //&packC[ j * ldc + i * NR ], ldc,
95
        &packC[ j * ldc + i * NR ], 1, MR,
96
        &aux
97
      );
98
    }                                                      // end 2nd loop
99
  }                                                        // end 3rd loop
100
}                                                          // end rank_k_macro_kernel
101
102
/**
103
 *
104
 */
105
template<
106
  int KC, int MR, int NR, int PACK_MR, int PACK_NR,
107
  typename MICROKERNEL,
108
  typename TA, typename TB, typename TC, typename TV>
109
void fused_macro_kernel
110
(
111
  //ks_t *kernel,
112
  kernel_s<TV, TC> *kernel,
113
  Worker &thread,
114
  int ic, int jc, int pc,
115
  int  m,  int n,  int k,
116
  TC *packu,
117
  TA *packA, TA *packA2, TV *packAh,
118
  TB *packB, TB *packB2, TV *packBh,
119
  TC *packw,
120
  TV *packC, int ldc,
121
  MICROKERNEL microkernel
122
)
123
{
124
  thread_communicator &ic_comm = *thread.ic_comm;
125
126
  auto loop3rd = GetRange( 0, n,      NR, thread.jr_id, ic_comm.GetNumThreads() );
127
  auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() );
128
  auto loop2nd = GetRange( 0, m,      MR );
129
  auto pack2nd = GetRange( 0, m, PACK_MR );
130
131
  for ( int j   = loop3rd.beg(), jp  = pack3rd.beg();
132
            j   < loop3rd.end();
133
            j  += loop3rd.inc(), jp += pack3rd.inc() )     // beg 3rd loop
134
  {
135
    struct aux_s<TA, TB, TC, TV> aux;
136
    aux.pc       = pc;
137
    aux.b_next   = packB;
138
    aux.do_packC = 1;
139
    aux.jb       = min( n - j, NR );
140
141
    for ( int i  = loop2nd.beg(), ip  = pack2nd.beg();
142
              i  < loop2nd.end();
143
              i += loop2nd.inc(), ip += pack2nd.inc() )    // beg 2nd loop
144
    {
145
      aux.ib = min( m - i, MR );
146
      if ( i + MR >= m )
147
      {
148
        aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k;
149
      }
150
      aux.hi = packAh + ip;
151
      aux.hj = packBh + jp;
152
      microkernel
153
      (
154
        kernel,
155
        k,
156
        KS_RHS,
157
        packu  + ip * KS_RHS,
158
        packA  + ip * k,
159
        packA2 + ip,
160
        packB  + jp * k,
161
        packB2 + jp,
162
        packw  + jp * KS_RHS,
163
        packC  + j * ldc + i * NR, MR,                     // packed
164
        &aux
165
      );
166
    }                                                      // end 2nd loop
167
  }                                                        // end 3rd loop
168
}                                                          // end fused_macro_kernel
169
170
171
/**
172
 *
173
 */
174
template<
175
  int MC, int NC, int KC, int MR, int NR,
176
  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
177
  bool USE_L2NORM, bool USE_VAR_BANDWIDTH, bool USE_STRASSEN,
178
  typename SEMIRINGKERNEL, typename MICROKERNEL,
179
  typename TA, typename TB, typename TC, typename TV>
180
void gsks_internal
181
(
182
  Worker &thread,
183
  //ks_t *kernel,
184
  kernel_s<TV, TC> *kernel,
185
  int m, int n, int k,
186
  TC *u,         int *umap,
187
  TA *A, TA *A2, int *amap,
188
  TB *B, TB *B2, int *bmap,
189
  TC *w,         int *wmap,
190
  SEMIRINGKERNEL semiringkernel,
191
  MICROKERNEL microkernel,
192
  int nc, int pack_nc,
193
  TC *packu,
194
  TA *packA, TA *packA2, TA *packAh,
195
  TB *packB, TB *packB2, TB *packBh,
196
  TC *packw,
197
  TV *packC, int ldpackc, int padn
198
)
199
{
200
  packu  += ( thread.jc_id * thread.ic_nt * thread.jr_nt ) * PACK_MC * KS_RHS
201
          + ( thread.ic_id * thread.jr_nt + thread.jr_id ) * PACK_MC * KS_RHS;
202
  packA  += ( thread.jc_id * thread.ic_nt                ) * PACK_MC * KC
203
          + ( thread.ic_id                               ) * PACK_MC * KC;
204
  packA2 += ( thread.jc_id * thread.ic_nt + thread.ic_id ) * PACK_MC;
205
  packAh += ( thread.jc_id * thread.ic_nt + thread.ic_id ) * PACK_MC;
206
  packB  += ( thread.jc_id                               ) * pack_nc * KC;
207
  packB2 += ( thread.jc_id                               ) * pack_nc;
208
  packBh += ( thread.jc_id                               ) * pack_nc;
209
  packw  += ( thread.jc_id                               ) * pack_nc;
210
  packC  += ( thread.jc_id                               ) * ldpackc * padn;
211
212
  auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt );
213
  auto loop5th = GetRange( 0, k, KC );
214
  auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt );
215
216
  for ( int jc  = loop6th.beg();
217
            jc  < loop6th.end();
218
            jc += loop6th.inc() )                          // beg 6th loop
219
  {
220
    auto &jc_comm = *thread.jc_comm;
221
    auto jb = min( n - jc, nc );
222
223
    for ( int pc  = loop5th.beg();
224
              pc  < loop5th.end();
225
              pc += loop5th.inc() )
226
    {
227
      auto &pc_comm = *thread.pc_comm;
228
      auto pb = min( k - pc, KC );
229
      auto is_the_last_pc_iteration = ( pc + KC >= k );
230
231
      auto looppkB = GetRange( 0, jb,      NR, thread.ic_jr, pc_comm.GetNumThreads() );
232
      auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() );
233
234
      for ( int j   = looppkB.beg(), jp  = packpkB.beg();
235
                j   < looppkB.end();
236
                j  += looppkB.inc(), jp += packpkB.inc() )
237
      {
238
        pack2D<true, PACK_NR>                              // packB
239
        (
240
          min( jb - j, NR ), pb,
241
          &B[ pc ], k, &bmap[ jc + j ], &packB[ jp * pb ]
242
        );
243
244
245
        if ( is_the_last_pc_iteration )
246
        {
247
          pack2D<true, PACK_NR, true>                      // packw
248
          (
249
            min( jb - j, NR ), 1,
250
            &w[ 0 ], 1, &wmap[ jc + j ], &packw[ jp * 1 ]
251
          );
252
253
          if ( USE_L2NORM )
254
          {
255
            pack2D<true, PACK_NR>                          // packB2
256
            (
257
              min( jb - j, NR ), 1,
258
              &B2[ 0 ], 1, &bmap[ jc + j ], &packB2[ jp * 1 ]
259
            );
260
          }
261
262
          if ( USE_VAR_BANDWIDTH )
263
          {
264
            pack2D<true, PACK_NR>                          // packBh
265
            (
266
              min( jb - j, NR ), 1,
267
              kernel->hj, 1, &bmap[ jc + j ], &packBh[ jp * 1 ]
268
            );
269
          }
270
        }
271
      }
272
      pc_comm.Barrier();
273
274
      for ( int ic  = loop4th.beg();
275
                ic  < loop4th.end();
276
                ic += loop4th.inc() )                      // beg 4th loop
277
      {
278
        auto &ic_comm = *thread.ic_comm;
279
        auto ib = min( m - ic, MC );
280
281
        auto looppkA = GetRange( 0, ib,      MR, thread.jr_id, thread.jr_nt );
282
        auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt );
283
284
        for ( int i   = looppkA.beg(), ip  = packpkA.beg();
285
                  i   < looppkA.end();
286
                  i  += looppkA.inc(), ip += packpkA.inc() )
287
        {
288
          pack2D<true, PACK_MR>                            // packA
289
          (
290
            min( ib - i, MR ), pb,
291
            &A[ pc ], k, &amap[ ic + i ], &packA[ ip * pb ]
292
          );
293
294
          if ( is_the_last_pc_iteration )
295
          {
296
            if ( USE_L2NORM )
297
            {
298
              pack2D<true, PACK_MR>                        // packA2
299
              (
300
                min( ib - i, MR ), 1,
301
                &A2[ 0 ], 1, &amap[ ic + i ], &packA2[ ip * 1 ]
302
              );
303
            }
304
305
            if ( USE_VAR_BANDWIDTH )                       // variable bandwidths
306
            {
307
              pack2D<true, PACK_MR>                        // packAh
308
              (
309
                min( ib - i, MR ), 1,
310
                kernel->hi, 1, &amap[ ic + i ], &packAh[ ip * 1 ]
311
              );
312
            }
313
          }
314
        }
315
316
        if ( is_the_last_pc_iteration )                    // Initialize packu to zeros.
317
        {
318
          for ( auto i = 0, ip = 0; i < ib; i += MR, ip += PACK_MR )
319
          {
320
            for ( auto ir = 0; ir < min( ib - i, MR ); ir ++ )
321
            {
322
              packu[ ip + ir ] = 0.0;
323
            }
324
          }
325
        }
326
        ic_comm.Barrier();
327
328
329
        if ( is_the_last_pc_iteration )                    // fused_macro_kernel
330
        {
331
          fused_macro_kernel
332
          <KC, MR, NR, PACK_MR, PACK_NR, MICROKERNEL, TA, TB, TC, TV>
333
          (
334
            kernel,
335
            thread,
336
            ic, jc, pc,
337
            ib, jb, pb,
338
            packu,
339
            packA, packA2, packAh,
340
            packB, packB2, packBh,
341
            packw,
342
            packC + ic * padn,                             // packed
343
            ( ( ib - 1 ) / MR + 1 ) * MR,                  // packed ldc
344
            microkernel
345
          );
346
        }
347
        else                                               // semiring rank-k update
348
        {
349
          rank_k_macro_kernel
350
          <KC, MR, NR, PACK_MR, PACK_NR, SEMIRINGKERNEL, TA, TB, TC, TV>
351
          (
352
            thread,
353
            ic, jc, pc,
354
            ib, jb, pb,
355
            packA,
356
            packB,
357
            packC + ic * padn,                             // packed
358
            ( ( ib - 1 ) / MR + 1 ) * MR,                  // packed ldc
359
            semiringkernel
360
          );
361
        }
362
        ic_comm.Barrier();                                 // sync all jr_id!!
363
364
        if ( is_the_last_pc_iteration )
365
        {
366
          for ( auto i = 0, ip = 0; i < ib; i += MR, ip += PACK_MR )
367
          {
368
            for ( auto ir = 0; ir < min( ib - i, MR ); ir ++ )
369
            {
370
              TC *uptr = &( u[ umap[ ic + i + ir ] ] );
371
              #pragma omp atomic update                    // concurrent write
372
              *uptr += packu[ ip + ir ];
373
            }
374
          }
375
          ic_comm.Barrier();                               // sync all jr_id!!
376
        }
377
      }                                                    // end 4th loop
378
      pc_comm.Barrier();
379
    }                                                      // end 5th loop
380
  }                                                        // end 6th loop
381
}                                                          // end gsks_internal
382
383
384
385
386
387
/**
388
 *
389
 */
390
template<
391
  int MC, int NC, int KC, int MR, int NR,
392
  int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE,
393
  bool USE_L2NORM, bool USE_VAR_BANDWIDTH, bool USE_STRASSEN,
394
  typename SEMIRINGKERNEL, typename MICROKERNEL,
395
  typename TA, typename TB, typename TC, typename TV>
396
void gsks
397
(
398
  kernel_s<TV, TC> *kernel,
399
  int m, int n, int k,
400
  TC *u,         int *umap,
401
  TA *A, TA *A2, int *amap,
402
  TB *B, TB *B2, int *bmap,
403
  TC *w,         int *wmap,
404
  SEMIRINGKERNEL semiringkernel,
405
  MICROKERNEL microkernel
406
)
407
{
408
  int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1;
409
  int ldpackc = 0, padn = 0, nc = NC, pack_nc = PACK_NC;
410
  char *str;
411
412
  TC *packu_buff = NULL;
413
  TA *packA_buff = NULL, *packA2_buff = NULL, *packAh_buff = NULL;
414
  TB *packB_buff = NULL, *packB2_buff = NULL, *packBh_buff = NULL;
415
  TC *packw_buff = NULL;
416
  TV *packC_buff = NULL;
417
418
  // Early return if possible
419
  if ( m == 0 || n == 0 || k == 0 ) return;
420
421
  // Check the environment variable.
422
  jc_nt = hmlp_read_nway_from_env( "KS_JC_NT" );
423
  ic_nt = hmlp_read_nway_from_env( "KS_IC_NT" );
424
  jr_nt = hmlp_read_nway_from_env( "KS_JR_NT" );
425
426
  if ( jc_nt > 1 )
427
  {
428
    nc = ( ( n - 1 ) / ( NR * jc_nt ) + 1 ) * NR;
429
    pack_nc = ( nc / NR ) * PACK_NR;
430
  }
431
432
  // allocate packing memory
433
  {
434
    packA_buff  = hmlp_malloc<ALIGN_SIZE, TA>( KC, ( PACK_MC + 1 ) * jc_nt * ic_nt,         sizeof(TA) );
435
    packB_buff  = hmlp_malloc<ALIGN_SIZE, TB>( KC, ( pack_nc + 1 ) * jc_nt,                 sizeof(TB) );
436
    packu_buff  = hmlp_malloc<ALIGN_SIZE, TC>(  1, ( PACK_MC + 1 ) * jc_nt * ic_nt * jr_nt, sizeof(TC) );
437
    packw_buff  = hmlp_malloc<ALIGN_SIZE, TC>(  1, ( pack_nc + 1 ) * jc_nt,                 sizeof(TC) );
438
  }
439
440
  // allocate extra packing buffer
441
  if ( USE_L2NORM )
442
  {
443
    packA2_buff = hmlp_malloc<ALIGN_SIZE, TA>(  1, ( PACK_MC + 1 ) * jc_nt * ic_nt,         sizeof(TA) );
444
    packB2_buff = hmlp_malloc<ALIGN_SIZE, TB>(  1, ( pack_nc + 1 ) * jc_nt,                 sizeof(TB) );
445
  }
446
447
  if ( USE_VAR_BANDWIDTH )
448
  {
449
    packAh_buff = hmlp_malloc<ALIGN_SIZE, TA>(  1, ( PACK_MC + 1 ) * jc_nt * ic_nt,         sizeof(TA) );
450
    packBh_buff = hmlp_malloc<ALIGN_SIZE, TB>(  1, ( pack_nc + 1 ) * jc_nt,                 sizeof(TB) );
451
  }
452
453
  // Temporary bufferm <TV> to store the semi-ring rank-k update
454
  if ( k > KC )
455
  {
456
    ldpackc  = ( ( m - 1 ) / PACK_MR + 1 ) * PACK_MR;
457
    padn = pack_nc;
458
    if ( n < nc ) padn = ( ( n - 1 ) / PACK_NR + 1 ) * PACK_NR ;
459
    packC_buff = hmlp_malloc<ALIGN_SIZE, TV>( ldpackc, padn * jc_nt, sizeof(TV) );
460
  }
461
462
  // allocate tree communicator
463
  thread_communicator my_comm( jc_nt, pc_nt, ic_nt, jr_nt );
464
465
466
  #pragma omp parallel num_threads( my_comm.GetNumThreads() )
467
  {
468
    Worker thread( &my_comm );
469
470
    if ( USE_STRASSEN )
471
    {
472
      printf( "gsks: strassen algorithms haven't been implemented." );
473
      exit( 1 );
474
    }
475
476
    gsks_internal
477
    <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE,
478
    USE_L2NORM, USE_VAR_BANDWIDTH, USE_STRASSEN,
479
    SEMIRINGKERNEL, MICROKERNEL,
480
    TA, TB, TC, TB>
481
    (
482
      thread,
483
      kernel,
484
      m, n, k,
485
      u,     umap,
486
      A, A2, amap,
487
      B, B2, bmap,
488
      w,     wmap,
489
      semiringkernel, microkernel,
490
      nc, pack_nc,
491
      packu_buff,
492
      packA_buff, packA2_buff, packAh_buff,
493
      packB_buff, packB2_buff, packBh_buff,
494
      packw_buff,
495
      packC_buff, ldpackc, padn
496
    );
497
498
  } /** end omp region */
499
500
  hmlp_free( packA_buff );
501
  hmlp_free( packB_buff );
502
  hmlp_free( packu_buff );
503
  hmlp_free( packw_buff );
504
  if ( USE_L2NORM )
505
  {
506
    hmlp_free( packA2_buff );
507
    hmlp_free( packB2_buff );
508
  }
509
} /** end gsks() */
510
511
512
/**
513
 *
514
 */
515
template<typename T>
516
void gsks_ref
517
(
518
  //ks_t *kernel,
519
  kernel_s<T, T> *kernel,
520
  int m, int n, int k,
521
  T *u,        int *umap,
522
  T *A, T *A2, int *amap,
523
  T *B, T *B2, int *bmap,
524
  T *w,        int *wmap
525
)
526
{
527
  int nrhs = KS_RHS;
528
  T rank_k_scale, fone = 1.0, fzero = 0.0;
529
  std::vector<T> packA, packB, C, packu, packw;
530
531
  // Early return if possible
532
  if ( m == 0 || n == 0 || k == 0 ) return;
533
534
  packA.resize( k * m );
535
  packB.resize( k * n );
536
  C.resize( m * n );
537
  packu.resize( m );
538
  packw.resize( n );
539
540
  switch ( kernel->type )
541
  {
542
    case GAUSSIAN:
543
      rank_k_scale = -2.0;
544
      break;
545
    case GAUSSIAN_VAR_BANDWIDTH:
546
      rank_k_scale = -2.0;
547
      break;
548
    default:
549
      exit( 1 );
550
  }
551
552
  /*
553
   *  Collect packA and packu
554
   */
555
  #pragma omp parallel for
556
  for ( int i = 0; i < m; i ++ )
557
  {
558
    for ( int p = 0; p < k; p ++ )
559
    {
560
      packA[ i * k + p ] = A[ amap[ i ] * k + p ];
561
    }
562
    for ( int p = 0; p < KS_RHS; p ++ )
563
    {
564
      packu[ p * m + i ] = u[ umap[ i ] * KS_RHS + p ];
565
    }
566
  }
567
568
  /*
569
   *  Collect packB and packw
570
   */
571
  #pragma omp parallel for
572
  for ( int j = 0; j < n; j ++ )
573
  {
574
    for ( int p = 0; p < k; p ++ )
575
    {
576
      packB[ j * k + p ] = B[ bmap[ j ] * k + p ];
577
    }
578
    for ( int p = 0; p < KS_RHS; p ++ )
579
    {
580
      packw[ p * n + j ] = w[ wmap[ j ] * KS_RHS + p ];
581
    }
582
  }
583
584
  /*
585
   *  C = -2.0 * A^T * B (GEMM)
586
   */
587
#ifdef USE_BLAS
588
  xgemm
589
  (
590
    "T", "N",
591
    m, n, k,
592
    rank_k_scale, packA.data(), k,
593
                  packB.data(), k,
594
    fzero,        C.data(),     m
595
  );
596
#else
597
  #pragma omp parallel for
598
  for ( int j = 0; j < n; j ++ )
599
  {
600
    for ( int i = 0; i < m; i ++ )
601
    {
602
      C[ j * m + i ] = 0.0;
603
      for ( int p = 0; p < k; p ++ )
604
      {
605
        C[ j * m + i ] += packA[ i * k + p ] * packB[ j * k + p ];
606
      }
607
    }
608
  }
609
  #pragma omp parallel for
610
  for ( int j = 0; j < n; j ++ )
611
  {
612
    for ( int i = 0; i < m; i ++ )
613
    {
614
      C[ j * m + i ] *= rank_k_scale;
615
    }
616
  }
617
#endif
618
619
  switch ( kernel->type )
620
  {
621
    case GAUSSIAN:
622
      {
623
        #pragma omp parallel for
624
        for ( int j = 0; j < n; j ++ )
625
        {
626
          for ( int i = 0; i < m; i ++ )
627
          {
628
            C[ j * m + i ] += A2[ amap[ i ] ];
629
            C[ j * m + i ] += B2[ bmap[ j ] ];
630
            C[ j * m + i ] *= kernel->scal;
631
          }
632
          for ( int i = 0; i < m; i ++ )
633
          {
634
            C[ j * m + i ] = exp( C[ j * m + i ] );
635
          }
636
        }
637
        break;
638
      }
639
    case GAUSSIAN_VAR_BANDWIDTH:
640
      {
641
        #pragma omp parallel for
642
        for ( int j = 0; j < n; j ++ )
643
        {
644
          for ( int i = 0; i < m; i ++ )
645
          {
646
            C[ j * m + i ] += A2[ amap[ i ] ];
647
            C[ j * m + i ] += B2[ bmap[ j ] ];
648
            C[ j * m + i ] *= -0.5;
649
            C[ j * m + i ] *= kernel->hi[ i ];
650
            C[ j * m + i ] *= kernel->hj[ j ];
651
          }
652
          for ( int i = 0; i < m; i ++ )
653
          {
654
            C[ j * m + i ] = exp( C[ j * m + i ] );
655
          }
656
        }
657
        break;
658
      }
659
    default:
660
      exit( 1 );
661
  }
662
663
  /*
664
   *  Kernel Summation
665
   */
666
#ifdef USE_BLAS
667
  xgemm
668
  (
669
    "N", "N",
670
    m, nrhs, n,
671
    fone, C.data(),     m,
672
          packw.data(), n,
673
    fone, packu.data(), m
674
  );
675
#else
676
  #pragma omp parallel for
677
  for ( int i = 0; i < m; i ++ )
678
  {
679
    for ( int j = 0; j < nrhs; j ++ )
680
    {
681
      for ( int p = 0; p < n; p ++ )
682
      {
683
        packu[ j * m + i ] += C[ p * m + i ] * packw[ j * n + p ];
684
      }
685
    }
686
  }
687
#endif
688
689
  /*
690
   *  Assemble packu back
691
   */
692
  #pragma omp parallel for
693
  for ( int i = 0; i < m; i ++ )
694
  {
695
    for ( int p = 0; p < KS_RHS; p ++ )
696
    {
697
      u[ umap[ i ] * KS_RHS + p ] = packu[ p * m + i ];
698
    }
699
  }
700
701
} // end void gsks_ref
702
703
704
}; /** end namespace gsks */
705
}; /** end namespace hmlp */
706
707
#endif // define GSKS_HXX