HMLP: High-performance Machine Learning Primitives
gaussian_d24x8.hpp
1 #include <stdio.h>
2 #include <math.h>
3 #include <omp.h>
4 #include <immintrin.h> // AVX
5 
6 #include <hmlp.h>
7 #include <hmlp_internal.hpp>
8 #include <avx_type.h> // self-defined vector type
9 
10 // #define DEBUG_MICRO 1
11 
12 
13 // void gaussian_ref_d24x8(
14 // int k,
15 // int rhs,
16 // //double *h,
17 // double *u,
18 // double *aa,
19 // double *a,
20 // double *bb,
21 // double *b,
22 // double *w,
23 // double *c,
24 // ks_t *ker,
25 // aux_t *aux
26 // )
27 // {
28 // int i, j, p;
29 // double K[ 24 * 8 ] = {{ 0.0 }};
30 //
31 // #include <rank_k_ref_d24x8.h>
32 //
33 // // Gaussian kernel
34 // for ( j = 0; j < 8; j ++ ) {
35 // for ( i = 0; i < 24; i ++ ) {
36 // K[ j * 24 + i ] = aa[ i ] - 2.0 * K[ j * 24 + i ] + bb[ j ];
37 // K[ j * 24+ i ] = exp( ker->scal * K[ j * 24 + i ] );
38 // u[ i ] += K[ j * 24 + i ] * w[ j ];
39 // }
40 // }
41 // }
42 //
43 //
44 // void gaussian_int_s48x8(
45 // int k,
46 // int rhs,
47 // //float *h,
48 // float *u,
49 // float *aa,
50 // float *a,
51 // float *bb,
52 // float *b,
53 // float *w,
54 // float *c,
55 // ks_t *ker,
56 // aux_t *aux
57 // )
58 // {
59 // printf( "gaussian_int_s48x8 not yet implemented.\n" );
60 // }
61 
62 
63 
64 
65 
67 {
68  inline void operator()(
69  ks_t *ker,
70  int k,
71  int nrhs,
72  double *u,
73  double *a, double *aa,
74  double *b, double *bb,
75  double *w,
76  double *c, int ldc,
78  {
79  int i;
80  double alpha = ker->scal;
81 
82  // 24 avx512 registers
83  v8df_t c07_0, c07_1, c07_2, c07_3, c07_4, c07_5, c07_6, c07_7;
84  v8df_t c15_0, c15_1, c15_2, c15_3, c15_4, c15_5, c15_6, c15_7;
85  v8df_t c23_0, c23_1, c23_2, c23_3, c23_4, c23_5, c23_6, c23_7;
86 
87  // 8 avx512 registers
88  v8df_t a07, a15, a23;
89  v8df_t A07, A15, A23;
90  v8df_t b0, b1;
91 
92 
93 
94  //printf( "a\n" );
95  //printf( "%E, %E, %E, %E, %E, %E, %E, %E\n", a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7] );
96  //printf( "b\n" );
97  //printf( "%E, %E, %E, %E, %E, %E, %E, %E\n", b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7] );
98 
99  #include <rank_k_int_d24x8.segment>
100 
101 // #pragma omp critical
102 // {
103 // printf( "c (tid %d)\n", omp_get_thread_num() );
104 // printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[0], c07_1.d[0], c07_2.d[0], c07_3.d[0], c07_4.d[0], c07_5.d[0], c07_6.d[0], c07_7.d[0] );
105 // printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[1], c07_1.d[1], c07_2.d[1], c07_3.d[1], c07_4.d[1], c07_5.d[1], c07_6.d[1], c07_7.d[1] );
106 // printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[2], c07_1.d[2], c07_2.d[2], c07_3.d[2], c07_4.d[2], c07_5.d[2], c07_6.d[2], c07_7.d[2] );
107 // printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[3], c07_1.d[3], c07_2.d[3], c07_3.d[3], c07_4.d[3], c07_5.d[3], c07_6.d[3], c07_7.d[3] );
108 // printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[4], c07_1.d[4], c07_2.d[4], c07_3.d[4], c07_4.d[4], c07_5.d[4], c07_6.d[4], c07_7.d[4] );
109 // printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[5], c07_1.d[5], c07_2.d[5], c07_3.d[5], c07_4.d[5], c07_5.d[5], c07_6.d[5], c07_7.d[5] );
110 // printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[6], c07_1.d[6], c07_2.d[6], c07_3.d[6], c07_4.d[6], c07_5.d[6], c07_6.d[6], c07_7.d[6] );
111 // printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[7], c07_1.d[7], c07_2.d[7], c07_3.d[7], c07_4.d[7], c07_5.d[7], c07_6.d[7], c07_7.d[7] );
112 // }
113 
114  #include <sq2nrm_int_d24x8.segment>
115 
116  //printf( "sq2\n" );
117  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[0], c07_1.d[0], c07_2.d[0], c07_3.d[0], c07_4.d[0], c07_5.d[0], c07_6.d[0], c07_7.d[0] );
118  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[1], c07_1.d[1], c07_2.d[1], c07_3.d[1], c07_4.d[1], c07_5.d[1], c07_6.d[1], c07_7.d[1] );
119  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[2], c07_1.d[2], c07_2.d[2], c07_3.d[2], c07_4.d[2], c07_5.d[2], c07_6.d[2], c07_7.d[2] );
120  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[3], c07_1.d[3], c07_2.d[3], c07_3.d[3], c07_4.d[3], c07_5.d[3], c07_6.d[3], c07_7.d[3] );
121  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[4], c07_1.d[4], c07_2.d[4], c07_3.d[4], c07_4.d[4], c07_5.d[4], c07_6.d[4], c07_7.d[4] );
122  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[5], c07_1.d[5], c07_2.d[5], c07_3.d[5], c07_4.d[5], c07_5.d[5], c07_6.d[5], c07_7.d[5] );
123  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[6], c07_1.d[6], c07_2.d[6], c07_3.d[6], c07_4.d[6], c07_5.d[6], c07_6.d[6], c07_7.d[6] );
124  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[7], c07_1.d[7], c07_2.d[7], c07_3.d[7], c07_4.d[7], c07_5.d[7], c07_6.d[7], c07_7.d[7] );
125 
126 
127  // Scale before the kernel evaluation
128  a07.v = _mm512_set1_pd( alpha );
129  c07_0.v = _mm512_mul_pd( a07.v, c07_0.v );
130  c07_1.v = _mm512_mul_pd( a07.v, c07_1.v );
131  c07_2.v = _mm512_mul_pd( a07.v, c07_2.v );
132  c07_3.v = _mm512_mul_pd( a07.v, c07_3.v );
133  c07_4.v = _mm512_mul_pd( a07.v, c07_4.v );
134  c07_5.v = _mm512_mul_pd( a07.v, c07_5.v );
135  c07_6.v = _mm512_mul_pd( a07.v, c07_6.v );
136  c07_7.v = _mm512_mul_pd( a07.v, c07_7.v );
137 
138  c15_0.v = _mm512_mul_pd( a07.v, c15_0.v );
139  c15_1.v = _mm512_mul_pd( a07.v, c15_1.v );
140  c15_2.v = _mm512_mul_pd( a07.v, c15_2.v );
141  c15_3.v = _mm512_mul_pd( a07.v, c15_3.v );
142  c15_4.v = _mm512_mul_pd( a07.v, c15_4.v );
143  c15_5.v = _mm512_mul_pd( a07.v, c15_5.v );
144  c15_6.v = _mm512_mul_pd( a07.v, c15_6.v );
145  c15_7.v = _mm512_mul_pd( a07.v, c15_7.v );
146 
147  c23_0.v = _mm512_mul_pd( a07.v, c23_0.v );
148  c23_1.v = _mm512_mul_pd( a07.v, c23_1.v );
149  c23_2.v = _mm512_mul_pd( a07.v, c23_2.v );
150  c23_3.v = _mm512_mul_pd( a07.v, c23_3.v );
151  c23_4.v = _mm512_mul_pd( a07.v, c23_4.v );
152  c23_5.v = _mm512_mul_pd( a07.v, c23_5.v );
153  c23_6.v = _mm512_mul_pd( a07.v, c23_6.v );
154  c23_7.v = _mm512_mul_pd( a07.v, c23_7.v );
155 
156  //printf( "-1/(2h^2)\n" );
157  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[0], c07_1.d[0], c07_2.d[0], c07_3.d[0], c07_4.d[0], c07_5.d[0], c07_6.d[0], c07_7.d[0] );
158  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[1], c07_1.d[1], c07_2.d[1], c07_3.d[1], c07_4.d[1], c07_5.d[1], c07_6.d[1], c07_7.d[1] );
159  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[2], c07_1.d[2], c07_2.d[2], c07_3.d[2], c07_4.d[2], c07_5.d[2], c07_6.d[2], c07_7.d[2] );
160  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[3], c07_1.d[3], c07_2.d[3], c07_3.d[3], c07_4.d[3], c07_5.d[3], c07_6.d[3], c07_7.d[3] );
161  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[4], c07_1.d[4], c07_2.d[4], c07_3.d[4], c07_4.d[4], c07_5.d[4], c07_6.d[4], c07_7.d[4] );
162  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[5], c07_1.d[5], c07_2.d[5], c07_3.d[5], c07_4.d[5], c07_5.d[5], c07_6.d[5], c07_7.d[5] );
163  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[6], c07_1.d[6], c07_2.d[6], c07_3.d[6], c07_4.d[6], c07_5.d[6], c07_6.d[6], c07_7.d[6] );
164  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[7], c07_1.d[7], c07_2.d[7], c07_3.d[7], c07_4.d[7], c07_5.d[7], c07_6.d[7], c07_7.d[7] );
165 
166 
167 
168 
169 
170  // Prefetch u, w
171  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( u ) );
172  __asm__ volatile( "prefetcht0 64(%0) \n\t" : :"r"( u ) );
173  __asm__ volatile( "prefetcht0 128(%0) \n\t" : :"r"( u ) );
174  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( w ) );
175 
176  // c = exp( c )
177  if ( 0 )
178  {
179  #include "exp_int_d24x8.h"
180  }
181  else
182  {
183  c07_0.v = _mm512_exp_pd( c07_0.v );
184  c07_1.v = _mm512_exp_pd( c07_1.v );
185  c07_2.v = _mm512_exp_pd( c07_2.v );
186  c07_3.v = _mm512_exp_pd( c07_3.v );
187  c07_4.v = _mm512_exp_pd( c07_4.v );
188  c07_5.v = _mm512_exp_pd( c07_5.v );
189  c07_6.v = _mm512_exp_pd( c07_6.v );
190  c07_7.v = _mm512_exp_pd( c07_7.v );
191 
192  c15_0.v = _mm512_exp_pd( c15_0.v );
193  c15_1.v = _mm512_exp_pd( c15_1.v );
194  c15_2.v = _mm512_exp_pd( c15_2.v );
195  c15_3.v = _mm512_exp_pd( c15_3.v );
196  c15_4.v = _mm512_exp_pd( c15_4.v );
197  c15_5.v = _mm512_exp_pd( c15_5.v );
198  c15_6.v = _mm512_exp_pd( c15_6.v );
199  c15_7.v = _mm512_exp_pd( c15_7.v );
200 
201  c23_0.v = _mm512_exp_pd( c23_0.v );
202  c23_1.v = _mm512_exp_pd( c23_1.v );
203  c23_2.v = _mm512_exp_pd( c23_2.v );
204  c23_3.v = _mm512_exp_pd( c23_3.v );
205  c23_4.v = _mm512_exp_pd( c23_4.v );
206  c23_5.v = _mm512_exp_pd( c23_5.v );
207  c23_6.v = _mm512_exp_pd( c23_6.v );
208  c23_7.v = _mm512_exp_pd( c23_7.v );
209  }
210 
211  //printf( "exp\n" );
212  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[0], c07_1.d[0], c07_2.d[0], c07_3.d[0], c07_4.d[0], c07_5.d[0], c07_6.d[0], c07_7.d[0] );
213  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[1], c07_1.d[1], c07_2.d[1], c07_3.d[1], c07_4.d[1], c07_5.d[1], c07_6.d[1], c07_7.d[1] );
214  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[2], c07_1.d[2], c07_2.d[2], c07_3.d[2], c07_4.d[2], c07_5.d[2], c07_6.d[2], c07_7.d[2] );
215  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[3], c07_1.d[3], c07_2.d[3], c07_3.d[3], c07_4.d[3], c07_5.d[3], c07_6.d[3], c07_7.d[3] );
216  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[4], c07_1.d[4], c07_2.d[4], c07_3.d[4], c07_4.d[4], c07_5.d[4], c07_6.d[4], c07_7.d[4] );
217  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[5], c07_1.d[5], c07_2.d[5], c07_3.d[5], c07_4.d[5], c07_5.d[5], c07_6.d[5], c07_7.d[5] );
218  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[6], c07_1.d[6], c07_2.d[6], c07_3.d[6], c07_4.d[6], c07_5.d[6], c07_6.d[6], c07_7.d[6] );
219  //printf( "%lf, %lf, %lf, %lf, %lf, %lf, %lf, %lf\n", c07_0.d[7], c07_1.d[7], c07_2.d[7], c07_3.d[7], c07_4.d[7], c07_5.d[7], c07_6.d[7], c07_7.d[7] );
220 
221 
222 
223  // Preload u03, u47
224  a07.v = _mm512_load_pd( u );
225  a15.v = _mm512_load_pd( u + 8 );
226  a23.v = _mm512_load_pd( u + 16 );
227 
228  // Multiple rhs weighted sum.
229  #include<weighted_sum_int_d24x8.segment>
230 
231  //if ( u[ 0 ] != u[ 0 ] ) printf( "u[ 0 ] nan\n" );
232  //if ( u[ 1 ] != u[ 1 ] ) printf( "u[ 1 ] nan\n" );
233  //if ( u[ 2 ] != u[ 2 ] ) printf( "u[ 2 ] nan\n" );
234  //if ( u[ 3 ] != u[ 3 ] ) printf( "u[ 3 ] nan\n" );
235  //if ( u[ 4 ] != u[ 4 ] ) printf( "u[ 4 ] nan\n" );
236  //if ( u[ 5 ] != u[ 5 ] ) printf( "u[ 5 ] nan\n" );
237  //if ( u[ 6 ] != u[ 6 ] ) printf( "u[ 6 ] nan\n" );
238  //if ( u[ 7 ] != u[ 7 ] ) printf( "u[ 7 ] nan\n" );
239 
240  //if ( w[ 0 ] != w[ 0 ] ) printf( "w[ 0 ] nan\n" );
241  //if ( w[ 1 ] != w[ 1 ] ) printf( "w[ 1 ] nan\n" );
242  //if ( w[ 2 ] != w[ 2 ] ) printf( "w[ 2 ] nan\n" );
243  //if ( w[ 3 ] != w[ 3 ] ) printf( "w[ 3 ] nan\n" );
244  //if ( w[ 4 ] != w[ 4 ] ) printf( "w[ 4 ] nan\n" );
245  //if ( w[ 5 ] != w[ 5 ] ) printf( "w[ 5 ] nan\n" );
246  }
247 };
Definition: gaussian_d24x8.hpp:66
Definition: hmlp_internal.hpp:38
Definition: avx_type.h:4