6 #include <hmlp_internal.hpp> 13 inline void operator()
16 kernel_s<double> *kernel,
20 double *a,
double *a2,
21 double *b,
double *b2,
27 double c_reg[ 8 * 4 ] = { 0.0 };
29 for (
int p = 0; p < k; p ++ )
32 for (
int j = 0; j < 4; j ++ )
35 for (
int i = 0; i < 8; i ++ )
37 c_reg[ j * 8 + i ] += a[ p * 8 + i ] * b[ p * 4 + j ];
45 for (
int j = 0; j < 4; j ++ )
48 for (
int i = 0; i < 8; i ++ )
50 c_reg[ j * 8 + i ] += c[ j * ldc + i ];
56 printf(
"variable_bandwidth_gaussian_ref_d8x4: c_reg\n" );
57 for (
int i = 0; i < 8; i ++ )
59 for (
int j = 0; j < 4; j ++ )
62 printf(
"%E ", c_reg[ j * 8 + i ] );
69 for (
int j = 0; j < 4; j ++ )
72 for (
int i = 0; i < 8; i ++ )
74 c_reg[ j * 8 + i ] *= -2.0;
75 c_reg[ j * 8 + i ] += a2[ i ] + b2[ j ];
76 c_reg[ j * 8 + i ] *= -0.5;
77 c_reg[ j * 8 + i ] *= aux->hi[ i ];
78 c_reg[ j * 8 + i ] *= aux->hj[ j ];
79 c_reg[ j * 8 + i ] = exp( c_reg[ j * 8 + i ] );
84 for (
int j = 0; j < 4; j ++ )
87 for (
int i = 0; i < 8; i ++ )
89 u[ i ] += c_reg[ j * 8 + i ] * w[ j ];
100 inline void operator()
103 kernel_s<double> *ker,
107 double *a,
double *aa,
108 double *b,
double *bb,
116 double neghalf = -0.5;
119 v4df_t c03_0, c03_1, c03_2, c03_3;
120 v4df_t c47_0, c47_1, c47_2, c47_3;
121 v4df_t tmpc03_0, tmpc03_1, tmpc03_2, tmpc03_3;
122 v4df_t tmpc47_0, tmpc47_1, tmpc47_2, tmpc47_3;
124 v4df_t a03, a47, A03, A47;
125 v4df_t b0, b1, b2, b3, B0;
126 v4df_t c_tmp, aa_tmp, bb_tmp, w_tmp;
129 #include "component/rank_k_int_d8x4.hpp" 131 __asm__
volatile(
"prefetcht0 0(%0) \n\t" : :
"r"( aa ) );
132 __asm__
volatile(
"prefetcht0 0(%0) \n\t" : :
"r"( bb ) );
136 tmpc03_0.v = _mm256_load_pd( (
double*)( c ) );
137 tmpc47_0.v = _mm256_load_pd( (
double*)( c + 4 ) );
139 tmpc03_1.v = _mm256_load_pd( (
double*)( c + 1 * ldc ) );
140 tmpc47_1.v = _mm256_load_pd( (
double*)( c + 1 * ldc + 4 ) );
142 tmpc03_2.v = _mm256_load_pd( (
double*)( c + 2 * ldc ) );
143 tmpc47_2.v = _mm256_load_pd( (
double*)( c + 2 * ldc + 4 ) );
145 tmpc03_3.v = _mm256_load_pd( (
double*)( c + 3 * ldc ) );
146 tmpc47_3.v = _mm256_load_pd( (
double*)( c + 3 * ldc + 4 ) );
149 c03_0.v = _mm256_add_pd( tmpc03_0.v, c03_0.v );
150 c47_0.v = _mm256_add_pd( tmpc47_0.v, c47_0.v );
152 c03_1.v = _mm256_add_pd( tmpc03_1.v, c03_1.v );
153 c47_1.v = _mm256_add_pd( tmpc47_1.v, c47_1.v );
155 c03_2.v = _mm256_add_pd( tmpc03_2.v, c03_2.v );
156 c47_2.v = _mm256_add_pd( tmpc47_2.v, c47_2.v );
158 c03_3.v = _mm256_add_pd( tmpc03_3.v, c03_3.v );
159 c47_3.v = _mm256_add_pd( tmpc47_3.v, c47_3.v );
163 aa_tmp.v = _mm256_broadcast_sd( &neg2 );
164 c03_0.v = _mm256_mul_pd( aa_tmp.v, c03_0.v );
165 c03_1.v = _mm256_mul_pd( aa_tmp.v, c03_1.v );
166 c03_2.v = _mm256_mul_pd( aa_tmp.v, c03_2.v );
167 c03_3.v = _mm256_mul_pd( aa_tmp.v, c03_3.v );
168 c47_0.v = _mm256_mul_pd( aa_tmp.v, c47_0.v );
169 c47_1.v = _mm256_mul_pd( aa_tmp.v, c47_1.v );
170 c47_2.v = _mm256_mul_pd( aa_tmp.v, c47_2.v );
171 c47_3.v = _mm256_mul_pd( aa_tmp.v, c47_3.v );
174 aa_tmp.v = _mm256_load_pd( (
double*)aa );
175 c03_0.v = _mm256_add_pd( aa_tmp.v, c03_0.v );
176 c03_1.v = _mm256_add_pd( aa_tmp.v, c03_1.v );
177 c03_2.v = _mm256_add_pd( aa_tmp.v, c03_2.v );
178 c03_3.v = _mm256_add_pd( aa_tmp.v, c03_3.v );
181 aa_tmp.v = _mm256_load_pd( (
double*)( aa + 4 ) );
182 c47_0.v = _mm256_add_pd( aa_tmp.v, c47_0.v );
183 c47_1.v = _mm256_add_pd( aa_tmp.v, c47_1.v );
184 c47_2.v = _mm256_add_pd( aa_tmp.v, c47_2.v );
185 c47_3.v = _mm256_add_pd( aa_tmp.v, c47_3.v );
189 __asm__
volatile(
"prefetcht0 0(%0) \n\t" : :
"r"( u ) );
191 bb_tmp.v = _mm256_broadcast_sd( (
double*)bb );
192 c03_0.v = _mm256_add_pd( bb_tmp.v, c03_0.v );
193 c47_0.v = _mm256_add_pd( bb_tmp.v, c47_0.v );
195 bb_tmp.v = _mm256_broadcast_sd( (
double*)( bb + 1 ) );
196 c03_1.v = _mm256_add_pd( bb_tmp.v, c03_1.v );
197 c47_1.v = _mm256_add_pd( bb_tmp.v, c47_1.v );
199 bb_tmp.v = _mm256_broadcast_sd( (
double*)( bb + 2 ) );
200 c03_2.v = _mm256_add_pd( bb_tmp.v, c03_2.v );
201 c47_2.v = _mm256_add_pd( bb_tmp.v, c47_2.v );
203 bb_tmp.v = _mm256_broadcast_sd( (
double*)( bb + 3 ) );
204 c03_3.v = _mm256_add_pd( bb_tmp.v, c03_3.v );
205 c47_3.v = _mm256_add_pd( bb_tmp.v, c47_3.v );
209 c_tmp.v = _mm256_broadcast_sd( &dzero );
210 c03_0.v = _mm256_max_pd( c_tmp.v, c03_0.v );
211 c03_1.v = _mm256_max_pd( c_tmp.v, c03_1.v );
212 c03_2.v = _mm256_max_pd( c_tmp.v, c03_2.v );
213 c03_3.v = _mm256_max_pd( c_tmp.v, c03_3.v );
214 c47_0.v = _mm256_max_pd( c_tmp.v, c47_0.v );
215 c47_1.v = _mm256_max_pd( c_tmp.v, c47_1.v );
216 c47_2.v = _mm256_max_pd( c_tmp.v, c47_2.v );
217 c47_3.v = _mm256_max_pd( c_tmp.v, c47_3.v );
220 aa_tmp.v = _mm256_broadcast_sd( &neghalf );
221 c03_0.v = _mm256_mul_pd( aa_tmp.v, c03_0.v );
222 c03_1.v = _mm256_mul_pd( aa_tmp.v, c03_1.v );
223 c03_2.v = _mm256_mul_pd( aa_tmp.v, c03_2.v );
224 c03_3.v = _mm256_mul_pd( aa_tmp.v, c03_3.v );
225 c47_0.v = _mm256_mul_pd( aa_tmp.v, c47_0.v );
226 c47_1.v = _mm256_mul_pd( aa_tmp.v, c47_1.v );
227 c47_2.v = _mm256_mul_pd( aa_tmp.v, c47_2.v );
228 c47_3.v = _mm256_mul_pd( aa_tmp.v, c47_3.v );
231 u03.v = _mm256_load_pd( (
double*)u );
232 u47.v = _mm256_load_pd( (
double*)( u + 4 ) );
235 aa_tmp.v = _mm256_broadcast_sd( aux->hj + 0 );
236 c03_0.v = _mm256_mul_pd( aa_tmp.v, c03_0.v );
237 c47_0.v = _mm256_mul_pd( aa_tmp.v, c47_0.v );
239 aa_tmp.v = _mm256_broadcast_sd( aux->hj + 1 );
240 c03_1.v = _mm256_mul_pd( aa_tmp.v, c03_1.v );
241 c47_1.v = _mm256_mul_pd( aa_tmp.v, c47_1.v );
243 aa_tmp.v = _mm256_broadcast_sd( aux->hj + 2 );
244 c03_2.v = _mm256_mul_pd( aa_tmp.v, c03_2.v );
245 c47_2.v = _mm256_mul_pd( aa_tmp.v, c47_2.v );
247 aa_tmp.v = _mm256_broadcast_sd( aux->hj + 3 );
248 c03_3.v = _mm256_mul_pd( aa_tmp.v, c03_3.v );
249 c47_3.v = _mm256_mul_pd( aa_tmp.v, c47_3.v );
252 u03.v = _mm256_load_pd( aux->hi + 0 );
253 u47.v = _mm256_load_pd( aux->hi + 4 );
255 c03_0.v = _mm256_mul_pd( u03.v, c03_0.v );
256 c03_1.v = _mm256_mul_pd( u03.v, c03_1.v );
257 c03_2.v = _mm256_mul_pd( u03.v, c03_2.v );
258 c03_3.v = _mm256_mul_pd( u03.v, c03_3.v );
259 c47_0.v = _mm256_mul_pd( u47.v, c47_0.v );
260 c47_1.v = _mm256_mul_pd( u47.v, c47_1.v );
261 c47_2.v = _mm256_mul_pd( u47.v, c47_2.v );
262 c47_3.v = _mm256_mul_pd( u47.v, c47_3.v );
265 u03.v = _mm256_load_pd( (
double*)u );
266 u47.v = _mm256_load_pd( (
double*)( u + 4 ) );
270 __asm__
volatile(
"prefetcht0 0(%0) \n\t" : :
"r"( u + 8 ) );
271 __asm__
volatile(
"prefetcht0 0(%0) \n\t" : :
"r"( w ) );
274 #include "component/exp_int_d8x4.hpp" 277 #include "component/weighted_sum_int_d8x4.hpp" Definition: variable_bandwidth_gaussian_d8x4.hpp:11
Definition: variable_bandwidth_gaussian_d8x4.hpp:98
Definition: hmlp_internal.hpp:38
Definition: avx_type.h:13