HMLP: High-performance Machine Learning Primitives
rank_k_d6x32.hpp
1 #include <stdio.h>
2 #include <hmlp_internal.hpp>
3 
4 
6 BLIS_GEMM_KERNEL(bli_sgemm_opt_12x32_l2,float);
7 BLIS_GEMM_KERNEL(bli_dgemm_opt_6x32_l2,double);
8 
9 
10 struct rank_k_opt_s12x32
11 {
12  const static size_t mr = 32;
13  const static size_t nr = 12;
14  const static size_t pack_mr = 32;
15  const static size_t pack_nr = 12;
16  const static size_t align_size = 64;
17  const static bool row_major = true;
18 
19  inline STRA_OPERATOR(float) const
20  {
21  printf( "no STRA_OPERATOR implementation\n" );
22  exit( 1 );
23  };
24 
25  inline GEMM_OPERATOR(float) const
26  {
27  //printf( "bli_sgemm_opt_12x32\n" ); fflush( stdout );
28  float alpha = 1.0;
30  float beta = aux->pc ? 1.0 : 0.0;
32  //bli_sgemm_opt_12x32_l2
33  //(
34  // k,
35  // &alpha,
36  // a,
37  // b,
38  // &beta,
39  // c, rs_c, cs_c,
40  // aux
41  //);
42  bli_sgemm_opt_12x32_l2
43  (
44  k,
45  &alpha,
46  b,
47  a,
48  &beta,
49  c, cs_c, rs_c,
50  aux
51  );
52  };
53 
54  template<typename TC>
55  inline void operator()
56  (
57  dim_t k,
58  float *a,
59  float *b,
60  TC *c,
61  float *v, inc_t rs_v, inc_t cs_v,
63  )
64  {
65  //printf( "bli_sgemm_opt_12x32 case2\n" ); fflush( stdout );
66  float alpha = 1.0;
68  float beta = aux->pc ? 1.0 : 0.0;
69 
71  float vtmp[ mr * nr ];
72 
73  if ( !is_same<TC, hmlp::MatrixLike<pack_mr, float, float>>::value )
74  {
75  if ( aux->pc )
76  {
77  for ( size_t j = 0; j < aux->jb; j ++ )
78  for ( size_t i = 0; i < aux->ib; i ++ )
79  vtmp[ j * mr + i ] = v[ j * cs_v + i * rs_v ];
80  }
81 
82  v = vtmp;
83  rs_v = 1;
84  cs_v = mr;
85  }
86 
88  //bli_sgemm_opt_12x32_l2
89  //(
90  // k,
91  // &alpha,
92  // a,
93  // b,
94  // &beta,
95  // v, rs_v, cs_v,
96  // reinterpret_cast<aux_s<float, float, float, float>*>( aux )
97  //);
98  bli_sgemm_opt_12x32_l2
99  (
100  k,
101  &alpha,
102  b,
103  a,
104  &beta,
105  v, cs_v, rs_v,
106  reinterpret_cast<aux_s<float, float, float, float>*>( aux )
107  );
108 
113  if ( !is_same<TC, hmlp::MatrixLike<pack_mr, float, float>>::value ||
114  aux->ib != mr || aux->jb != nr )
115  {
116  c->Unpack( aux->m, aux->i, aux->ib, aux->n, aux->j, aux->jb, v );
117  }
118 
119  };
120 
121 
122 };
126 {
127  //const static size_t mr = 6;
128  //const static size_t nr = 32;
129  //const static size_t pack_mr = 6;
130  //const static size_t pack_nr = 32;
131  const static size_t mr = 32;
132  const static size_t nr = 6;
133  const static size_t pack_mr = 32;
134  const static size_t pack_nr = 6;
135  const static size_t align_size = 64;
136  const static bool row_major = true;
137 
138  inline STRA_OPERATOR(double) const
139  {
140  printf( "no STRA_OPERATOR implementation\n" );
141  exit( 1 );
142  };
143 
144  inline GEMM_OPERATOR(double) const
145  {
146  //printf( "bli_dgemm_opt_6x32_l2\n" ); fflush( stdout );
147  double alpha = 1.0;
149  double beta = aux->pc ? 1.0 : 0.0;
151  //bli_dgemm_opt_6x32_l2
152  //(
153  // k,
154  // &alpha,
155  // a,
156  // b,
157  // &beta,
158  // c, rs_c, cs_c,
159  // aux
160  //);
161  bli_dgemm_opt_6x32_l2
162  (
163  k,
164  &alpha,
165  b,
166  a,
167  &beta,
168  c, cs_c, rs_c,
169  aux
170  );
171  };
172 
173  template<typename TC>
174  inline void operator()
175  (
176  dim_t k,
177  double *a,
178  double *b,
179  TC *c,
180  double *v, inc_t rs_v, inc_t cs_v,
182  )
183  {
184  //printf( "bli_dgemm_opt_6x32_l2 case2\n" ); fflush( stdout );
185  double alpha = 1.0;
187  double beta = aux->pc ? 1.0 : 0.0;
188 
190  double vtmp[ mr * nr ];
191 
192  if ( !is_same<TC, hmlp::MatrixLike<pack_mr, double, double>>::value )
193  {
194  if ( aux->pc )
195  {
196  for ( size_t j = 0; j < aux->jb; j ++ )
197  for ( size_t i = 0; i < aux->ib; i ++ )
198  vtmp[ j * mr + i ] = v[ j * cs_v + i * rs_v ];
199  }
200 
201  v = vtmp;
202  rs_v = 1;
203  cs_v = mr;
204  }
205 
207  //bli_dgemm_opt_6x32_l2
208  //(
209  // k,
210  // &alpha,
211  // a,
212  // b,
213  // &beta,
214  // v, rs_v, cs_v,
215  // reinterpret_cast<aux_s<double, double, double, double>*>( aux )
216  //);
217  bli_dgemm_opt_6x32_l2
218  (
219  k,
220  &alpha,
221  b,
222  a,
223  &beta,
224  v, cs_v, rs_v,
225  reinterpret_cast<aux_s<double, double, double, double>*>( aux )
226  );
227 
232  if ( !is_same<TC, hmlp::MatrixLike<pack_mr, double, double>>::value ||
233  aux->ib != mr || aux->jb != nr )
234  {
235  //printf( "bug %d %d %d %d %d %d\n", aux->m, aux->i, aux->ib, aux->n, aux->j, aux->jb );
236  c->Unpack( aux->m, aux->i, aux->ib, aux->n, aux->j, aux->jb, v );
237  }
238 
239  };
240 
241 };
Definition: rank_k_d12x16.hpp:10
Definition: rank_k_d6x32.hpp:125
GEMM_OPERATOR(double) const
Definition: rank_k_d6x32.hpp:144
Definition: hmlp_internal.hpp:38
Definition: packing.hpp:198
GEMM_OPERATOR(float) const
Definition: rank_k_d6x32.hpp:25