HMLP: High-performance Machine Learning Primitives
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Pages
gsks_d8x4.hpp
1 #include <stdio.h>
2 #include <math.h>
3 
5 #include <avx_type.h>
7 #include <hmlp.h>
8 #include <hmlp_internal.hpp>
9 
10 
11 
13 {
14  inline GSKS_OPERATOR(float) const
15  {
16  printf( "not implemented yet\n" );
17  exit( 1 );
18  };
19 };
20 
21 
22 
23 
25 {
26  const size_t mr = 8;
27  const size_t nr = 4;
28  const size_t pack_mr = 8;
29  const size_t pack_nr = 4;
30  const size_t align_size = 32;
31  const bool row_major = false;
32 
33 
34  //inline void operator()
35  //(
36  // kernel_s<double> *ker,
37  // int k,
38  // int rhs,
39  // double *u,
40  // double *a, double *aa,
41  // double *b, double *bb,
42  // double *w,
43  // double *c, int ldc,
44  // aux_s<double, double, double, double> *aux
45  //) const
46 
47  inline GSKS_OPERATOR(double) const
48  {
49  int i, rhs_left;
50  double neg2 = -2.0;
51  double dzero = 0.0;
52  double alpha = ker->scal;
53 
54  v4df_t c03_0, c03_1, c03_2, c03_3;
55  v4df_t c47_0, c47_1, c47_2, c47_3;
56  v4df_t tmpc03_0, tmpc03_1, tmpc03_2, tmpc03_3;
57  v4df_t tmpc47_0, tmpc47_1, tmpc47_2, tmpc47_3;
58  v4df_t u03, u47;
59  v4df_t a03, a47, A03, A47; // prefetched A
60  v4df_t b0, b1, b2, b3, B0; // prefetched B
61  v4df_t c_tmp, aa_tmp, bb_tmp, w_tmp;
62 
64  #include "component/rank_k_int_d8x4.hpp"
65 
66  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( aa ) );
67  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( bb ) );
68 
69  if ( aux->pc )
70  {
71  tmpc03_0.v = _mm256_load_pd( (double*)( c ) );
72  tmpc47_0.v = _mm256_load_pd( (double*)( c + 4 ) );
73 
74  tmpc03_1.v = _mm256_load_pd( (double*)( c + 1 * ldc ) );
75  tmpc47_1.v = _mm256_load_pd( (double*)( c + 1 * ldc + 4 ) );
76 
77  tmpc03_2.v = _mm256_load_pd( (double*)( c + 2 * ldc ) );
78  tmpc47_2.v = _mm256_load_pd( (double*)( c + 2 * ldc + 4 ) );
79 
80  tmpc03_3.v = _mm256_load_pd( (double*)( c + 3 * ldc ) );
81  tmpc47_3.v = _mm256_load_pd( (double*)( c + 3 * ldc + 4 ) );
82 
83 
84  c03_0.v = _mm256_add_pd( tmpc03_0.v, c03_0.v );
85  c47_0.v = _mm256_add_pd( tmpc47_0.v, c47_0.v );
86 
87  c03_1.v = _mm256_add_pd( tmpc03_1.v, c03_1.v );
88  c47_1.v = _mm256_add_pd( tmpc47_1.v, c47_1.v );
89 
90  c03_2.v = _mm256_add_pd( tmpc03_2.v, c03_2.v );
91  c47_2.v = _mm256_add_pd( tmpc47_2.v, c47_2.v );
92 
93  c03_3.v = _mm256_add_pd( tmpc03_3.v, c03_3.v );
94  c47_3.v = _mm256_add_pd( tmpc47_3.v, c47_3.v );
95  }
96 
97  // Scale -2
98  aa_tmp.v = _mm256_broadcast_sd( &neg2 );
99  c03_0.v = _mm256_mul_pd( aa_tmp.v, c03_0.v );
100  c03_1.v = _mm256_mul_pd( aa_tmp.v, c03_1.v );
101  c03_2.v = _mm256_mul_pd( aa_tmp.v, c03_2.v );
102  c03_3.v = _mm256_mul_pd( aa_tmp.v, c03_3.v );
103  c47_0.v = _mm256_mul_pd( aa_tmp.v, c47_0.v );
104  c47_1.v = _mm256_mul_pd( aa_tmp.v, c47_1.v );
105  c47_2.v = _mm256_mul_pd( aa_tmp.v, c47_2.v );
106  c47_3.v = _mm256_mul_pd( aa_tmp.v, c47_3.v );
107 
108 
109  aa_tmp.v = _mm256_load_pd( (double*)aa );
110  c03_0.v = _mm256_add_pd( aa_tmp.v, c03_0.v );
111  c03_1.v = _mm256_add_pd( aa_tmp.v, c03_1.v );
112  c03_2.v = _mm256_add_pd( aa_tmp.v, c03_2.v );
113  c03_3.v = _mm256_add_pd( aa_tmp.v, c03_3.v );
114 
115 
116  aa_tmp.v = _mm256_load_pd( (double*)( aa + 4 ) );
117  c47_0.v = _mm256_add_pd( aa_tmp.v, c47_0.v );
118  c47_1.v = _mm256_add_pd( aa_tmp.v, c47_1.v );
119  c47_2.v = _mm256_add_pd( aa_tmp.v, c47_2.v );
120  c47_3.v = _mm256_add_pd( aa_tmp.v, c47_3.v );
121 
122 
123  // Prefetch u
124  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( u ) );
125 
126  bb_tmp.v = _mm256_broadcast_sd( (double*)bb );
127  c03_0.v = _mm256_add_pd( bb_tmp.v, c03_0.v );
128  c47_0.v = _mm256_add_pd( bb_tmp.v, c47_0.v );
129 
130  bb_tmp.v = _mm256_broadcast_sd( (double*)( bb + 1 ) );
131  c03_1.v = _mm256_add_pd( bb_tmp.v, c03_1.v );
132  c47_1.v = _mm256_add_pd( bb_tmp.v, c47_1.v );
133 
134  bb_tmp.v = _mm256_broadcast_sd( (double*)( bb + 2 ) );
135  c03_2.v = _mm256_add_pd( bb_tmp.v, c03_2.v );
136  c47_2.v = _mm256_add_pd( bb_tmp.v, c47_2.v );
137 
138  bb_tmp.v = _mm256_broadcast_sd( (double*)( bb + 3 ) );
139  c03_3.v = _mm256_add_pd( bb_tmp.v, c03_3.v );
140  c47_3.v = _mm256_add_pd( bb_tmp.v, c47_3.v );
141 
142 
143  // Check if there is any illegle value
144  c_tmp.v = _mm256_broadcast_sd( &dzero );
145  c03_0.v = _mm256_max_pd( c_tmp.v, c03_0.v );
146  c03_1.v = _mm256_max_pd( c_tmp.v, c03_1.v );
147  c03_2.v = _mm256_max_pd( c_tmp.v, c03_2.v );
148  c03_3.v = _mm256_max_pd( c_tmp.v, c03_3.v );
149  c47_0.v = _mm256_max_pd( c_tmp.v, c47_0.v );
150  c47_1.v = _mm256_max_pd( c_tmp.v, c47_1.v );
151  c47_2.v = _mm256_max_pd( c_tmp.v, c47_2.v );
152  c47_3.v = _mm256_max_pd( c_tmp.v, c47_3.v );
153 
154 
155  aa_tmp.v = _mm256_broadcast_sd( &alpha );
156  c03_0.v = _mm256_mul_pd( aa_tmp.v, c03_0.v );
157  c03_1.v = _mm256_mul_pd( aa_tmp.v, c03_1.v );
158  c03_2.v = _mm256_mul_pd( aa_tmp.v, c03_2.v );
159  c03_3.v = _mm256_mul_pd( aa_tmp.v, c03_3.v );
160  c47_0.v = _mm256_mul_pd( aa_tmp.v, c47_0.v );
161  c47_1.v = _mm256_mul_pd( aa_tmp.v, c47_1.v );
162  c47_2.v = _mm256_mul_pd( aa_tmp.v, c47_2.v );
163  c47_3.v = _mm256_mul_pd( aa_tmp.v, c47_3.v );
164 
165 
167  u03.v = _mm256_load_pd( (double*)u );
168  u47.v = _mm256_load_pd( (double*)( u + 4 ) );
169 
170 
172  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( u + 8 ) );
173  __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"( w ) );
174 
176  #include "component/exp_int_d8x4.hpp"
177 
179  #include "component/weighted_sum_int_d8x4.hpp"
180 
181  };
182 };
Definition: gsks_d8x4.hpp:12
GSKS_OPERATOR(double) const
Definition: gsks_d8x4.hpp:47
Definition: gsks_d8x4.hpp:24
Definition: avx_type.h:13