HMLP: High-performance Machine Learning Primitives
rank_k_d12x16.hpp
1 #include <stdio.h>
2 #include <hmlp_internal.hpp>
3 #include <packing.hpp>
4 
6 BLIS_GEMM_KERNEL(bli_sgemm_opt_12x32_l2,float);
7 BLIS_GEMM_KERNEL(bli_dgemm_opt_12x16_l2,double);
8 
9 
11 {
12  //const static size_t mr = 12;
13  //const static size_t nr = 32;
14  //const static size_t pack_mr = 12;
15  //const static size_t pack_nr = 32;
16  const static size_t mr = 32;
17  const static size_t nr = 12;
18  const static size_t pack_mr = 32;
19  const static size_t pack_nr = 12;
20  const static size_t align_size = 64;
21  const static bool row_major = true;
22 
23  inline STRA_OPERATOR(float) const
24  {
25  printf( "no STRA_OPERATOR implementation\n" );
26  exit( 1 );
27  };
28 
29  inline GEMM_OPERATOR(float) const
30  {
31  float alpha = 1.0;
33  float beta = aux->pc ? 1.0 : 0.0;
35  //bli_sgemm_opt_12x32_l2
36  //(
37  // k,
38  // &alpha,
39  // a,
40  // b,
41  // &beta,
42  // c, rs_c, cs_c,
43  // aux
44  //);
45  bli_sgemm_opt_12x32_l2
46  (
47  k,
48  &alpha,
49  b,
50  a,
51  &beta,
52  c, cs_c, rs_c,
53  aux
54  );
55  };
56 
57  template<typename TC>
58  inline void operator()
59  (
60  dim_t k,
61  float *a,
62  float *b,
63  TC *c,
64  float *v, inc_t rs_v, inc_t cs_v,
66  )
67  {
68  float alpha = 1.0;
70  float beta = aux->pc ? 1.0 : 0.0;
71 
73  float vtmp[ mr * nr ];
74 
75  if ( !is_same<TC, hmlp::MatrixLike<pack_mr, float, float>>::value )
76  {
77  if ( aux->pc )
78  {
79  for ( size_t j = 0; j < aux->jb; j ++ )
80  for ( size_t i = 0; i < aux->ib; i ++ )
81  vtmp[ j * mr + i ] = v[ j * cs_v + i * rs_v ];
82  }
83 
84  v = vtmp;
85  rs_v = 1;
86  cs_v = mr;
87  }
88 
90  //bli_sgemm_opt_12x32_l2
91  //(
92  // k,
93  // &alpha,
94  // a,
95  // b,
96  // &beta,
97  // v, rs_v, cs_v,
98  // reinterpret_cast<aux_s<float, float, float, float>*>( aux )
99  //);
100  bli_sgemm_opt_12x32_l2
101  (
102  k,
103  &alpha,
104  b,
105  a,
106  &beta,
107  v, cs_v, rs_v,
108  reinterpret_cast<aux_s<float, float, float, float>*>( aux )
109  );
110 
115  if ( !is_same<TC, hmlp::MatrixLike<pack_mr, float, float>>::value ||
116  aux->ib != mr || aux->jb != nr )
117  {
118  c->Unpack( aux->m, aux->i, aux->ib, aux->n, aux->j, aux->jb, v );
119  }
120 
121  };
122 
123 
124 };
128 {
129  //const static size_t mr = 12;
130  //const static size_t nr = 16;
131  //const static size_t pack_mr = 12;
132  //const static size_t pack_nr = 16;
133  const static size_t mr = 16;
134  const static size_t nr = 12;
135  const static size_t pack_mr = 16;
136  const static size_t pack_nr = 12;
137  const static size_t align_size = 64;
138  const static bool row_major = true;
139 
140  inline STRA_OPERATOR(double) const
141  {
142  printf( "no STRA_OPERATOR implementation\n" );
143  exit( 1 );
144  };
145 
146  inline GEMM_OPERATOR(double) const
147  {
148  double alpha = 1.0;
150  double beta = aux->pc ? 1.0 : 0.0;
152  //bli_dgemm_opt_6x32_l2
153  //(
154  // k,
155  // &alpha,
156  // a,
157  // b,
158  // &beta,
159  // c, rs_c, cs_c,
160  // aux
161  //);
162  bli_dgemm_opt_12x16_l2
163  (
164  k,
165  &alpha,
166  b,
167  a,
168  &beta,
169  c, cs_c, rs_c,
170  aux
171  );
172  };
173 
174  template<typename TC>
175  inline void operator()
176  (
177  dim_t k,
178  double *a,
179  double *b,
180  TC *c,
181  double *v, inc_t rs_v, inc_t cs_v,
183  )
184  {
185  double alpha = 1.0;
187  double beta = aux->pc ? 1.0 : 0.0;
188 
190  double vtmp[ mr * nr ];
191 
192  if ( !is_same<TC, hmlp::MatrixLike<pack_mr, double, double>>::value )
193  {
194  if ( aux->pc )
195  {
196  for ( size_t j = 0; j < aux->jb; j ++ )
197  for ( size_t i = 0; i < aux->ib; i ++ )
198  vtmp[ j * mr + i ] = v[ j * cs_v + i * rs_v ];
199  }
200 
201  v = vtmp;
202  rs_v = 1;
203  cs_v = mr;
204  }
205 
207  //bli_dgemm_opt_12x16_l2
208  //(
209  // k,
210  // &alpha,
211  // a,
212  // b,
213  // &beta,
214  // v, rs_v, cs_v,
215  // reinterpret_cast<aux_s<double, double, double, double>*>( aux )
216  //);
217  bli_dgemm_opt_12x16_l2
218  (
219  k,
220  &alpha,
221  b,
222  a,
223  &beta,
224  v, cs_v, rs_v,
225  reinterpret_cast<aux_s<double, double, double, double>*>( aux )
226  );
227 
232  if ( !is_same<TC, hmlp::MatrixLike<pack_mr, double, double>>::value ||
233  aux->ib != mr || aux->jb != nr )
234  {
235  //printf( "bug %d %d %d %d %d %d\n", aux->m, aux->i, aux->ib, aux->n, aux->j, aux->jb );
236  c->Unpack( aux->m, aux->i, aux->ib, aux->n, aux->j, aux->jb, v );
237  }
238 
239  };
240 
241 };
GEMM_OPERATOR(double) const
Definition: rank_k_d12x16.hpp:146
Definition: rank_k_d12x16.hpp:10
Definition: hmlp_internal.hpp:38
Definition: packing.hpp:198
GEMM_OPERATOR(float) const
Definition: rank_k_d12x16.hpp:29
Definition: rank_k_d12x16.hpp:127