HMLP: High-performance Machine Learning Primitives
knn_d8x4.hpp
1 #include <stdio.h>
2 #include <math.h>
3 
4 // AVX
5 #include <immintrin.h>
6 
7 #include <hmlp.h>
8 #include <hmlp_util.hpp>
9 #include <hmlp_internal.hpp>
10 #include <primitives/gsknn.hpp>
11 #include <avx_type.h> // self-defined vector type
12 
14 {
15  const size_t mr = 8;
16  const size_t nr = 4;
17  const size_t pack_mr = 8;
18  const size_t pack_nr = 4;
19  const size_t align_size = 32;
20  const bool row_major = true;
21 
22  //inline void operator()
23  //(
24  // int k,
25  // int r,
26  // double *a, double *aa,
27  // double *b, double *bb,
28  // double *c,
29  // aux_s<double, double, double, double> *aux,
30  // int *bmap
31  //) const
32  inline void operator()
33  (
34  int k,
35  int r,
36  double *a, double *aa,
37  double *b, double *bb, int *bmap,
38  double *c,
39  double *Keys, int *Values, int ldr,
41  ) const
42  {
43  int i, j; //ldr;
44  double *D = Keys;
45  int *I = Values;
46  //int *I = aux->I;
47  //double *D = aux->D;
48  // int ldc = aux->ldc;
49 
50  double neg2 = -2.0;
51  double dzero = 0.0;
52  v4df_t c03_0, c03_1, c03_2, c03_3;
53  v4df_t c47_0, c47_1, c47_2, c47_3;
54  v4df_t tmpc03_0, tmpc03_1, tmpc03_2, tmpc03_3;
55  v4df_t tmpc47_0, tmpc47_1, tmpc47_2, tmpc47_3;
56  v4df_t c_tmp;
57  v4df_t a03, a47;
58  v4df_t A03, A47; // prefetched A
59 
60  v4df_t b0, b1, b2, b3;
61  v4df_t B0; // prefetched B
62  v4df_t aa_tmp, bb_tmp;
63 
64  int k_iter = k / 2;
65  int k_left = k % 2;
66 
67  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( a ) );
68  __asm__ volatile( "prefetcht2 0(%0) \n\t" : :"r"( aux->b_next ) );
69  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( c + 3 ) );
70  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( c + 11 ) );
71  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( c + 19 ) );
72  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( c + 27 ) );
73 
74  c03_0.v = _mm256_setzero_pd();
75  c03_1.v = _mm256_setzero_pd();
76  c03_2.v = _mm256_setzero_pd();
77  c03_3.v = _mm256_setzero_pd();
78  c47_0.v = _mm256_setzero_pd();
79  c47_1.v = _mm256_setzero_pd();
80  c47_2.v = _mm256_setzero_pd();
81  c47_3.v = _mm256_setzero_pd();
82 
83 
84  // Load a03
85  a03.v = _mm256_load_pd( (double*)a );
86  // Load a47
87  a47.v = _mm256_load_pd( (double*)( a + 4 ) );
88  // Load ( b0, b1, b2, b3 )
89  b0.v = _mm256_load_pd( (double*)b );
90 
91  for ( i = 0; i < k_iter; ++i )
92  {
93  __asm__ volatile( "prefetcht0 192(%0) \n\t" : :"r"(a) );
94 
95  // Preload A03
96  A03.v = _mm256_load_pd( (double*)( a + 8 ) );
97 
98  c_tmp.v = _mm256_mul_pd( a03.v , b0.v );
99  c03_0.v = _mm256_add_pd( c_tmp.v, c03_0.v );
100  c_tmp.v = _mm256_mul_pd( a47.v , b0.v );
101  c47_0.v = _mm256_add_pd( c_tmp.v, c47_0.v );
102 
103  // Preload A47
104  A47.v = _mm256_load_pd( (double*)( a + 12 ) );
105 
106  // Shuffle b ( 1, 0, 3, 2 )
107  b1.v = _mm256_shuffle_pd( b0.v, b0.v, 0x5 );
108 
109  c_tmp.v = _mm256_mul_pd( a03.v , b1.v );
110  c03_1.v = _mm256_add_pd( c_tmp.v, c03_1.v );
111  c_tmp.v = _mm256_mul_pd( a47.v , b1.v );
112  c47_1.v = _mm256_add_pd( c_tmp.v, c47_1.v );
113 
114  // Permute b ( 3, 2, 1, 0 )
115  b2.v = _mm256_permute2f128_pd( b1.v, b1.v, 0x1 );
116 
117  // Preload B0
118  B0.v = _mm256_load_pd( (double*)( b + 4 ) );
119 
120  c_tmp.v = _mm256_mul_pd( a03.v , b2.v );
121  c03_2.v = _mm256_add_pd( c_tmp.v, c03_2.v );
122  c_tmp.v = _mm256_mul_pd( a47.v , b2.v );
123  c47_2.v = _mm256_add_pd( c_tmp.v, c47_2.v );
124 
125  // Shuffle b ( 3, 2, 1, 0 )
126  b3.v = _mm256_shuffle_pd( b2.v, b2.v, 0x5 );
127 
128  c_tmp.v = _mm256_mul_pd( a03.v , b3.v );
129  c03_3.v = _mm256_add_pd( c_tmp.v, c03_3.v );
130  c_tmp.v = _mm256_mul_pd( a47.v , b3.v );
131  c47_3.v = _mm256_add_pd( c_tmp.v, c47_3.v );
132 
133 
134  // Iteration #1
135  __asm__ volatile( "prefetcht0 512(%0) \n\t" : :"r"(a) );
136 
137  // Preload a03 ( next iteration )
138  a03.v = _mm256_load_pd( (double*)( a + 16 ) );
139 
140  c_tmp.v = _mm256_mul_pd( A03.v , B0.v );
141  c03_0.v = _mm256_add_pd( c_tmp.v, c03_0.v );
142 
143  b1.v = _mm256_shuffle_pd( B0.v, B0.v, 0x5 );
144 
145  c_tmp.v = _mm256_mul_pd( A47.v , B0.v );
146  c47_0.v = _mm256_add_pd( c_tmp.v, c47_0.v );
147  c_tmp.v = _mm256_mul_pd( A03.v , b1.v );
148  c03_1.v = _mm256_add_pd( c_tmp.v, c03_1.v );
149 
150  // Preload a47 ( next iteration )
151  a47.v = _mm256_load_pd( (double*)( a + 20 ) );
152 
153  // Permute b ( 3, 2, 1, 0 )
154  b2.v = _mm256_permute2f128_pd( b1.v, b1.v, 0x1 );
155 
156  c_tmp.v = _mm256_mul_pd( A47.v , b1.v );
157  c47_1.v = _mm256_add_pd( c_tmp.v, c47_1.v );
158  c_tmp.v = _mm256_mul_pd( A03.v , b2.v );
159  c03_2.v = _mm256_add_pd( c_tmp.v, c03_2.v );
160 
161  // Shuffle b ( 3, 2, 1, 0 )
162  b3.v = _mm256_shuffle_pd( b2.v, b2.v, 0x5 );
163 
164  c_tmp.v = _mm256_mul_pd( A47.v , b2.v );
165  c47_2.v = _mm256_add_pd( c_tmp.v, c47_2.v );
166 
167  // Load b0 ( next iteration )
168  b0.v = _mm256_load_pd( (double*)( b + 8 ) );
169 
170  c_tmp.v = _mm256_mul_pd( A03.v , b3.v );
171  c03_3.v = _mm256_add_pd( c_tmp.v, c03_3.v );
172  c_tmp.v = _mm256_mul_pd( A47.v , b3.v );
173  c47_3.v = _mm256_add_pd( c_tmp.v, c47_3.v );
174 
175  a += 16;
176  b += 8;
177  }
178 
179  for ( i = 0; i < k_left; ++i )
180  {
181  a03.v = _mm256_load_pd( (double*)a );
182 
183  a47.v = _mm256_load_pd( (double*)( a + 4 ) );
184 
185  b0.v = _mm256_load_pd( (double*)b );
186 
187  c_tmp.v = _mm256_mul_pd( a03.v , b0.v );
188  c03_0.v = _mm256_add_pd( c_tmp.v, c03_0.v );
189  c_tmp.v = _mm256_mul_pd( a47.v , b0.v );
190  c47_0.v = _mm256_add_pd( c_tmp.v, c47_0.v );
191 
192  // Shuffle b ( 1, 0, 3, 2 )
193  b1.v = _mm256_shuffle_pd( b0.v, b0.v, 0x5 );
194 
195  c_tmp.v = _mm256_mul_pd( a03.v , b1.v );
196  c03_1.v = _mm256_add_pd( c_tmp.v, c03_1.v );
197  c_tmp.v = _mm256_mul_pd( a47.v , b1.v );
198  c47_1.v = _mm256_add_pd( c_tmp.v, c47_1.v );
199 
200  // Permute b ( 3, 2, 1, 0 )
201  b2.v = _mm256_permute2f128_pd( b1.v, b1.v, 0x1 );
202 
203  c_tmp.v = _mm256_mul_pd( a03.v , b2.v );
204  c03_2.v = _mm256_add_pd( c_tmp.v, c03_2.v );
205  c_tmp.v = _mm256_mul_pd( a47.v , b2.v );
206  c47_2.v = _mm256_add_pd( c_tmp.v, c47_2.v );
207 
208  // Shuffle b ( 3, 2, 1, 0 )
209  b3.v = _mm256_shuffle_pd( b2.v, b2.v, 0x5 );
210 
211  c_tmp.v = _mm256_mul_pd( a03.v , b3.v );
212  c03_3.v = _mm256_add_pd( c_tmp.v, c03_3.v );
213  c_tmp.v = _mm256_mul_pd( a47.v , b3.v );
214  c47_3.v = _mm256_add_pd( c_tmp.v, c47_3.v );
215 
216  a += 8;
217  b += 4;
218  }
219 
220 
221  // Prefetch aa and bb
222  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( aa ) );
223  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( bb ) );
224 
225 
226  tmpc03_0.v = _mm256_blend_pd( c03_0.v, c03_1.v, 0x6 );
227  tmpc03_1.v = _mm256_blend_pd( c03_1.v, c03_0.v, 0x6 );
228 
229  tmpc03_2.v = _mm256_blend_pd( c03_2.v, c03_3.v, 0x6 );
230  tmpc03_3.v = _mm256_blend_pd( c03_3.v, c03_2.v, 0x6 );
231 
232  tmpc47_0.v = _mm256_blend_pd( c47_0.v, c47_1.v, 0x6 );
233  tmpc47_1.v = _mm256_blend_pd( c47_1.v, c47_0.v, 0x6 );
234 
235  tmpc47_2.v = _mm256_blend_pd( c47_2.v, c47_3.v, 0x6 );
236  tmpc47_3.v = _mm256_blend_pd( c47_3.v, c47_2.v, 0x6 );
237 
238  c03_0.v = _mm256_permute2f128_pd( tmpc03_0.v, tmpc03_2.v, 0x30 );
239  c03_3.v = _mm256_permute2f128_pd( tmpc03_2.v, tmpc03_0.v, 0x30 );
240 
241  c03_1.v = _mm256_permute2f128_pd( tmpc03_1.v, tmpc03_3.v, 0x30 );
242  c03_2.v = _mm256_permute2f128_pd( tmpc03_3.v, tmpc03_1.v, 0x30 );
243 
244  c47_0.v = _mm256_permute2f128_pd( tmpc47_0.v, tmpc47_2.v, 0x30 );
245  c47_3.v = _mm256_permute2f128_pd( tmpc47_2.v, tmpc47_0.v, 0x30 );
246 
247  c47_1.v = _mm256_permute2f128_pd( tmpc47_1.v, tmpc47_3.v, 0x30 );
248  c47_2.v = _mm256_permute2f128_pd( tmpc47_3.v, tmpc47_1.v, 0x30 );
249 
250 
251  if ( aux->pc )
252  {
253  c_tmp.v = _mm256_load_pd( c + 0 );
254  c03_0.v = _mm256_add_pd( c_tmp.v, c03_0.v );
255 
256  c_tmp.v = _mm256_load_pd( c + 4 );
257  c47_0.v = _mm256_add_pd( c_tmp.v, c47_0.v );
258 
259  c_tmp.v = _mm256_load_pd( c + 8 );
260  c03_1.v = _mm256_add_pd( c_tmp.v, c03_1.v );
261 
262  c_tmp.v = _mm256_load_pd( c + 12 );
263  c47_1.v = _mm256_add_pd( c_tmp.v, c47_1.v );
264 
265  c_tmp.v = _mm256_load_pd( c + 16 );
266  c03_2.v = _mm256_add_pd( c_tmp.v, c03_2.v );
267 
268  c_tmp.v = _mm256_load_pd( c + 20 );
269  c47_2.v = _mm256_add_pd( c_tmp.v, c47_2.v );
270 
271  c_tmp.v = _mm256_load_pd( c + 24 );
272  c03_3.v = _mm256_add_pd( c_tmp.v, c03_3.v );
273 
274  c_tmp.v = _mm256_load_pd( c + 28 );
275  c47_3.v = _mm256_add_pd( c_tmp.v, c47_3.v );
276  }
277 
278 
279  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( I ) );
280  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( D ) );
281 
282  aa_tmp.v = _mm256_broadcast_sd( &neg2 );
283 
284  c03_0.v = _mm256_mul_pd( aa_tmp.v, c03_0.v );
285  c03_1.v = _mm256_mul_pd( aa_tmp.v, c03_1.v );
286  c03_2.v = _mm256_mul_pd( aa_tmp.v, c03_2.v );
287  c03_3.v = _mm256_mul_pd( aa_tmp.v, c03_3.v );
288  c47_0.v = _mm256_mul_pd( aa_tmp.v, c47_0.v );
289  c47_1.v = _mm256_mul_pd( aa_tmp.v, c47_1.v );
290  c47_2.v = _mm256_mul_pd( aa_tmp.v, c47_2.v );
291  c47_3.v = _mm256_mul_pd( aa_tmp.v, c47_3.v );
292 
293 
294  aa_tmp.v = _mm256_load_pd( (double*)aa );
295  c03_0.v = _mm256_add_pd( aa_tmp.v, c03_0.v );
296  c03_1.v = _mm256_add_pd( aa_tmp.v, c03_1.v );
297  c03_2.v = _mm256_add_pd( aa_tmp.v, c03_2.v );
298  c03_3.v = _mm256_add_pd( aa_tmp.v, c03_3.v );
299 
300  aa_tmp.v = _mm256_load_pd( (double*)( aa + 4 ) );
301  c47_0.v = _mm256_add_pd( aa_tmp.v, c47_0.v );
302  c47_1.v = _mm256_add_pd( aa_tmp.v, c47_1.v );
303  c47_2.v = _mm256_add_pd( aa_tmp.v, c47_2.v );
304  c47_3.v = _mm256_add_pd( aa_tmp.v, c47_3.v );
305 
306 
307  bb_tmp.v = _mm256_broadcast_sd( (double*)bb );
308  c03_0.v = _mm256_add_pd( bb_tmp.v, c03_0.v );
309  c47_0.v = _mm256_add_pd( bb_tmp.v, c47_0.v );
310 
311  bb_tmp.v = _mm256_broadcast_sd( (double*)( bb + 1 ) );
312  c03_1.v = _mm256_add_pd( bb_tmp.v, c03_1.v );
313  c47_1.v = _mm256_add_pd( bb_tmp.v, c47_1.v );
314 
315  bb_tmp.v = _mm256_broadcast_sd( (double*)( bb + 2 ) );
316  c03_2.v = _mm256_add_pd( bb_tmp.v, c03_2.v );
317  c47_2.v = _mm256_add_pd( bb_tmp.v, c47_2.v );
318 
319  bb_tmp.v = _mm256_broadcast_sd( (double*)( bb + 3 ) );
320  c03_3.v = _mm256_add_pd( bb_tmp.v, c03_3.v );
321  c47_3.v = _mm256_add_pd( bb_tmp.v, c47_3.v );
322 
323 
324 
325  // Check if there is any illegle value
326  c_tmp.v = _mm256_broadcast_sd( &dzero );
327  c03_0.v = _mm256_max_pd( c_tmp.v, c03_0.v );
328  c03_1.v = _mm256_max_pd( c_tmp.v, c03_1.v );
329  c03_2.v = _mm256_max_pd( c_tmp.v, c03_2.v );
330  c03_3.v = _mm256_max_pd( c_tmp.v, c03_3.v );
331  c47_0.v = _mm256_max_pd( c_tmp.v, c47_0.v );
332  c47_1.v = _mm256_max_pd( c_tmp.v, c47_1.v );
333  c47_2.v = _mm256_max_pd( c_tmp.v, c47_2.v );
334  c47_3.v = _mm256_max_pd( c_tmp.v, c47_3.v );
335 
336 
337  // Transpose c03/c47 _0, _1, _2, _3 to be the row vector
338  tmpc03_0.v = _mm256_shuffle_pd( c03_0.v, c03_1.v, 0x0 );
339  tmpc03_1.v = _mm256_shuffle_pd( c03_0.v, c03_1.v, 0xF );
340 
341  tmpc03_2.v = _mm256_shuffle_pd( c03_2.v, c03_3.v, 0x0 );
342  tmpc03_3.v = _mm256_shuffle_pd( c03_2.v, c03_3.v, 0xF );
343 
344  tmpc47_0.v = _mm256_shuffle_pd( c47_0.v, c47_1.v, 0x0 );
345  tmpc47_1.v = _mm256_shuffle_pd( c47_0.v, c47_1.v, 0xF );
346 
347  tmpc47_2.v = _mm256_shuffle_pd( c47_2.v, c47_3.v, 0x0 );
348  tmpc47_3.v = _mm256_shuffle_pd( c47_2.v, c47_3.v, 0xF );
349 
350  c03_0.v = _mm256_permute2f128_pd( tmpc03_0.v, tmpc03_2.v, 0x20 );
351  c03_2.v = _mm256_permute2f128_pd( tmpc03_0.v, tmpc03_2.v, 0x31 );
352 
353  c03_1.v = _mm256_permute2f128_pd( tmpc03_1.v, tmpc03_3.v, 0x20 );
354  c03_3.v = _mm256_permute2f128_pd( tmpc03_1.v, tmpc03_3.v, 0x31 );
355 
356  c47_0.v = _mm256_permute2f128_pd( tmpc47_0.v, tmpc47_2.v, 0x20 );
357  c47_2.v = _mm256_permute2f128_pd( tmpc47_0.v, tmpc47_2.v, 0x31 );
358 
359  c47_1.v = _mm256_permute2f128_pd( tmpc47_1.v, tmpc47_3.v, 0x20 );
360  c47_3.v = _mm256_permute2f128_pd( tmpc47_1.v, tmpc47_3.v, 0x31 );
361 
362 
363  //ldr = aux->ldr;
364 
365 
366  aa_tmp.v = _mm256_broadcast_sd( D );
367  b0.v = _mm256_cmp_pd( c03_0.v, aa_tmp.v, 0x1 );
368  if ( !_mm256_testz_pd( b0.v, b0.v ) )
369  {
370  _mm256_store_pd( c , c03_0.v );
371  hmlp::heap_select<double>( aux->jb, r, c + 0, bmap, D + 0 * ldr, I + 0 * ldr );
372  }
373 
374  if ( aux->ib > 1 )
375  {
376  aa_tmp.v = _mm256_broadcast_sd( D + ldr );
377  b0.v = _mm256_cmp_pd( c03_1.v, aa_tmp.v, 0x1 );
378  if ( !_mm256_testz_pd( b0.v, b0.v ) ) {
379  _mm256_store_pd( c + 4, c03_1.v );
380  hmlp::heap_select<double>( aux->jb, r, c + 4, bmap, D + 1 * ldr, I + 1 * ldr );
381  }
382  }
383 
384  if ( aux->ib > 2 )
385  {
386  aa_tmp.v = _mm256_broadcast_sd( D + 2 * ldr );
387  b0.v = _mm256_cmp_pd( c03_2.v, aa_tmp.v, 0x1 );
388  if ( !_mm256_testz_pd( b0.v, b0.v ) ) {
389  _mm256_store_pd( c + 8, c03_2.v );
390  hmlp::heap_select<double>( aux->jb, r, c + 8, bmap, D + 2 * ldr, I + 2 * ldr );
391  }
392  }
393 
394  if ( aux->ib > 3 )
395  {
396  aa_tmp.v = _mm256_broadcast_sd( D + 3 * ldr );
397  b0.v = _mm256_cmp_pd( c03_3.v, aa_tmp.v, 0x1 );
398  if ( !_mm256_testz_pd( b0.v, b0.v ) ) {
399  _mm256_store_pd( c + 12, c03_3.v );
400  hmlp::heap_select<double>( aux->jb, r, c + 12, bmap, D + 3 * ldr, I + 3 * ldr );
401  }
402  }
403 
404  if ( aux->ib > 4 )
405  {
406  aa_tmp.v = _mm256_broadcast_sd( D + 4 * ldr );
407  b0.v = _mm256_cmp_pd( c47_0.v, aa_tmp.v, 0x1 );
408  if ( !_mm256_testz_pd( b0.v, b0.v ) ) {
409  _mm256_store_pd( c + 16, c47_0.v );
410  hmlp::heap_select<double>( aux->jb, r, c + 16, bmap, D + 4 * ldr, I + 4 * ldr );
411  }
412  }
413 
414  if ( aux->ib > 5 )
415  {
416  aa_tmp.v = _mm256_broadcast_sd( D + 5 * ldr );
417  b0.v = _mm256_cmp_pd( c47_1.v, aa_tmp.v, 0x1 );
418  if ( !_mm256_testz_pd( b0.v, b0.v ) ) {
419  _mm256_store_pd( c + 20, c47_1.v );
420  hmlp::heap_select<double>( aux->jb, r, c + 20, bmap, D + 5 * ldr, I + 5 * ldr );
421  }
422  }
423 
424  if ( aux->ib > 6 )
425  {
426  aa_tmp.v = _mm256_broadcast_sd( D + 6 * ldr );
427  b0.v = _mm256_cmp_pd( c47_2.v, aa_tmp.v, 0x1 );
428  if ( !_mm256_testz_pd( b0.v, b0.v ) ) {
429  _mm256_store_pd( c + 24, c47_2.v );
430  hmlp::heap_select<double>( aux->jb, r, c + 24, bmap, D + 6 * ldr, I + 6 * ldr );
431  }
432  }
433 
434  if ( aux->ib > 7 )
435  {
436  aa_tmp.v = _mm256_broadcast_sd( D + 7 * ldr );
437  b0.v = _mm256_cmp_pd( c47_3.v, aa_tmp.v, 0x1 );
438  if ( !_mm256_testz_pd( b0.v, b0.v ) ) {
439  _mm256_store_pd( c + 28, c47_3.v );
440  hmlp::heap_select<double>( aux->jb, r, c + 28, bmap, D + 7 * ldr, I + 7 * ldr );
441  }
442  }
443  }
444 };
Definition: knn_d8x4.hpp:13
Definition: hmlp_internal.hpp:38
Definition: avx_type.h:13