HMLP: High-performance Machine Learning Primitives
rank_k_d6x8.hpp
1 #include <stdio.h>
2 #include <hmlp_internal.hpp>
3 
4 
6 BLIS_GEMM_KERNEL(bli_sgemm_asm_6x16,float);
7 BLIS_GEMM_KERNEL(bli_dgemm_asm_6x8,double);
8 
9 
11 {
12  const static size_t mr = 6;
13  const static size_t nr = 16;
14  const static size_t pack_mr = 6;
15  const static size_t pack_nr = 16;
16  const static size_t align_size = 32;
17  const static bool row_major = false;
18 
19  inline STRA_OPERATOR(float) const
20  {
21  printf( "no STRA_OPERATOR implementation\n" );
22  exit( 1 );
23  };
24 
25  inline GEMM_OPERATOR(float) const
26  {
27  float alpha = 1.0;
29  float beta = aux->pc ? 1.0 : 0.0;
31  bli_sgemm_asm_6x16
32  (
33  k,
34  &alpha,
35  a,
36  b,
37  &beta,
38  c, rs_c, cs_c,
39  aux
40  );
41  };
42 
43 };
46 struct rank_k_asm_d6x8
47 {
48  const static size_t mr = 6;
49  const static size_t nr = 8;
50  const static size_t pack_mr = 6;
51  const static size_t pack_nr = 8;
52  const static size_t align_size = 32;
53  const static bool row_major = false;
54 
55  inline STRA_OPERATOR(double) const
56  {
57  printf( "no STRA_OPERATOR implementation\n" );
58  exit( 1 );
59  };
60 
61  inline GEMM_OPERATOR(double) const
62  {
63  double alpha = 1.0;
65  double beta = aux->pc ? 1.0 : 0.0;
67  bli_dgemm_asm_6x8
68  (
69  k,
70  &alpha,
71  a,
72  b,
73  &beta,
74  c, rs_c, cs_c,
75  aux
76  );
77  };
78 
79 
80  template<typename TC>
81  inline void operator()
82  (
83  dim_t k,
84  double *a,
85  double *b,
86  TC *c,
87  double *v, inc_t rs_v, inc_t cs_v,
89  )
90  {
91  double alpha = 1.0;
93  double beta = aux->pc ? 1.0 : 0.0;
94 
96  double vtmp[ mr * nr ];
97 
98  if ( !is_same<TC, hmlp::MatrixLike<pack_mr, double, double>>::value )
99  {
100  if ( aux->pc )
101  {
102  for ( size_t j = 0; j < aux->jb; j ++ )
103  for ( size_t i = 0; i < aux->ib; i ++ )
104  vtmp[ j * mr + i ] = v[ j * cs_v + i * rs_v ];
105  }
106 
107  v = vtmp;
108  rs_v = 1;
109  cs_v = mr;
110  }
111 
113  bli_dgemm_asm_6x8
114  (
115  k,
116  &alpha,
117  a,
118  b,
119  &beta,
120  v, rs_v, cs_v,
121  reinterpret_cast<aux_s<double, double, double, double>*>( aux )
122  );
123 
128  if ( !is_same<TC, hmlp::MatrixLike<pack_mr, double, double>>::value ||
129  aux->ib != mr || aux->jb != nr )
130  {
131  //printf( "bug %d %d %d %d %d %d\n", aux->m, aux->i, aux->ib, aux->n, aux->j, aux->jb );
132  c->Unpack( aux->m, aux->i, aux->ib, aux->n, aux->j, aux->jb, v );
133  }
134 
135  };
136 
137 };
Definition: rank_k_d6x8.hpp:10
GEMM_OPERATOR(float) const
Definition: rank_k_d6x8.hpp:25
GEMM_OPERATOR(double) const
Definition: rank_k_d6x8.hpp:61
Definition: hmlp_internal.hpp:38
Definition: packing.hpp:198
Definition: rank_k_d6x8.hpp:164