HMLP: High-performance Machine Learning Primitives
rank_k_d4x4.hpp
1 #include <stdio.h>
2 #include <hmlp_internal.hpp>
3 
4 // #define DEBUG_MICRO 1
5 
6 void bli_sgemm_opt_4x4
7 (
8  dim_t k,
9  float* restrict alpha,
10  float* restrict a,
11  float* restrict b,
12  float* restrict beta,
13  float* restrict c, inc_t rs_c, inc_t cs_c,
15 );
16 
17 void bli_dgemm_opt_4x4
18 (
19  dim_t k,
20  double* restrict alpha,
21  double* restrict a,
22  double* restrict b,
23  double* restrict beta,
24  double* restrict c, inc_t rs_c, inc_t cs_c,
26 );
27 
28 
29 //struct rank_k_ref_d8x4
30 //{
31 // inline void operator()
32 // (
33 // int k,
34 // double *a,
35 // double *b,
36 // double *c, int ldc,
37 // aux_s<double, double, double, double> *aux
38 // ) const
39 // {
40 // double c_reg[ 8 * 4 ] = { 0.0 };
41 //
42 // for ( int p = 0; p < k; p ++ )
43 // {
44 // #pragma unroll
45 // for ( int j = 0; j < 4; j ++ )
46 // {
47 // #pragma unroll
48 // for ( int i = 0; i < 8; i ++ )
49 // {
50 // c_reg[ j * 8 + i ] += a[ p * 8 + i ] * b[ p * 4 + j ];
51 // }
52 // }
53 // }
54 //
55 // if ( aux->pc )
56 // {
57 // #pragma unroll
58 // for ( int j = 0; j < 4; j ++ )
59 // {
60 // #pragma unroll
61 // for ( int i = 0; i < 8; i ++ )
62 // {
63 // c[ j * ldc + i ] += c_reg[ j * 8 + i ];
64 // }
65 // }
66 // }
67 // else
68 // {
69 // #pragma unroll
70 // for ( int j = 0; j < 4; j ++ )
71 // {
72 // #pragma unroll
73 // for ( int i = 0; i < 8; i ++ )
74 // {
75 // c[ j * ldc + i ] = c_reg[ j * 8 + i ];
76 // }
77 // }
78 // }
79 //
80 //#ifdef DEBUG_MICRO
81 // printf( "rank_k_ref_d8x4:" );
82 // for ( int i = 0; i < 8; i ++ )
83 // {
84 // for ( int j = 0; j < 4; j ++ )
85 // {
86 // printf( "%E ", c[ j * ldc + i ] );
87 // }
88 // printf( "\n" );
89 // }
90 //#endif
91 // }
92 //};
93 
94 
95 
97 {
98  // Strassen interface
99  inline void operator()
100  (
101  int k,
102  float *a,
103  float *b,
104  int len,
105  float **c, int ldc, float *alpha,
107  ) const
108  {
109  float c_reg[ 4 * 4 ] = { 0.0 };
110 
111  for ( int p = 0; p < k; p ++ )
112  {
113  #pragma unroll
114  for ( int j = 0; j < 4; j ++ )
115  {
116  #pragma unroll
117  for ( int i = 0; i < 4; i ++ )
118  {
119  c_reg[ j * 4 + i ] += a[ p * 4 + i ] * b[ p * 4 + j ];
120  }
121  }
122  }
123 
124  for ( int t = 0; t < len; t ++ )
125  {
126  #pragma unroll
127  for ( int j = 0; j < 4; j ++ )
128  {
129  #pragma unroll
130  for ( int i = 0; i < 4; i ++ )
131  {
132  c[ t ][ j * ldc + i ] += alpha[ t ] * c_reg[ j * 4 + i ];
133  }
134  }
135  }
136  }; // end inline void operator()
137 
138  inline void operator()
139  (
140  int k,
141  float *a,
142  float *b,
143  float *c, int ldc,
145  ) const
146  {
147  float alpha = 1.0;
148  float beta = aux->pc ? 1.0 : 0.0;
149  bli_sgemm_opt_4x4
150  (
151  k,
152  &alpha,
153  a,
154  b,
155  &beta,
156  c, 1, ldc,
157  aux
158  );
159  }; // end inline void operator()
160 
161 }; // end struct rank_k_asm_d4x4
162 
163 
164 
165 
166 
167 
168 
169 
170 
171 
172 
173 
175 {
176  // Strassen interface
177  inline void operator()
178  (
179  int k,
180  double *a,
181  double *b,
182  int len,
183  double **c, int ldc, double *alpha,
185  ) const
186  {
187  double c_reg[ 4 * 4 ] = { 0.0 };
188 
189  for ( int p = 0; p < k; p ++ )
190  {
191  #pragma unroll
192  for ( int j = 0; j < 4; j ++ )
193  {
194  #pragma unroll
195  for ( int i = 0; i < 4; i ++ )
196  {
197  c_reg[ j * 4 + i ] += a[ p * 4 + i ] * b[ p * 4 + j ];
198  }
199  }
200  }
201 
202  for ( int t = 0; t < len; t ++ )
203  {
204  #pragma unroll
205  for ( int j = 0; j < 4; j ++ )
206  {
207  #pragma unroll
208  for ( int i = 0; i < 4; i ++ )
209  {
210  c[ t ][ j * ldc + i ] += alpha[ t ] * c_reg[ j * 4 + i ];
211  }
212  }
213  }
214  }; // end inline void operator()
215 
216  inline void operator()
217  (
218  int k,
219  double *a,
220  double *b,
221  double *c, int ldc,
223  ) const
224  {
225  double alpha = 1.0;
226  double beta = aux->pc ? 1.0 : 0.0;
227  bli_dgemm_opt_4x4
228  (
229  k,
230  &alpha,
231  a,
232  b,
233  &beta,
234  c, 1, ldc,
235  aux
236  );
237  }; // end inline void operator()
238 
239 }; // end struct rank_k_asm_d4x4
Definition: rank_k_d4x4.hpp:174
Definition: hmlp_internal.hpp:38
Definition: rank_k_d4x4.hpp:96