HMLP: High-performance Machine Learning Primitives
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Pages
fused_mrxnr.hpp
1 #ifndef FUSED_MRXNR_HPP
2 #define FUSED_MRXNR_HPP
3 
9 template<int MR, int NR,
10 typename OPKERNEL, typename OP1, typename OP2,
11 typename TA, typename TB, typename TC, typename TV>
12 struct gkmm_mrxnr
13 {
14  const static size_t mr = MR;
15  const static size_t nr = NR;
16  const static size_t pack_mr = MR;
17  const static size_t pack_nr = NR;
18  const static size_t align_size = 32;
19 
20  OPKERNEL opkernel;
21  OP1 op1;
22  OP2 op2;
23  TV initV;
24 
25  inline void operator()
26  (
27  int k,
28  TA *a,
29  TB *b,
30  TC *c, int ldc,
31  TV *v, int ldv,
33  ) const
34  {
35  TV regV[ MR * NR ];
36 
37  if ( !aux->pc ) // Initialize
38  {
39  #pragma unroll
40  for ( int j = 0; j < NR; j ++ )
41  #pragma simd
42  for ( int i = 0; i < MR; i ++ )
43  regV[ j * MR + i ] = initV;
44  }
45  else // accumulate
46  {
47  #pragma unroll
48  for ( int j = 0; j < NR; j ++ )
49  #pragma simd
50  for ( int i = 0; i < MR; i ++ )
51  regV[ j * MR + i ] = v[ j * ldv + i ];
52  }
53 
54  // semiring rank-k update
55  for ( int p = 0; p < k; p ++ )
56  {
57  #pragma unroll
58  for ( int j = 0; j < NR; j ++ )
59  #pragma simd
60  for ( int i = 0; i < MR; i ++ )
61  regV[ j * MR + i ] =
62  op1( regV[ j * MR + i ], op2( a[ p * MR + i ], b[ p * NR + j ] ) );
63  }
64 
65  // kernel transformation and store back
66  #pragma unroll
67  for ( int j = 0; j < NR; j ++ )
68  #pragma simd
69  for ( int i = 0; i < MR; i ++ )
70  c[ j * ldc + i ] = opkernel( regV[ j * MR + i ], aux->i, aux->j, aux->b );
71  };
72 
73 
74  inline void operator()
75  (
76  int k,
77  TA *a,
78  TB *b,
79  TV *v, int rs_c, int cs_c,
81  ) const
82  {
83  TV regV[ MR * NR ];
84 
85  if ( !aux->pc ) // Initialize
86  {
87  #pragma unroll
88  for ( int j = 0; j < NR; j ++ )
89  #pragma simd
90  for ( int i = 0; i < MR; i ++ )
91  regV[ j * MR + i ] = initV;
92  }
93  else // accumulate
94  {
95  #pragma unroll
96  for ( int j = 0; j < NR; j ++ )
97  #pragma simd
98  for ( int i = 0; i < MR; i ++ )
99  regV[ j * MR + i ] = v[ j * cs_c + i * rs_c ];
100  }
101 
102  // semiring rank-k update
103  for ( int p = 0; p < k; p ++ )
104  {
105  #pragma unroll
106  for ( int j = 0; j < NR; j ++ )
107  #pragma simd
108  for ( int i = 0; i < MR; i ++ )
109  regV[ j * MR + i ] =
110  op1( regV[ j * MR + i ], op2( a[ p * MR + i ], b[ p * NR + j ] ) );
111  }
112 
113  // kernel transformation and store back
114  #pragma unroll
115  for ( int j = 0; j < NR; j ++ )
116  #pragma simd
117  for ( int i = 0; i < MR; i ++ )
118  v[ j * cs_c + i * rs_c ] = opkernel( regV[ j * MR + i ], aux->i, aux->j, aux->b );
119  };
120 };
121 
122 
123 template<
124 int MR, int NR,
125 typename OPKERNEL, typename OP1, typename OP2, typename OPREDUCE,
126 typename TA, typename TB, typename TC, typename TV>
127 struct gkrm_mrxnr
128 {
129  const static size_t mr = MR;
130  const static size_t nr = NR;
131  const static size_t pack_mr = MR;
132  const static size_t pack_nr = NR;
133  const static size_t align_size = 32;
134 
135  OPKERNEL opkernel;
136  OP1 op1;
137  OP2 op2;
138  TV initV;
139  OPKERNEL opreduce;
140  TC initC;
141 
142  inline void operator()
143  (
144  int k,
145  TA *a,
146  TB *b,
147  TC *c, int ldc, // ldc is redundant here
148  TV *v, int ldv,
150  ) const
151  {
152  TV regV[ MR * NR ];
153  TC regC[ MR ];
154 
155  if ( !aux->pc ) // Initialize
156  {
157  #pragma unroll
158  for ( int j = 0; j < NR; j ++ )
159  #pragma simd
160  for ( int i = 0; i < MR; i ++ )
161  regV[ j * MR + i ] = initV;
162  }
163  else // accumulate
164  {
165  #pragma unroll
166  for ( int j = 0; j < NR; j ++ )
167  #pragma simd
168  for ( int i = 0; i < MR; i ++ )
169  regV[ j * MR + i ] = v[ j * ldv + i ];
170  }
171 
172  // semiring rank-k update
173  for ( int p = 0; p < k; p ++ )
174  {
175  #pragma unroll
176  for ( int j = 0; j < NR; j ++ )
177  #pragma simd
178  for ( int i = 0; i < MR; i ++ )
179  regV[ j * MR + i ] =
180  op1( regV[ j * MR + i ], op2( a[ p * MR + i ], b[ p * NR + j ] ) );
181  }
182 
183  // Initialize
184  #pragma simd
185  for ( int i = 0; i < MR; i ++ )
186  regC[ i ] = initC;
187 
188  // kernel transformation and reduction
189  #pragma unroll
190  for ( int j = 0; j < NR; j ++ )
191  #pragma simd
192  for ( int i = 0; i < MR; i ++ )
193  regC[ i ] = opreduce( regC[ i ], opkernel( regV[ j * MR + i ], aux->i, aux->j, aux->b ), aux->i, aux->j, aux->b );
194 
195  // Here we need omp atomic update
196  for ( int i = 0; i < MR; i ++ )
197  {
198  TC *cptr = c + i;
199 #ifdef USE_INTEL
200  #pragma omp atomic update
201 #else
202  #pragma omp critical
203 #endif
204  *c = opreduce( *c, regC[ i ] );
205  }
206  };
207 
208 };
212 template<
213 int MR, int NR,
214 typename OPKERNEL, typename OP1, typename OP2,
215 typename TA, typename TB, typename TC, typename TPACKC, typename TV>
216 struct gnbx_mrxnr
217 {
218  const static size_t mr = MR;
219  const static size_t nr = NR;
220  const static size_t pack_mr = MR;
221  const static size_t pack_nr = NR;
222  const static size_t align_size = 32;
223 
224  OPKERNEL opkernel;
225  OP1 op1;
226  OP2 op2;
227  TV initV;
228 
229  inline void operator()
230  (
231  int k,
232  TA *a,
233  TB *b,
234  TC *c,
235  TV *v, int rs_v, int cs_v,
237  ) const
238  {
239  TV regV[ MR * NR ];
240  TPACKC regC[ MR * NR ];
241 
242  if ( !aux->pc )
243  {
244  for ( int j = 0; j < NR; j ++ )
245  for ( int i = 0; i < MR; i ++ )
246  regV[ j * MR + i ] = initV;
247  }
248  else
249  {
250  for ( int j = 0; j < aux->jb; j ++ )
251  for ( int i = 0; i < aux->ib; i ++ )
252  regV[ j * MR + i ] = v[ j * cs_v + i * rs_v ];
253  }
254 
258  for ( int p = 0; p < k; p ++ )
259  {
260  #pragma unroll
261  for ( int j = 0; j < NR; j ++ )
262  #pragma simd
263  for ( int i = 0; i < MR; i ++ )
264  regV[ j * MR + i ] = op1( regV[ j * MR + i ], op2( a[ p * MR + i ], b[ p * NR + j ] ) );
265  }
266 
267  #pragma unroll
268  for ( int j = 0; j < NR; j ++ )
269  #pragma simd
270  for ( int i = 0; i < MR; i ++ )
271  regC[ j * MR + i ] = opkernel( regV[ j * MR + i ], aux->i + i, aux->j + j, aux->b );
272 
276  c->Unpack( aux->m, aux->i, aux->ib, aux->n, aux->j, aux->jb, regC );
277  };
278 
279 };
281 #endif
Definition: fused_mrxnr.hpp:216
This kernel takes opkernel, op1 and op2 to implement an MR-by-NR GKMM operation.
Definition: fused_mrxnr.hpp:12
Definition: fused_mrxnr.hpp:127
Definition: hmlp_internal.hpp:38