HMLP: High-performance Machine Learning Primitives
rank_k_d8x6.hpp
1 #include <stdio.h>
2 #include <hmlp_internal.hpp>
3 #include <packing.hpp>
4 
6 BLIS_GEMM_KERNEL(bli_sgemm_asm_16x6,float);
7 BLIS_GEMM_KERNEL(bli_dgemm_asm_8x6,double);
8 
9 using namespace hmlp;
10 
12 {
13  const static size_t mr = 16;
14  const static size_t nr = 6;
15  const static size_t pack_mr = 16;
16  const static size_t pack_nr = 6;
17  const static size_t align_size = 32;
18  const static bool row_major = false;
19 
20  inline STRA_OPERATOR(float) const
21  {
22  printf( "no STRA_OPERATOR implementation\n" );
23  exit( 1 );
24  };
25 
26  inline GEMM_OPERATOR(float) const
27  {
28  float alpha = 1.0;
30  float beta = aux->pc ? 1.0 : 0.0;
32  bli_sgemm_asm_16x6( k, &alpha, a, b, &beta, c, rs_c, cs_c, aux );
33  };
34 
35 
36  template<typename TC>
37  inline void operator()
38  (
39  dim_t k,
40  float *a,
41  float *b,
42  TC *c,
43  float *v, inc_t rs_v, inc_t cs_v,
45  )
46  {
47  float alpha = 1.0;
49  float beta = aux->pc ? 1.0 : 0.0;
50 
52  float vtmp[ mr * nr ];
53 
54  if ( !is_same<TC, MatrixLike<pack_mr, float, float>>::value )
55  {
56  if ( aux->pc )
57  {
58  for ( size_t j = 0; j < aux->jb; j ++ )
59  for ( size_t i = 0; i < aux->ib; i ++ )
60  vtmp[ j * mr + i ] = v[ j * cs_v + i * rs_v ];
61  }
62 
63  v = vtmp;
64  rs_v = 1;
65  cs_v = mr;
66  }
67 
69  bli_sgemm_asm_16x6
70  (
71  k,
72  &alpha,
73  a,
74  b,
75  &beta,
76  v, rs_v, cs_v,
77  reinterpret_cast<aux_s<float, float, float, float>*>( aux )
78  );
79 
84  if ( !is_same<TC, MatrixLike<pack_mr, float, float>>::value ||
85  aux->ib != mr || aux->jb != nr )
86  {
87  //printf( "bug %d %d %d %d %d %d\n", aux->m, aux->i, aux->ib, aux->n, aux->j, aux->jb );
88  c->Unpack( aux->m, aux->i, aux->ib, aux->n, aux->j, aux->jb, v );
89  }
90 
91  };
92 
93 };
97 {
98  const static size_t mr = 8;
99  const static size_t nr = 6;
100  const static size_t pack_mr = 8;
101  const static size_t pack_nr = 6;
102  const static size_t align_size = 32;
103  const static bool row_major = false;
104 
105  inline STRA_OPERATOR(double) const
106  {
107  printf( "no STRA_OPERATOR implementation\n" );
108  exit( 1 );
109  };
110 
111  inline GEMM_OPERATOR(double) const
112  {
113  double alpha = 1.0;
115  double beta = aux->pc ? 1.0 : 0.0;
117  bli_dgemm_asm_8x6
118  (
119  k,
120  &alpha,
121  a,
122  b,
123  &beta,
124  c, rs_c, cs_c,
125  aux
126  );
127  };
128 
129  template<typename TC>
130  inline void operator()
131  (
132  dim_t k,
133  double *a,
134  double *b,
135  TC *c,
136  double *v, inc_t rs_v, inc_t cs_v,
138  )
139  {
140  double alpha = 1.0;
142  double beta = aux->pc ? 1.0 : 0.0;
143 
145  double vtmp[ mr * nr ];
146 
147  if ( !is_same<TC, hmlp::MatrixLike<pack_mr, double, double>>::value )
148  {
149  if ( aux->pc )
150  {
151  for ( size_t j = 0; j < aux->jb; j ++ )
152  for ( size_t i = 0; i < aux->ib; i ++ )
153  vtmp[ j * mr + i ] = v[ j * cs_v + i * rs_v ];
154  }
155 
156  v = vtmp;
157  rs_v = 1;
158  cs_v = mr;
159  }
160 
162  bli_dgemm_asm_8x6
163  (
164  k,
165  &alpha,
166  a,
167  b,
168  &beta,
169  v, rs_v, cs_v,
170  reinterpret_cast<aux_s<double, double, double, double>*>( aux )
171  );
172 
177  if ( !is_same<TC, hmlp::MatrixLike<pack_mr, double, double>>::value ||
178  aux->ib != mr || aux->jb != nr )
179  {
180  //printf( "bug %d %d %d %d %d %d\n", aux->m, aux->i, aux->ib, aux->n, aux->j, aux->jb );
181  c->Unpack( aux->m, aux->i, aux->ib, aux->n, aux->j, aux->jb, v );
182  }
183 
184  };
185 
186 };
Definition: rank_k_d8x6.hpp:96
Definition: rank_k_d8x6.hpp:11
Definition: hmlp_internal.hpp:38
Definition: packing.hpp:198
GEMM_OPERATOR(float) const
Definition: rank_k_d8x6.hpp:26
GEMM_OPERATOR(double) const
Definition: rank_k_d8x6.hpp:111
Definition: gofmm.hpp:83