1 #ifndef FUSED_MRXNR_HPP 2 #define FUSED_MRXNR_HPP 9 template<
int MR,
int NR,
10 typename OPKERNEL,
typename OP1,
typename OP2,
11 typename TA,
typename TB,
typename TC,
typename TV>
14 const static size_t mr = MR;
15 const static size_t nr = NR;
16 const static size_t pack_mr = MR;
17 const static size_t pack_nr = NR;
18 const static size_t align_size = 32;
25 inline void operator()
40 for (
int j = 0; j < NR; j ++ )
42 for (
int i = 0; i < MR; i ++ )
43 regV[ j * MR + i ] = initV;
48 for (
int j = 0; j < NR; j ++ )
50 for (
int i = 0; i < MR; i ++ )
51 regV[ j * MR + i ] = v[ j * ldv + i ];
55 for (
int p = 0; p < k; p ++ )
58 for (
int j = 0; j < NR; j ++ )
60 for (
int i = 0; i < MR; i ++ )
62 op1( regV[ j * MR + i ], op2( a[ p * MR + i ], b[ p * NR + j ] ) );
67 for (
int j = 0; j < NR; j ++ )
69 for (
int i = 0; i < MR; i ++ )
70 c[ j * ldc + i ] = opkernel( regV[ j * MR + i ], aux->i, aux->j, aux->b );
74 inline void operator()
79 TV *v,
int rs_c,
int cs_c,
88 for (
int j = 0; j < NR; j ++ )
90 for (
int i = 0; i < MR; i ++ )
91 regV[ j * MR + i ] = initV;
96 for (
int j = 0; j < NR; j ++ )
98 for (
int i = 0; i < MR; i ++ )
99 regV[ j * MR + i ] = v[ j * cs_c + i * rs_c ];
103 for (
int p = 0; p < k; p ++ )
106 for (
int j = 0; j < NR; j ++ )
108 for (
int i = 0; i < MR; i ++ )
110 op1( regV[ j * MR + i ], op2( a[ p * MR + i ], b[ p * NR + j ] ) );
115 for (
int j = 0; j < NR; j ++ )
117 for (
int i = 0; i < MR; i ++ )
118 v[ j * cs_c + i * rs_c ] = opkernel( regV[ j * MR + i ], aux->i, aux->j, aux->b );
125 typename OPKERNEL,
typename OP1,
typename OP2,
typename OPREDUCE,
126 typename TA,
typename TB,
typename TC,
typename TV>
129 const static size_t mr = MR;
130 const static size_t nr = NR;
131 const static size_t pack_mr = MR;
132 const static size_t pack_nr = NR;
133 const static size_t align_size = 32;
142 inline void operator()
158 for (
int j = 0; j < NR; j ++ )
160 for (
int i = 0; i < MR; i ++ )
161 regV[ j * MR + i ] = initV;
166 for (
int j = 0; j < NR; j ++ )
168 for (
int i = 0; i < MR; i ++ )
169 regV[ j * MR + i ] = v[ j * ldv + i ];
173 for (
int p = 0; p < k; p ++ )
176 for (
int j = 0; j < NR; j ++ )
178 for (
int i = 0; i < MR; i ++ )
180 op1( regV[ j * MR + i ], op2( a[ p * MR + i ], b[ p * NR + j ] ) );
185 for (
int i = 0; i < MR; i ++ )
190 for (
int j = 0; j < NR; j ++ )
192 for (
int i = 0; i < MR; i ++ )
193 regC[ i ] = opreduce( regC[ i ], opkernel( regV[ j * MR + i ], aux->i, aux->j, aux->b ), aux->i, aux->j, aux->b );
196 for (
int i = 0; i < MR; i ++ )
200 #pragma omp atomic update 204 *c = opreduce( *c, regC[ i ] );
214 typename OPKERNEL,
typename OP1,
typename OP2,
215 typename TA,
typename TB,
typename TC,
typename TPACKC,
typename TV>
218 const static size_t mr = MR;
219 const static size_t nr = NR;
220 const static size_t pack_mr = MR;
221 const static size_t pack_nr = NR;
222 const static size_t align_size = 32;
229 inline void operator()
235 TV *v,
int rs_v,
int cs_v,
240 TPACKC regC[ MR * NR ];
244 for (
int j = 0; j < NR; j ++ )
245 for (
int i = 0; i < MR; i ++ )
246 regV[ j * MR + i ] = initV;
250 for (
int j = 0; j < aux->jb; j ++ )
251 for (
int i = 0; i < aux->ib; i ++ )
252 regV[ j * MR + i ] = v[ j * cs_v + i * rs_v ];
258 for (
int p = 0; p < k; p ++ )
261 for (
int j = 0; j < NR; j ++ )
263 for (
int i = 0; i < MR; i ++ )
264 regV[ j * MR + i ] = op1( regV[ j * MR + i ], op2( a[ p * MR + i ], b[ p * NR + j ] ) );
268 for (
int j = 0; j < NR; j ++ )
270 for (
int i = 0; i < MR; i ++ )
271 regC[ j * MR + i ] = opkernel( regV[ j * MR + i ], aux->i + i, aux->j + j, aux->b );
276 c->Unpack( aux->m, aux->i, aux->ib, aux->n, aux->j, aux->jb, regC );
Definition: fused_mrxnr.hpp:216
This kernel takes opkernel, op1 and op2 to implement an MR-by-NR GKMM operation.
Definition: fused_mrxnr.hpp:12
Definition: fused_mrxnr.hpp:127
Definition: hmlp_internal.hpp:38