HMLP: High-performance Machine Learning Primitives
variable_bandwidth_gaussian_d8x4.hpp
1 #include <stdio.h>
2 #include <math.h>
3 #include <immintrin.h> // AVX
4 
5 #include <hmlp.h>
6 #include <hmlp_internal.hpp>
7 #include <avx_type.h> // self-defined vector type
8 
9 // #define DEBUG_MICRO 1
10 
12 {
13  inline void operator()
14  (
15  //ks_t *kernel,
16  kernel_s<double> *kernel,
17  int k,
18  int nrhs,
19  double *u,
20  double *a, double *a2,
21  double *b, double *b2,
22  double *w,
23  double *c, int ldc,
25  ) const
26  {
27  double c_reg[ 8 * 4 ] = { 0.0 };
28 
29  for ( int p = 0; p < k; p ++ )
30  {
31  #pragma unroll
32  for ( int j = 0; j < 4; j ++ )
33  {
34  #pragma unroll
35  for ( int i = 0; i < 8; i ++ )
36  {
37  c_reg[ j * 8 + i ] += a[ p * 8 + i ] * b[ p * 4 + j ];
38  }
39  }
40  }
41 
42  if ( aux->pc )
43  {
44  #pragma unroll
45  for ( int j = 0; j < 4; j ++ )
46  {
47  #pragma unroll
48  for ( int i = 0; i < 8; i ++ )
49  {
50  c_reg[ j * 8 + i ] += c[ j * ldc + i ];
51  }
52  }
53  }
54 
55 #ifdef DEBUG_MICRO
56  printf( "variable_bandwidth_gaussian_ref_d8x4: c_reg\n" );
57  for ( int i = 0; i < 8; i ++ )
58  {
59  for ( int j = 0; j < 4; j ++ )
60  {
61  //printf( "%E (%E) ", c_reg[ j * 8 + i ], c[ j * 8 + i ] );
62  printf( "%E ", c_reg[ j * 8 + i ] );
63  }
64  printf( "\n" );
65  }
66 #endif
67 
68  #pragma unroll
69  for ( int j = 0; j < 4; j ++ )
70  {
71  #pragma unroll
72  for ( int i = 0; i < 8; i ++ )
73  {
74  c_reg[ j * 8 + i ] *= -2.0;
75  c_reg[ j * 8 + i ] += a2[ i ] + b2[ j ];
76  c_reg[ j * 8 + i ] *= -0.5;
77  c_reg[ j * 8 + i ] *= aux->hi[ i ];
78  c_reg[ j * 8 + i ] *= aux->hj[ j ];
79  c_reg[ j * 8 + i ] = exp( c_reg[ j * 8 + i ] );
80  }
81  }
82 
83  #pragma unroll
84  for ( int j = 0; j < 4; j ++ )
85  {
86  #pragma unroll
87  for ( int i = 0; i < 8; i ++ )
88  {
89  u[ i ] += c_reg[ j * 8 + i ] * w[ j ];
90  }
91  }
92 
93  }; // end inline void operator
94 }; // end struct variable_bandwidth_gaussian_ref_d8x4
95 
96 
97 
99 {
100  inline void operator()
101  (
102  //ks_t *ker,
103  kernel_s<double> *ker,
104  int k,
105  int rhs,
106  double *u,
107  double *a, double *aa,
108  double *b, double *bb,
109  double *w,
110  double *c, int ldc,
112  ) const
113  {
114  int i, rhs_left;
115  double neg2 = -2.0;
116  double neghalf = -0.5;
117  double dzero = 0.0;
118 
119  v4df_t c03_0, c03_1, c03_2, c03_3;
120  v4df_t c47_0, c47_1, c47_2, c47_3;
121  v4df_t tmpc03_0, tmpc03_1, tmpc03_2, tmpc03_3;
122  v4df_t tmpc47_0, tmpc47_1, tmpc47_2, tmpc47_3;
123  v4df_t u03, u47;
124  v4df_t a03, a47, A03, A47; // prefetched A
125  v4df_t b0, b1, b2, b3, B0; // prefetched B
126  v4df_t c_tmp, aa_tmp, bb_tmp, w_tmp;
127 
128  // Rank-k update segment
129  #include "component/rank_k_int_d8x4.hpp"
130 
131  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( aa ) );
132  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( bb ) );
133 
134  if ( aux->pc )
135  {
136  tmpc03_0.v = _mm256_load_pd( (double*)( c ) );
137  tmpc47_0.v = _mm256_load_pd( (double*)( c + 4 ) );
138 
139  tmpc03_1.v = _mm256_load_pd( (double*)( c + 1 * ldc ) );
140  tmpc47_1.v = _mm256_load_pd( (double*)( c + 1 * ldc + 4 ) );
141 
142  tmpc03_2.v = _mm256_load_pd( (double*)( c + 2 * ldc ) );
143  tmpc47_2.v = _mm256_load_pd( (double*)( c + 2 * ldc + 4 ) );
144 
145  tmpc03_3.v = _mm256_load_pd( (double*)( c + 3 * ldc ) );
146  tmpc47_3.v = _mm256_load_pd( (double*)( c + 3 * ldc + 4 ) );
147 
148 
149  c03_0.v = _mm256_add_pd( tmpc03_0.v, c03_0.v );
150  c47_0.v = _mm256_add_pd( tmpc47_0.v, c47_0.v );
151 
152  c03_1.v = _mm256_add_pd( tmpc03_1.v, c03_1.v );
153  c47_1.v = _mm256_add_pd( tmpc47_1.v, c47_1.v );
154 
155  c03_2.v = _mm256_add_pd( tmpc03_2.v, c03_2.v );
156  c47_2.v = _mm256_add_pd( tmpc47_2.v, c47_2.v );
157 
158  c03_3.v = _mm256_add_pd( tmpc03_3.v, c03_3.v );
159  c47_3.v = _mm256_add_pd( tmpc47_3.v, c47_3.v );
160  }
161 
162  // Scale -2
163  aa_tmp.v = _mm256_broadcast_sd( &neg2 );
164  c03_0.v = _mm256_mul_pd( aa_tmp.v, c03_0.v );
165  c03_1.v = _mm256_mul_pd( aa_tmp.v, c03_1.v );
166  c03_2.v = _mm256_mul_pd( aa_tmp.v, c03_2.v );
167  c03_3.v = _mm256_mul_pd( aa_tmp.v, c03_3.v );
168  c47_0.v = _mm256_mul_pd( aa_tmp.v, c47_0.v );
169  c47_1.v = _mm256_mul_pd( aa_tmp.v, c47_1.v );
170  c47_2.v = _mm256_mul_pd( aa_tmp.v, c47_2.v );
171  c47_3.v = _mm256_mul_pd( aa_tmp.v, c47_3.v );
172 
173 
174  aa_tmp.v = _mm256_load_pd( (double*)aa );
175  c03_0.v = _mm256_add_pd( aa_tmp.v, c03_0.v );
176  c03_1.v = _mm256_add_pd( aa_tmp.v, c03_1.v );
177  c03_2.v = _mm256_add_pd( aa_tmp.v, c03_2.v );
178  c03_3.v = _mm256_add_pd( aa_tmp.v, c03_3.v );
179 
180 
181  aa_tmp.v = _mm256_load_pd( (double*)( aa + 4 ) );
182  c47_0.v = _mm256_add_pd( aa_tmp.v, c47_0.v );
183  c47_1.v = _mm256_add_pd( aa_tmp.v, c47_1.v );
184  c47_2.v = _mm256_add_pd( aa_tmp.v, c47_2.v );
185  c47_3.v = _mm256_add_pd( aa_tmp.v, c47_3.v );
186 
187 
188  // Prefetch u
189  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( u ) );
190 
191  bb_tmp.v = _mm256_broadcast_sd( (double*)bb );
192  c03_0.v = _mm256_add_pd( bb_tmp.v, c03_0.v );
193  c47_0.v = _mm256_add_pd( bb_tmp.v, c47_0.v );
194 
195  bb_tmp.v = _mm256_broadcast_sd( (double*)( bb + 1 ) );
196  c03_1.v = _mm256_add_pd( bb_tmp.v, c03_1.v );
197  c47_1.v = _mm256_add_pd( bb_tmp.v, c47_1.v );
198 
199  bb_tmp.v = _mm256_broadcast_sd( (double*)( bb + 2 ) );
200  c03_2.v = _mm256_add_pd( bb_tmp.v, c03_2.v );
201  c47_2.v = _mm256_add_pd( bb_tmp.v, c47_2.v );
202 
203  bb_tmp.v = _mm256_broadcast_sd( (double*)( bb + 3 ) );
204  c03_3.v = _mm256_add_pd( bb_tmp.v, c03_3.v );
205  c47_3.v = _mm256_add_pd( bb_tmp.v, c47_3.v );
206 
207 
208  // Check if there is any illegle value
209  c_tmp.v = _mm256_broadcast_sd( &dzero );
210  c03_0.v = _mm256_max_pd( c_tmp.v, c03_0.v );
211  c03_1.v = _mm256_max_pd( c_tmp.v, c03_1.v );
212  c03_2.v = _mm256_max_pd( c_tmp.v, c03_2.v );
213  c03_3.v = _mm256_max_pd( c_tmp.v, c03_3.v );
214  c47_0.v = _mm256_max_pd( c_tmp.v, c47_0.v );
215  c47_1.v = _mm256_max_pd( c_tmp.v, c47_1.v );
216  c47_2.v = _mm256_max_pd( c_tmp.v, c47_2.v );
217  c47_3.v = _mm256_max_pd( c_tmp.v, c47_3.v );
218 
219 
220  aa_tmp.v = _mm256_broadcast_sd( &neghalf );
221  c03_0.v = _mm256_mul_pd( aa_tmp.v, c03_0.v );
222  c03_1.v = _mm256_mul_pd( aa_tmp.v, c03_1.v );
223  c03_2.v = _mm256_mul_pd( aa_tmp.v, c03_2.v );
224  c03_3.v = _mm256_mul_pd( aa_tmp.v, c03_3.v );
225  c47_0.v = _mm256_mul_pd( aa_tmp.v, c47_0.v );
226  c47_1.v = _mm256_mul_pd( aa_tmp.v, c47_1.v );
227  c47_2.v = _mm256_mul_pd( aa_tmp.v, c47_2.v );
228  c47_3.v = _mm256_mul_pd( aa_tmp.v, c47_3.v );
229 
230 
231  u03.v = _mm256_load_pd( (double*)u );
232  u47.v = _mm256_load_pd( (double*)( u + 4 ) );
233 
234  // Scale columns with hj
235  aa_tmp.v = _mm256_broadcast_sd( aux->hj + 0 );
236  c03_0.v = _mm256_mul_pd( aa_tmp.v, c03_0.v );
237  c47_0.v = _mm256_mul_pd( aa_tmp.v, c47_0.v );
238 
239  aa_tmp.v = _mm256_broadcast_sd( aux->hj + 1 );
240  c03_1.v = _mm256_mul_pd( aa_tmp.v, c03_1.v );
241  c47_1.v = _mm256_mul_pd( aa_tmp.v, c47_1.v );
242 
243  aa_tmp.v = _mm256_broadcast_sd( aux->hj + 2 );
244  c03_2.v = _mm256_mul_pd( aa_tmp.v, c03_2.v );
245  c47_2.v = _mm256_mul_pd( aa_tmp.v, c47_2.v );
246 
247  aa_tmp.v = _mm256_broadcast_sd( aux->hj + 3 );
248  c03_3.v = _mm256_mul_pd( aa_tmp.v, c03_3.v );
249  c47_3.v = _mm256_mul_pd( aa_tmp.v, c47_3.v );
250 
251  // Scale rows with hi
252  u03.v = _mm256_load_pd( aux->hi + 0 );
253  u47.v = _mm256_load_pd( aux->hi + 4 );
254 
255  c03_0.v = _mm256_mul_pd( u03.v, c03_0.v );
256  c03_1.v = _mm256_mul_pd( u03.v, c03_1.v );
257  c03_2.v = _mm256_mul_pd( u03.v, c03_2.v );
258  c03_3.v = _mm256_mul_pd( u03.v, c03_3.v );
259  c47_0.v = _mm256_mul_pd( u47.v, c47_0.v );
260  c47_1.v = _mm256_mul_pd( u47.v, c47_1.v );
261  c47_2.v = _mm256_mul_pd( u47.v, c47_2.v );
262  c47_3.v = _mm256_mul_pd( u47.v, c47_3.v );
263 
264  // Preload u03, u47
265  u03.v = _mm256_load_pd( (double*)u );
266  u47.v = _mm256_load_pd( (double*)( u + 4 ) );
267 
268 
269  // Prefetch u and w
270  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( u + 8 ) );
271  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( w ) );
272 
273  // c = exp( c );
274  #include "component/exp_int_d8x4.hpp"
275 
276  // Multiple rhs kernel summation.
277  #include "component/weighted_sum_int_d8x4.hpp"
278 
279  }; // end inline void operator
280 }; // end struct variable_bandwidth_gaussian_ref_d8x4
Definition: variable_bandwidth_gaussian_d8x4.hpp:11
Definition: variable_bandwidth_gaussian_d8x4.hpp:98
Definition: hmlp_internal.hpp:38
Definition: avx_type.h:13