HMLP: High-performance Machine Learning Primitives
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Pages
rank_k_d6x8.hpp
1 #include <stdio.h>
2 #include <hmlp_internal.hpp>
3 
4 // #define DEBUG_MICRO 1
5 
6 void bli_sgemm_opt_8x12
7 (
8  dim_t k,
9  float* restrict alpha,
10  float* restrict a,
11  float* restrict b,
12  float* restrict beta,
13  float* restrict c, inc_t rs_c, inc_t cs_c,
15 );
16 
17 void bli_dgemm_opt_6x8
18 (
19  dim_t k,
20  double* restrict alpha,
21  double* restrict a,
22  double* restrict b,
23  double* restrict beta,
24  double* restrict c, inc_t rs_c, inc_t cs_c,
26 );
27 
29 {
30  inline void operator()
31  (
32  int k,
33  double *a,
34  double *b,
35  double *c, int ldc,
37  ) const
38  {
39  double c_reg[ 8 * 4 ] = { 0.0 };
40 
41  for ( int p = 0; p < k; p ++ )
42  {
43  #pragma unroll
44  for ( int j = 0; j < 4; j ++ )
45  {
46  #pragma unroll
47  for ( int i = 0; i < 8; i ++ )
48  {
49  c_reg[ j * 8 + i ] += a[ p * 8 + i ] * b[ p * 4 + j ];
50  }
51  }
52  }
53 
54  if ( aux->pc )
55  {
56  #pragma unroll
57  for ( int j = 0; j < 4; j ++ )
58  {
59  #pragma unroll
60  for ( int i = 0; i < 8; i ++ )
61  {
62  c[ j * ldc + i ] += c_reg[ j * 8 + i ];
63  }
64  }
65  }
66  else
67  {
68  #pragma unroll
69  for ( int j = 0; j < 4; j ++ )
70  {
71  #pragma unroll
72  for ( int i = 0; i < 8; i ++ )
73  {
74  c[ j * ldc + i ] = c_reg[ j * 8 + i ];
75  }
76  }
77  }
78 
79 #ifdef DEBUG_MICRO
80  printf( "rank_k_ref_d8x4:" );
81  for ( int i = 0; i < 8; i ++ )
82  {
83  for ( int j = 0; j < 4; j ++ )
84  {
85  printf( "%E ", c[ j * ldc + i ] );
86  }
87  printf( "\n" );
88  }
89 #endif
90  }
91 };
92 
93 
95 {
96  // Strassen interface
97  inline void operator()
98  (
99  int k,
100  float *a,
101  float *b,
102  int len,
103  float **c, int ldc, float *alpha,
105  ) const
106  {
107  float c_reg[ 8 * 12 ] = { 0.0 };
108 
109  for ( int p = 0; p < k; p ++ )
110  {
111  #pragma unroll
112  for ( int j = 0; j < 12; j ++ )
113  {
114  #pragma unroll
115  for ( int i = 0; i < 8; i ++ )
116  {
117  c_reg[ j * 8 + i ] += a[ p * 8 + i ] * b[ p * 12 + j ];
118  }
119  }
120  }
121 
122  for ( int t = 0; t < len; t ++ )
123  {
124  #pragma unroll
125  for ( int j = 0; j < 12; j ++ )
126  {
127  #pragma unroll
128  for ( int i = 0; i < 8; i ++ )
129  {
130  c[ t ][ j * ldc + i ] += alpha[ t ] * c_reg[ j * 8 + i ];
131  }
132  }
133  }
134  }; // end inline void operator()
135 
136  inline void operator()
137  (
138  int k,
139  float *a,
140  float *b,
141  float *c, int ldc,
143  ) const
144  {
145  float alpha = 1.0;
146  float beta = aux->pc ? 1.0 : 0.0;
147  bli_sgemm_opt_8x12
148  (
149  k,
150  &alpha,
151  a,
152  b,
153  &beta,
154  c, 1, ldc,
155  aux
156  );
157  }; // end inline void operator()
158 }; // end struct rank_k_asm_s8x12
159 
160 
161 
162 
163 
165 {
166  // Strassen interface
167  inline void operator()
168  (
169  int k,
170  double *a,
171  double *b,
172  int len,
173  double **c, int ldc, double *alpha,
175  ) const
176  {
177  double c_reg[ 6 * 8 ] = { 0.0 };
178 
179  for ( int p = 0; p < k; p ++ )
180  {
181  #pragma unroll
182  for ( int j = 0; j < 8; j ++ )
183  {
184  #pragma unroll
185  for ( int i = 0; i < 6; i ++ )
186  {
187  c_reg[ j * 8 + i ] += a[ p * 6 + i ] * b[ p * 8 + j ];
188  }
189  }
190  }
191 
192  for ( int t = 0; t < len; t ++ )
193  {
194  #pragma unroll
195  for ( int j = 0; j < 8; j ++ )
196  {
197  #pragma unroll
198  for ( int i = 0; i < 6; i ++ )
199  {
200  c[ t ][ j * ldc + i ] += alpha[ t ] * c_reg[ j * 6 + i ];
201  }
202  }
203  }
204  }; // end inline void operator()
205 
206  inline void operator()
207  (
208  int k,
209  double *a,
210  double *b,
211  double *c, int ldc,
213  ) const
214  {
215  double alpha = 1.0;
216  double beta = aux->pc ? 1.0 : 0.0;
217  bli_dgemm_opt_6x8
218  (
219  k,
220  &alpha,
221  a,
222  b,
223  &beta,
224  c, 1, ldc,
225  aux
226  );
227  }; // end inline void operator()
228 
229 }; // end struct rank_k_asm_d8x4
Definition: rank_k_d6x8.hpp:94
Definition: hmlp_internal.hpp:38
Definition: rank_k_d6x8.hpp:164
Definition: rank_k_d6x8.hpp:28