1 #ifndef _PVFMM_INTRIN_WRAPPER_HPP_ 2 #define _PVFMM_INTRIN_WRAPPER_HPP_ 4 #include <pvfmm/math_utils.hpp> 5 #include <pvfmm/common.hpp> 12 #include <emmintrin.h> 15 #include <pmmintrin.h> 18 #include <immintrin.h> 21 #include <immintrin.h> 26 template <
class T>
inline T zero_intrin() {
return (T)0; }
28 template <
class T,
class Real>
inline T set_intrin(
const Real& a) {
return a; }
30 template <
class T,
class Real>
inline T load_intrin(Real
const* a) {
return a[0]; }
32 template <
class T,
class Real>
inline T bcast_intrin(Real
const* a) {
return a[0]; }
34 template <
class T,
class Real>
inline void store_intrin(Real* a,
const T& b) { a[0] = b; }
36 template <
class T>
inline T mul_intrin(
const T& a,
const T& b) {
return a * b; }
38 template <
class T>
inline T add_intrin(
const T& a,
const T& b) {
return a + b; }
40 template <
class T>
inline T sub_intrin(
const T& a,
const T& b) {
return a - b; }
42 template <
class T>
inline T cmplt_intrin(
const T& a,
const T& b) {
44 uint8_t* r_ =
reinterpret_cast<uint8_t*
>(&r);
46 for (
int i = 0; i <
sizeof(T); i++) r_[i] = ~(uint8_t)0;
50 template <
class T>
inline T and_intrin(
const T& a,
const T& b) {
52 const uint8_t* a_ =
reinterpret_cast<const uint8_t*
>(&a);
53 const uint8_t* b_ =
reinterpret_cast<const uint8_t*
>(&b);
54 uint8_t* r_ =
reinterpret_cast<uint8_t*
>(&r);
55 for (
int i = 0; i <
sizeof(T); i++) r_[i] = a_[i] & b_[i];
59 template <
class T>
inline T rsqrt_approx_intrin(
const T& r2) {
60 if (r2 != 0)
return 1.0 / pvfmm::sqrt<T>(r2);
64 template <
class T,
class Real>
inline void rsqrt_newton_intrin(T& rinv,
const T& r2,
const Real& nwtn_const) { rinv = rinv * (nwtn_const - r2 * rinv * rinv); }
66 template <
class T>
inline T rsqrt_single_intrin(
const T& r2) {
67 if (r2 != 0)
return 1.0 / pvfmm::sqrt<T>(r2);
71 template <
class T>
inline T max_intrin(
const T& a,
const T& b) {
78 template <
class T>
inline T min_intrin(
const T& a,
const T& b) {
85 template <
class T>
inline T sin_intrin(
const T& t) {
return pvfmm::sin<T>(t); }
87 template <
class T>
inline T cos_intrin(
const T& t) {
return pvfmm::cos<T>(t); }
90 template <>
inline __m128 zero_intrin() {
return _mm_setzero_ps(); }
92 template <>
inline __m128d zero_intrin() {
return _mm_setzero_pd(); }
94 template <>
inline __m128 set_intrin(
const float& a) {
return _mm_set_ps1(a); }
96 template <>
inline __m128d set_intrin(
const double& a) {
return _mm_set1_pd(a); }
98 template <>
inline __m128 load_intrin(
float const* a) {
return _mm_load_ps(a); }
100 template <>
inline __m128d load_intrin(
double const* a) {
return _mm_load_pd(a); }
102 template <>
inline __m128 bcast_intrin(
float const* a) {
return _mm_set_ps1(a[0]); }
104 template <>
inline __m128d bcast_intrin(
double const* a) {
return _mm_load_pd1(a); }
106 template <>
inline void store_intrin(
float* a,
const __m128& b) {
return _mm_store_ps(a, b); }
108 template <>
inline void store_intrin(
double* a,
const __m128d& b) {
return _mm_store_pd(a, b); }
110 template <>
inline __m128 mul_intrin(
const __m128& a,
const __m128& b) {
return _mm_mul_ps(a, b); }
112 template <>
inline __m128d mul_intrin(
const __m128d& a,
const __m128d& b) {
return _mm_mul_pd(a, b); }
114 template <>
inline __m128 add_intrin(
const __m128& a,
const __m128& b) {
return _mm_add_ps(a, b); }
116 template <>
inline __m128d add_intrin(
const __m128d& a,
const __m128d& b) {
return _mm_add_pd(a, b); }
118 template <>
inline __m128 sub_intrin(
const __m128& a,
const __m128& b) {
return _mm_sub_ps(a, b); }
120 template <>
inline __m128d sub_intrin(
const __m128d& a,
const __m128d& b) {
return _mm_sub_pd(a, b); }
122 template <>
inline __m128 cmplt_intrin(
const __m128& a,
const __m128& b) {
return _mm_cmplt_ps(a, b); }
124 template <>
inline __m128d cmplt_intrin(
const __m128d& a,
const __m128d& b) {
return _mm_cmplt_pd(a, b); }
126 template <>
inline __m128 and_intrin(
const __m128& a,
const __m128& b) {
return _mm_and_ps(a, b); }
128 template <>
inline __m128d and_intrin(
const __m128d& a,
const __m128d& b) {
return _mm_and_pd(a, b); }
130 template <>
inline __m128 rsqrt_approx_intrin(
const __m128& r2) {
131 #define VEC_INTRIN __m128 132 #define RSQRT_INTRIN(a) _mm_rsqrt_ps(a) 133 #define CMPEQ_INTRIN(a, b) _mm_cmpeq_ps(a, b) 134 #define ANDNOT_INTRIN(a, b) _mm_andnot_ps(a, b) 137 return ANDNOT_INTRIN(CMPEQ_INTRIN(r2, zero_intrin<VEC_INTRIN>()), RSQRT_INTRIN(r2));
145 template <>
inline __m128d rsqrt_approx_intrin(
const __m128d& r2) {
146 #define PD2PS(a) _mm_cvtpd_ps(a) 147 #define PS2PD(a) _mm_cvtps_pd(a) 148 return PS2PD(rsqrt_approx_intrin(PD2PS(r2)));
153 template <>
inline void rsqrt_newton_intrin(__m128& rinv,
const __m128& r2,
const float& nwtn_const) {
154 #define VEC_INTRIN __m128 157 rinv = mul_intrin(rinv, sub_intrin(set_intrin<VEC_INTRIN>(nwtn_const), mul_intrin(r2, mul_intrin(rinv, rinv))));
161 template <>
inline void rsqrt_newton_intrin(__m128d& rinv,
const __m128d& r2,
const double& nwtn_const) {
162 #define VEC_INTRIN __m128d 165 rinv = mul_intrin(rinv, sub_intrin(set_intrin<VEC_INTRIN>(nwtn_const), mul_intrin(r2, mul_intrin(rinv, rinv))));
169 template <>
inline __m128 rsqrt_single_intrin(
const __m128& r2) {
170 #define VEC_INTRIN __m128 171 VEC_INTRIN rinv = rsqrt_approx_intrin(r2);
172 rsqrt_newton_intrin(rinv, r2, (
float)3.0);
177 template <>
inline __m128d rsqrt_single_intrin(
const __m128d& r2) {
178 #define PD2PS(a) _mm_cvtpd_ps(a) 179 #define PS2PD(a) _mm_cvtps_pd(a) 180 return PS2PD(rsqrt_single_intrin(PD2PS(r2)));
185 template <>
inline __m128 max_intrin(
const __m128& a,
const __m128& b) {
return _mm_max_ps(a, b); }
187 template <>
inline __m128d max_intrin(
const __m128d& a,
const __m128d& b) {
return _mm_max_pd(a, b); }
189 template <>
inline __m128 min_intrin(
const __m128& a,
const __m128& b) {
return _mm_min_ps(a, b); }
191 template <>
inline __m128d min_intrin(
const __m128d& a,
const __m128d& b) {
return _mm_min_pd(a, b); }
193 #ifdef PVFMM_HAVE_INTEL_SVML 194 template <>
inline __m128 sin_intrin(
const __m128& t) {
return _mm_sin_ps(t); }
196 template <>
inline __m128 cos_intrin(
const __m128& t) {
return _mm_cos_ps(t); }
198 template <>
inline __m128d sin_intrin(
const __m128d& t) {
return _mm_sin_pd(t); }
200 template <>
inline __m128d cos_intrin(
const __m128d& t) {
return _mm_cos_pd(t); }
202 template <>
inline __m128 sin_intrin(
const __m128& t_) {
207 store_intrin(t.e, t_);
208 return _mm_set_ps(pvfmm::sin<float>(t.e[3]), pvfmm::sin<float>(t.e[2]), pvfmm::sin<float>(t.e[1]), pvfmm::sin<float>(t.e[0]));
211 template <>
inline __m128 cos_intrin(
const __m128& t_) {
216 store_intrin(t.e, t_);
217 return _mm_set_ps(pvfmm::cos<float>(t.e[3]), pvfmm::cos<float>(t.e[2]), pvfmm::cos<float>(t.e[1]), pvfmm::cos<float>(t.e[0]));
220 template <>
inline __m128d sin_intrin(
const __m128d& t_) {
225 store_intrin(t.e, t_);
226 return _mm_set_pd(pvfmm::sin<double>(t.e[1]), pvfmm::sin<double>(t.e[0]));
229 template <>
inline __m128d cos_intrin(
const __m128d& t_) {
234 store_intrin(t.e, t_);
235 return _mm_set_pd(pvfmm::cos<double>(t.e[1]), pvfmm::cos<double>(t.e[0]));
241 template <>
inline __m256 zero_intrin() {
return _mm256_setzero_ps(); }
243 template <>
inline __m256d zero_intrin() {
return _mm256_setzero_pd(); }
245 template <>
inline __m256 set_intrin(
const float& a) {
return _mm256_set_ps(a, a, a, a, a, a, a, a); }
247 template <>
inline __m256d set_intrin(
const double& a) {
return _mm256_set_pd(a, a, a, a); }
249 template <>
inline __m256 load_intrin(
float const* a) {
return _mm256_load_ps(a); }
251 template <>
inline __m256d load_intrin(
double const* a) {
return _mm256_load_pd(a); }
253 template <>
inline __m256 bcast_intrin(
float const* a) {
return _mm256_broadcast_ss(a); }
255 template <>
inline __m256d bcast_intrin(
double const* a) {
return _mm256_broadcast_sd(a); }
257 template <>
inline void store_intrin(
float* a,
const __m256& b) {
return _mm256_store_ps(a, b); }
259 template <>
inline void store_intrin(
double* a,
const __m256d& b) {
return _mm256_store_pd(a, b); }
261 template <>
inline __m256 mul_intrin(
const __m256& a,
const __m256& b) {
return _mm256_mul_ps(a, b); }
263 template <>
inline __m256d mul_intrin(
const __m256d& a,
const __m256d& b) {
return _mm256_mul_pd(a, b); }
265 template <>
inline __m256 add_intrin(
const __m256& a,
const __m256& b) {
return _mm256_add_ps(a, b); }
267 template <>
inline __m256d add_intrin(
const __m256d& a,
const __m256d& b) {
return _mm256_add_pd(a, b); }
269 template <>
inline __m256 sub_intrin(
const __m256& a,
const __m256& b) {
return _mm256_sub_ps(a, b); }
271 template <>
inline __m256d sub_intrin(
const __m256d& a,
const __m256d& b) {
return _mm256_sub_pd(a, b); }
273 template <>
inline __m256 cmplt_intrin(
const __m256& a,
const __m256& b) {
return _mm256_cmp_ps(a, b, _CMP_LT_OS); }
275 template <>
inline __m256d cmplt_intrin(
const __m256d& a,
const __m256d& b) {
return _mm256_cmp_pd(a, b, _CMP_LT_OS); }
277 template <>
inline __m256 and_intrin(
const __m256& a,
const __m256& b) {
return _mm256_and_ps(a, b); }
279 template <>
inline __m256d and_intrin(
const __m256d& a,
const __m256d& b) {
return _mm256_and_pd(a, b); }
281 template <>
inline __m256 rsqrt_approx_intrin(
const __m256& r2) {
282 #define VEC_INTRIN __m256 283 #define RSQRT_INTRIN(a) _mm256_rsqrt_ps(a) 284 #define CMPEQ_INTRIN(a, b) _mm256_cmp_ps(a, b, _CMP_EQ_OS) 285 #define ANDNOT_INTRIN(a, b) _mm256_andnot_ps(a, b) 288 return ANDNOT_INTRIN(CMPEQ_INTRIN(r2, zero_intrin<VEC_INTRIN>()), RSQRT_INTRIN(r2));
296 template <>
inline __m256d rsqrt_approx_intrin(
const __m256d& r2) {
297 #define PD2PS(a) _mm256_cvtpd_ps(a) 298 #define PS2PD(a) _mm256_cvtps_pd(a) 299 return PS2PD(rsqrt_approx_intrin(PD2PS(r2)));
304 template <>
inline void rsqrt_newton_intrin(__m256& rinv,
const __m256& r2,
const float& nwtn_const) {
305 #define VEC_INTRIN __m256 308 rinv = mul_intrin(rinv, sub_intrin(set_intrin<VEC_INTRIN>(nwtn_const), mul_intrin(r2, mul_intrin(rinv, rinv))));
312 template <>
inline void rsqrt_newton_intrin(__m256d& rinv,
const __m256d& r2,
const double& nwtn_const) {
313 #define VEC_INTRIN __m256d 316 rinv = mul_intrin(rinv, sub_intrin(set_intrin<VEC_INTRIN>(nwtn_const), mul_intrin(r2, mul_intrin(rinv, rinv))));
320 template <>
inline __m256 rsqrt_single_intrin(
const __m256& r2) {
321 #define VEC_INTRIN __m256 322 VEC_INTRIN rinv = rsqrt_approx_intrin(r2);
323 rsqrt_newton_intrin(rinv, r2, (
float)3.0);
328 template <>
inline __m256d rsqrt_single_intrin(
const __m256d& r2) {
329 #define PD2PS(a) _mm256_cvtpd_ps(a) 330 #define PS2PD(a) _mm256_cvtps_pd(a) 331 return PS2PD(rsqrt_single_intrin(PD2PS(r2)));
336 template <>
inline __m256 max_intrin(
const __m256& a,
const __m256& b) {
return _mm256_max_ps(a, b); }
338 template <>
inline __m256d max_intrin(
const __m256d& a,
const __m256d& b) {
return _mm256_max_pd(a, b); }
340 template <>
inline __m256 min_intrin(
const __m256& a,
const __m256& b) {
return _mm256_min_ps(a, b); }
342 template <>
inline __m256d min_intrin(
const __m256d& a,
const __m256d& b) {
return _mm256_min_pd(a, b); }
344 #ifdef PVFMM_HAVE_INTEL_SVML 345 template <>
inline __m256 sin_intrin(
const __m256& t) {
return _mm256_sin_ps(t); }
347 template <>
inline __m256 cos_intrin(
const __m256& t) {
return _mm256_cos_ps(t); }
349 template <>
inline __m256d sin_intrin(
const __m256d& t) {
return _mm256_sin_pd(t); }
351 template <>
inline __m256d cos_intrin(
const __m256d& t) {
return _mm256_cos_pd(t); }
353 template <>
inline __m256 sin_intrin(
const __m256& t_) {
358 store_intrin(t.e, t_);
359 return _mm256_set_ps(pvfmm::sin<float>(t.e[7]), pvfmm::sin<float>(t.e[6]), pvfmm::sin<float>(t.e[5]), pvfmm::sin<float>(t.e[4]), pvfmm::sin<float>(t.e[3]), pvfmm::sin<float>(t.e[2]), pvfmm::sin<float>(t.e[1]), pvfmm::sin<float>(t.e[0]));
362 template <>
inline __m256 cos_intrin(
const __m256& t_) {
367 store_intrin(t.e, t_);
368 return _mm256_set_ps(pvfmm::cos<float>(t.e[7]), pvfmm::cos<float>(t.e[6]), pvfmm::cos<float>(t.e[5]), pvfmm::cos<float>(t.e[4]), pvfmm::cos<float>(t.e[3]), pvfmm::cos<float>(t.e[2]), pvfmm::cos<float>(t.e[1]), pvfmm::cos<float>(t.e[0]));
371 template <>
inline __m256d sin_intrin(
const __m256d& t_) {
376 store_intrin(t.e, t_);
377 return _mm256_set_pd(pvfmm::sin<double>(t.e[3]), pvfmm::sin<double>(t.e[2]), pvfmm::sin<double>(t.e[1]), pvfmm::sin<double>(t.e[0]));
380 template <>
inline __m256d cos_intrin(
const __m256d& t_) {
385 store_intrin(t.e, t_);
386 return _mm256_set_pd(pvfmm::cos<double>(t.e[3]), pvfmm::cos<double>(t.e[2]), pvfmm::cos<double>(t.e[1]), pvfmm::cos<double>(t.e[0]));
391 template <
class VEC,
class Real>
inline VEC rsqrt_intrin0(VEC r2) {
404 rinv = rsqrt_single_intrin(r2);
406 rinv = rsqrt_approx_intrin(r2);
410 rsqrt_newton_intrin(rinv, r2, const_nwtn1);
413 rsqrt_newton_intrin(rinv, r2, const_nwtn2);
416 rsqrt_newton_intrin(rinv, r2, const_nwtn3);
427 template <
class VEC,
class Real>
inline VEC rsqrt_intrin1(VEC r2) {
434 scal = (NWTN0 ? 2 * scal * scal * scal : scal);
435 Real const_nwtn1 = 3 * scal * scal;
441 rinv = rsqrt_single_intrin(r2);
443 rinv = rsqrt_approx_intrin(r2);
447 rsqrt_newton_intrin(rinv, r2, const_nwtn1);
450 rsqrt_newton_intrin(rinv, r2, const_nwtn2);
453 rsqrt_newton_intrin(rinv, r2, const_nwtn3);
464 template <
class VEC,
class Real>
inline VEC rsqrt_intrin2(VEC r2) {
471 scal = (NWTN0 ? 2 * scal * scal * scal : scal);
472 Real const_nwtn1 = 3 * scal * scal;
473 scal = (NWTN1 ? 2 * scal * scal * scal : scal);
474 Real const_nwtn2 = 3 * scal * scal;
479 rinv = rsqrt_single_intrin(r2);
481 rinv = rsqrt_approx_intrin(r2);
485 rsqrt_newton_intrin(rinv, r2, const_nwtn1);
488 rsqrt_newton_intrin(rinv, r2, const_nwtn2);
491 rsqrt_newton_intrin(rinv, r2, const_nwtn3);
502 template <
class VEC,
class Real>
inline VEC rsqrt_intrin3(VEC r2) {
509 scal = (NWTN0 ? 2 * scal * scal * scal : scal);
510 Real const_nwtn1 = 3 * scal * scal;
511 scal = (NWTN1 ? 2 * scal * scal * scal : scal);
512 Real const_nwtn2 = 3 * scal * scal;
513 scal = (NWTN2 ? 2 * scal * scal * scal : scal);
514 Real const_nwtn3 = 3 * scal * scal;
518 rinv = rsqrt_single_intrin(r2);
520 rinv = rsqrt_approx_intrin(r2);
524 rsqrt_newton_intrin(rinv, r2, const_nwtn1);
527 rsqrt_newton_intrin(rinv, r2, const_nwtn2);
530 rsqrt_newton_intrin(rinv, r2, const_nwtn3);
542 #endif //_PVFMM_INTRIN_WRAPPER_HPP_ Definition: cheb_utils.hpp:12