8 #include <hmlp_util.hpp> 9 #include <hmlp_internal.hpp> 10 #include <primitives/gsknn.hpp> 17 const size_t pack_mr = 8;
18 const size_t pack_nr = 4;
19 const size_t align_size = 32;
20 const bool row_major =
true;
32 inline void operator()
36 double *a,
double *aa,
37 double *b,
double *bb,
int *bmap,
39 double *Keys,
int *Values,
int ldr,
52 v4df_t c03_0, c03_1, c03_2, c03_3;
53 v4df_t c47_0, c47_1, c47_2, c47_3;
54 v4df_t tmpc03_0, tmpc03_1, tmpc03_2, tmpc03_3;
55 v4df_t tmpc47_0, tmpc47_1, tmpc47_2, tmpc47_3;
67 __asm__
volatile(
"prefetcht0 0(%0) \n\t" : :
"r"( a ) );
68 __asm__
volatile(
"prefetcht2 0(%0) \n\t" : :
"r"( aux->b_next ) );
69 __asm__
volatile(
"prefetcht0 0(%0) \n\t" : :
"r"( c + 3 ) );
70 __asm__
volatile(
"prefetcht0 0(%0) \n\t" : :
"r"( c + 11 ) );
71 __asm__
volatile(
"prefetcht0 0(%0) \n\t" : :
"r"( c + 19 ) );
72 __asm__
volatile(
"prefetcht0 0(%0) \n\t" : :
"r"( c + 27 ) );
74 c03_0.v = _mm256_setzero_pd();
75 c03_1.v = _mm256_setzero_pd();
76 c03_2.v = _mm256_setzero_pd();
77 c03_3.v = _mm256_setzero_pd();
78 c47_0.v = _mm256_setzero_pd();
79 c47_1.v = _mm256_setzero_pd();
80 c47_2.v = _mm256_setzero_pd();
81 c47_3.v = _mm256_setzero_pd();
85 a03.v = _mm256_load_pd( (
double*)a );
87 a47.v = _mm256_load_pd( (
double*)( a + 4 ) );
89 b0.v = _mm256_load_pd( (
double*)b );
91 for ( i = 0; i < k_iter; ++i )
93 __asm__
volatile(
"prefetcht0 192(%0) \n\t" : :
"r"(a) );
96 A03.v = _mm256_load_pd( (
double*)( a + 8 ) );
98 c_tmp.v = _mm256_mul_pd( a03.v , b0.v );
99 c03_0.v = _mm256_add_pd( c_tmp.v, c03_0.v );
100 c_tmp.v = _mm256_mul_pd( a47.v , b0.v );
101 c47_0.v = _mm256_add_pd( c_tmp.v, c47_0.v );
104 A47.v = _mm256_load_pd( (
double*)( a + 12 ) );
107 b1.v = _mm256_shuffle_pd( b0.v, b0.v, 0x5 );
109 c_tmp.v = _mm256_mul_pd( a03.v , b1.v );
110 c03_1.v = _mm256_add_pd( c_tmp.v, c03_1.v );
111 c_tmp.v = _mm256_mul_pd( a47.v , b1.v );
112 c47_1.v = _mm256_add_pd( c_tmp.v, c47_1.v );
115 b2.v = _mm256_permute2f128_pd( b1.v, b1.v, 0x1 );
118 B0.v = _mm256_load_pd( (
double*)( b + 4 ) );
120 c_tmp.v = _mm256_mul_pd( a03.v , b2.v );
121 c03_2.v = _mm256_add_pd( c_tmp.v, c03_2.v );
122 c_tmp.v = _mm256_mul_pd( a47.v , b2.v );
123 c47_2.v = _mm256_add_pd( c_tmp.v, c47_2.v );
126 b3.v = _mm256_shuffle_pd( b2.v, b2.v, 0x5 );
128 c_tmp.v = _mm256_mul_pd( a03.v , b3.v );
129 c03_3.v = _mm256_add_pd( c_tmp.v, c03_3.v );
130 c_tmp.v = _mm256_mul_pd( a47.v , b3.v );
131 c47_3.v = _mm256_add_pd( c_tmp.v, c47_3.v );
135 __asm__
volatile(
"prefetcht0 512(%0) \n\t" : :
"r"(a) );
138 a03.v = _mm256_load_pd( (
double*)( a + 16 ) );
140 c_tmp.v = _mm256_mul_pd( A03.v , B0.v );
141 c03_0.v = _mm256_add_pd( c_tmp.v, c03_0.v );
143 b1.v = _mm256_shuffle_pd( B0.v, B0.v, 0x5 );
145 c_tmp.v = _mm256_mul_pd( A47.v , B0.v );
146 c47_0.v = _mm256_add_pd( c_tmp.v, c47_0.v );
147 c_tmp.v = _mm256_mul_pd( A03.v , b1.v );
148 c03_1.v = _mm256_add_pd( c_tmp.v, c03_1.v );
151 a47.v = _mm256_load_pd( (
double*)( a + 20 ) );
154 b2.v = _mm256_permute2f128_pd( b1.v, b1.v, 0x1 );
156 c_tmp.v = _mm256_mul_pd( A47.v , b1.v );
157 c47_1.v = _mm256_add_pd( c_tmp.v, c47_1.v );
158 c_tmp.v = _mm256_mul_pd( A03.v , b2.v );
159 c03_2.v = _mm256_add_pd( c_tmp.v, c03_2.v );
162 b3.v = _mm256_shuffle_pd( b2.v, b2.v, 0x5 );
164 c_tmp.v = _mm256_mul_pd( A47.v , b2.v );
165 c47_2.v = _mm256_add_pd( c_tmp.v, c47_2.v );
168 b0.v = _mm256_load_pd( (
double*)( b + 8 ) );
170 c_tmp.v = _mm256_mul_pd( A03.v , b3.v );
171 c03_3.v = _mm256_add_pd( c_tmp.v, c03_3.v );
172 c_tmp.v = _mm256_mul_pd( A47.v , b3.v );
173 c47_3.v = _mm256_add_pd( c_tmp.v, c47_3.v );
179 for ( i = 0; i < k_left; ++i )
181 a03.v = _mm256_load_pd( (
double*)a );
183 a47.v = _mm256_load_pd( (
double*)( a + 4 ) );
185 b0.v = _mm256_load_pd( (
double*)b );
187 c_tmp.v = _mm256_mul_pd( a03.v , b0.v );
188 c03_0.v = _mm256_add_pd( c_tmp.v, c03_0.v );
189 c_tmp.v = _mm256_mul_pd( a47.v , b0.v );
190 c47_0.v = _mm256_add_pd( c_tmp.v, c47_0.v );
193 b1.v = _mm256_shuffle_pd( b0.v, b0.v, 0x5 );
195 c_tmp.v = _mm256_mul_pd( a03.v , b1.v );
196 c03_1.v = _mm256_add_pd( c_tmp.v, c03_1.v );
197 c_tmp.v = _mm256_mul_pd( a47.v , b1.v );
198 c47_1.v = _mm256_add_pd( c_tmp.v, c47_1.v );
201 b2.v = _mm256_permute2f128_pd( b1.v, b1.v, 0x1 );
203 c_tmp.v = _mm256_mul_pd( a03.v , b2.v );
204 c03_2.v = _mm256_add_pd( c_tmp.v, c03_2.v );
205 c_tmp.v = _mm256_mul_pd( a47.v , b2.v );
206 c47_2.v = _mm256_add_pd( c_tmp.v, c47_2.v );
209 b3.v = _mm256_shuffle_pd( b2.v, b2.v, 0x5 );
211 c_tmp.v = _mm256_mul_pd( a03.v , b3.v );
212 c03_3.v = _mm256_add_pd( c_tmp.v, c03_3.v );
213 c_tmp.v = _mm256_mul_pd( a47.v , b3.v );
214 c47_3.v = _mm256_add_pd( c_tmp.v, c47_3.v );
222 __asm__
volatile(
"prefetcht0 0(%0) \n\t" : :
"r"( aa ) );
223 __asm__
volatile(
"prefetcht0 0(%0) \n\t" : :
"r"( bb ) );
226 tmpc03_0.v = _mm256_blend_pd( c03_0.v, c03_1.v, 0x6 );
227 tmpc03_1.v = _mm256_blend_pd( c03_1.v, c03_0.v, 0x6 );
229 tmpc03_2.v = _mm256_blend_pd( c03_2.v, c03_3.v, 0x6 );
230 tmpc03_3.v = _mm256_blend_pd( c03_3.v, c03_2.v, 0x6 );
232 tmpc47_0.v = _mm256_blend_pd( c47_0.v, c47_1.v, 0x6 );
233 tmpc47_1.v = _mm256_blend_pd( c47_1.v, c47_0.v, 0x6 );
235 tmpc47_2.v = _mm256_blend_pd( c47_2.v, c47_3.v, 0x6 );
236 tmpc47_3.v = _mm256_blend_pd( c47_3.v, c47_2.v, 0x6 );
238 c03_0.v = _mm256_permute2f128_pd( tmpc03_0.v, tmpc03_2.v, 0x30 );
239 c03_3.v = _mm256_permute2f128_pd( tmpc03_2.v, tmpc03_0.v, 0x30 );
241 c03_1.v = _mm256_permute2f128_pd( tmpc03_1.v, tmpc03_3.v, 0x30 );
242 c03_2.v = _mm256_permute2f128_pd( tmpc03_3.v, tmpc03_1.v, 0x30 );
244 c47_0.v = _mm256_permute2f128_pd( tmpc47_0.v, tmpc47_2.v, 0x30 );
245 c47_3.v = _mm256_permute2f128_pd( tmpc47_2.v, tmpc47_0.v, 0x30 );
247 c47_1.v = _mm256_permute2f128_pd( tmpc47_1.v, tmpc47_3.v, 0x30 );
248 c47_2.v = _mm256_permute2f128_pd( tmpc47_3.v, tmpc47_1.v, 0x30 );
253 c_tmp.v = _mm256_load_pd( c + 0 );
254 c03_0.v = _mm256_add_pd( c_tmp.v, c03_0.v );
256 c_tmp.v = _mm256_load_pd( c + 4 );
257 c47_0.v = _mm256_add_pd( c_tmp.v, c47_0.v );
259 c_tmp.v = _mm256_load_pd( c + 8 );
260 c03_1.v = _mm256_add_pd( c_tmp.v, c03_1.v );
262 c_tmp.v = _mm256_load_pd( c + 12 );
263 c47_1.v = _mm256_add_pd( c_tmp.v, c47_1.v );
265 c_tmp.v = _mm256_load_pd( c + 16 );
266 c03_2.v = _mm256_add_pd( c_tmp.v, c03_2.v );
268 c_tmp.v = _mm256_load_pd( c + 20 );
269 c47_2.v = _mm256_add_pd( c_tmp.v, c47_2.v );
271 c_tmp.v = _mm256_load_pd( c + 24 );
272 c03_3.v = _mm256_add_pd( c_tmp.v, c03_3.v );
274 c_tmp.v = _mm256_load_pd( c + 28 );
275 c47_3.v = _mm256_add_pd( c_tmp.v, c47_3.v );
279 __asm__
volatile(
"prefetcht0 0(%0) \n\t" : :
"r"( I ) );
280 __asm__
volatile(
"prefetcht0 0(%0) \n\t" : :
"r"( D ) );
282 aa_tmp.v = _mm256_broadcast_sd( &neg2 );
284 c03_0.v = _mm256_mul_pd( aa_tmp.v, c03_0.v );
285 c03_1.v = _mm256_mul_pd( aa_tmp.v, c03_1.v );
286 c03_2.v = _mm256_mul_pd( aa_tmp.v, c03_2.v );
287 c03_3.v = _mm256_mul_pd( aa_tmp.v, c03_3.v );
288 c47_0.v = _mm256_mul_pd( aa_tmp.v, c47_0.v );
289 c47_1.v = _mm256_mul_pd( aa_tmp.v, c47_1.v );
290 c47_2.v = _mm256_mul_pd( aa_tmp.v, c47_2.v );
291 c47_3.v = _mm256_mul_pd( aa_tmp.v, c47_3.v );
294 aa_tmp.v = _mm256_load_pd( (
double*)aa );
295 c03_0.v = _mm256_add_pd( aa_tmp.v, c03_0.v );
296 c03_1.v = _mm256_add_pd( aa_tmp.v, c03_1.v );
297 c03_2.v = _mm256_add_pd( aa_tmp.v, c03_2.v );
298 c03_3.v = _mm256_add_pd( aa_tmp.v, c03_3.v );
300 aa_tmp.v = _mm256_load_pd( (
double*)( aa + 4 ) );
301 c47_0.v = _mm256_add_pd( aa_tmp.v, c47_0.v );
302 c47_1.v = _mm256_add_pd( aa_tmp.v, c47_1.v );
303 c47_2.v = _mm256_add_pd( aa_tmp.v, c47_2.v );
304 c47_3.v = _mm256_add_pd( aa_tmp.v, c47_3.v );
307 bb_tmp.v = _mm256_broadcast_sd( (
double*)bb );
308 c03_0.v = _mm256_add_pd( bb_tmp.v, c03_0.v );
309 c47_0.v = _mm256_add_pd( bb_tmp.v, c47_0.v );
311 bb_tmp.v = _mm256_broadcast_sd( (
double*)( bb + 1 ) );
312 c03_1.v = _mm256_add_pd( bb_tmp.v, c03_1.v );
313 c47_1.v = _mm256_add_pd( bb_tmp.v, c47_1.v );
315 bb_tmp.v = _mm256_broadcast_sd( (
double*)( bb + 2 ) );
316 c03_2.v = _mm256_add_pd( bb_tmp.v, c03_2.v );
317 c47_2.v = _mm256_add_pd( bb_tmp.v, c47_2.v );
319 bb_tmp.v = _mm256_broadcast_sd( (
double*)( bb + 3 ) );
320 c03_3.v = _mm256_add_pd( bb_tmp.v, c03_3.v );
321 c47_3.v = _mm256_add_pd( bb_tmp.v, c47_3.v );
326 c_tmp.v = _mm256_broadcast_sd( &dzero );
327 c03_0.v = _mm256_max_pd( c_tmp.v, c03_0.v );
328 c03_1.v = _mm256_max_pd( c_tmp.v, c03_1.v );
329 c03_2.v = _mm256_max_pd( c_tmp.v, c03_2.v );
330 c03_3.v = _mm256_max_pd( c_tmp.v, c03_3.v );
331 c47_0.v = _mm256_max_pd( c_tmp.v, c47_0.v );
332 c47_1.v = _mm256_max_pd( c_tmp.v, c47_1.v );
333 c47_2.v = _mm256_max_pd( c_tmp.v, c47_2.v );
334 c47_3.v = _mm256_max_pd( c_tmp.v, c47_3.v );
338 tmpc03_0.v = _mm256_shuffle_pd( c03_0.v, c03_1.v, 0x0 );
339 tmpc03_1.v = _mm256_shuffle_pd( c03_0.v, c03_1.v, 0xF );
341 tmpc03_2.v = _mm256_shuffle_pd( c03_2.v, c03_3.v, 0x0 );
342 tmpc03_3.v = _mm256_shuffle_pd( c03_2.v, c03_3.v, 0xF );
344 tmpc47_0.v = _mm256_shuffle_pd( c47_0.v, c47_1.v, 0x0 );
345 tmpc47_1.v = _mm256_shuffle_pd( c47_0.v, c47_1.v, 0xF );
347 tmpc47_2.v = _mm256_shuffle_pd( c47_2.v, c47_3.v, 0x0 );
348 tmpc47_3.v = _mm256_shuffle_pd( c47_2.v, c47_3.v, 0xF );
350 c03_0.v = _mm256_permute2f128_pd( tmpc03_0.v, tmpc03_2.v, 0x20 );
351 c03_2.v = _mm256_permute2f128_pd( tmpc03_0.v, tmpc03_2.v, 0x31 );
353 c03_1.v = _mm256_permute2f128_pd( tmpc03_1.v, tmpc03_3.v, 0x20 );
354 c03_3.v = _mm256_permute2f128_pd( tmpc03_1.v, tmpc03_3.v, 0x31 );
356 c47_0.v = _mm256_permute2f128_pd( tmpc47_0.v, tmpc47_2.v, 0x20 );
357 c47_2.v = _mm256_permute2f128_pd( tmpc47_0.v, tmpc47_2.v, 0x31 );
359 c47_1.v = _mm256_permute2f128_pd( tmpc47_1.v, tmpc47_3.v, 0x20 );
360 c47_3.v = _mm256_permute2f128_pd( tmpc47_1.v, tmpc47_3.v, 0x31 );
366 aa_tmp.v = _mm256_broadcast_sd( D );
367 b0.v = _mm256_cmp_pd( c03_0.v, aa_tmp.v, 0x1 );
368 if ( !_mm256_testz_pd( b0.v, b0.v ) )
370 _mm256_store_pd( c , c03_0.v );
371 hmlp::heap_select<double>( aux->jb, r, c + 0, bmap, D + 0 * ldr, I + 0 * ldr );
376 aa_tmp.v = _mm256_broadcast_sd( D + ldr );
377 b0.v = _mm256_cmp_pd( c03_1.v, aa_tmp.v, 0x1 );
378 if ( !_mm256_testz_pd( b0.v, b0.v ) ) {
379 _mm256_store_pd( c + 4, c03_1.v );
380 hmlp::heap_select<double>( aux->jb, r, c + 4, bmap, D + 1 * ldr, I + 1 * ldr );
386 aa_tmp.v = _mm256_broadcast_sd( D + 2 * ldr );
387 b0.v = _mm256_cmp_pd( c03_2.v, aa_tmp.v, 0x1 );
388 if ( !_mm256_testz_pd( b0.v, b0.v ) ) {
389 _mm256_store_pd( c + 8, c03_2.v );
390 hmlp::heap_select<double>( aux->jb, r, c + 8, bmap, D + 2 * ldr, I + 2 * ldr );
396 aa_tmp.v = _mm256_broadcast_sd( D + 3 * ldr );
397 b0.v = _mm256_cmp_pd( c03_3.v, aa_tmp.v, 0x1 );
398 if ( !_mm256_testz_pd( b0.v, b0.v ) ) {
399 _mm256_store_pd( c + 12, c03_3.v );
400 hmlp::heap_select<double>( aux->jb, r, c + 12, bmap, D + 3 * ldr, I + 3 * ldr );
406 aa_tmp.v = _mm256_broadcast_sd( D + 4 * ldr );
407 b0.v = _mm256_cmp_pd( c47_0.v, aa_tmp.v, 0x1 );
408 if ( !_mm256_testz_pd( b0.v, b0.v ) ) {
409 _mm256_store_pd( c + 16, c47_0.v );
410 hmlp::heap_select<double>( aux->jb, r, c + 16, bmap, D + 4 * ldr, I + 4 * ldr );
416 aa_tmp.v = _mm256_broadcast_sd( D + 5 * ldr );
417 b0.v = _mm256_cmp_pd( c47_1.v, aa_tmp.v, 0x1 );
418 if ( !_mm256_testz_pd( b0.v, b0.v ) ) {
419 _mm256_store_pd( c + 20, c47_1.v );
420 hmlp::heap_select<double>( aux->jb, r, c + 20, bmap, D + 5 * ldr, I + 5 * ldr );
426 aa_tmp.v = _mm256_broadcast_sd( D + 6 * ldr );
427 b0.v = _mm256_cmp_pd( c47_2.v, aa_tmp.v, 0x1 );
428 if ( !_mm256_testz_pd( b0.v, b0.v ) ) {
429 _mm256_store_pd( c + 24, c47_2.v );
430 hmlp::heap_select<double>( aux->jb, r, c + 24, bmap, D + 6 * ldr, I + 6 * ldr );
436 aa_tmp.v = _mm256_broadcast_sd( D + 7 * ldr );
437 b0.v = _mm256_cmp_pd( c47_3.v, aa_tmp.v, 0x1 );
438 if ( !_mm256_testz_pd( b0.v, b0.v ) ) {
439 _mm256_store_pd( c + 28, c47_3.v );
440 hmlp::heap_select<double>( aux->jb, r, c + 28, bmap, D + 7 * ldr, I + 7 * ldr );
Definition: knn_d8x4.hpp:13
Definition: hmlp_internal.hpp:38
Definition: avx_type.h:13