HMLP: High-performance Machine Learning Primitives
gsks_ref_mrxnr.hpp
1 #ifndef GSKS_REF_MRXNR_HPP
2 #define GSKS_REF_MRXNR_HPP
3 
4 #include <KernelMatrix.hpp>
5 
6 using namespace std;
7 using namespace hmlp;
8 
9 namespace hmlp
10 {
11 
12 template<int MR, int NR, typename T>
14 {
15  inline void operator()
16  (
17  kernel_s<T, T> *kernel,
18  int k,
19  int nrhs,
20  T *u,
21  T *a, T *a2,
22  T *b, T *b2,
23  T *w,
24  T *c, int ldc,
25  aux_s<T, T, T, T> *aux
26  ) const
27  {
29  T c_reg[ MR * NR ] = { 0.0 };
30 
32  for ( int p = 0; p < k; p ++ )
33  #pragma unroll
34  for ( int j = 0; j < NR; j ++ )
35  #pragma unroll
36  for ( int i = 0; i < MR; i ++ )
37  c_reg[ j * MR + i ] += a[ p * MR + i ] * b[ p * NR + j ];
38 
40  if ( aux->pc )
41  {
42  #pragma unroll
43  for ( int j = 0; j < NR; j ++ )
44  #pragma unroll
45  for ( int i = 0; i < MR; i ++ )
46  c_reg[ j * MR + i ] += c[ j * ldc + i ];
47  }
48 
49  switch ( kernel->type )
50  {
51  case GAUSSIAN:
52  {
53  #pragma unroll
54  for ( int j = 0; j < NR; j ++ )
55  {
56  #pragma unroll
57  for ( int i = 0; i < MR; i ++ )
58  {
59  c_reg[ j * MR + i ] *= -2.0;
60  c_reg[ j * MR + i ] += a2[ i ] + b2[ j ];
61  c_reg[ j * MR + i ] *= kernel->scal;
62  c_reg[ j * MR + i ] = std::exp( c_reg[ j * MR + i ] );
63  }
64  }
65  break;
66  }
67  default:
68  {
69  exit( 1 );
70  }
71  }
72 
74  #pragma unroll
75  for ( int j = 0; j < NR; j ++ )
76  #pragma unroll
77  for ( int i = 0; i < MR; i ++ )
78  u[ i ] += c_reg[ j * MR + i ] * w[ j ];
79 
80  };
81 };
83 };
85 #endif
Definition: gsks_ref_mrxnr.hpp:13
Definition: hmlp_internal.hpp:38
Definition: gofmm.hpp:83