GCC Code Coverage Report
Directory: . Exec Total Coverage
File: kernel/reference/fused_mrxnr.hpp Lines: 0 34 0.0 %
Date: 2019-01-14 Branches: 0 40 0.0 %

Line Exec Source
1
#ifndef FUSED_MRXNR_HPP
2
#define FUSED_MRXNR_HPP
3
4
/**
5
 *  @brief This kernel takes opkernel, op1 and op2 to implement an MR-by-NR
6
 *         GKMM operation.
7
 *
8
 */
9
template<int MR, int NR,
10
typename OPKERNEL, typename OP1, typename OP2,
11
typename TA, typename TB, typename TC, typename TV>
12
struct gkmm_mrxnr
13
{
14
  const static size_t mr         = MR;
15
  const static size_t nr         = NR;
16
  const static size_t pack_mr    = MR;
17
  const static size_t pack_nr    = NR;
18
  const static size_t align_size = 32;
19
20
  OPKERNEL opkernel;
21
  OP1 op1;
22
  OP2 op2;
23
  TV initV;
24
25
  inline void operator()
26
  (
27
    int k,
28
    TA *a,
29
    TB *b,
30
    TC *c, int ldc,
31
    TV *v, int ldv,
32
    aux_s<TA, TB, TC, TV> *aux
33
  ) const
34
  {
35
    TV regV[ MR * NR ];
36
37
    if ( !aux->pc ) // Initialize
38
    {
39
      #pragma unroll
40
      for ( int j = 0; j < NR; j ++ )
41
        #pragma simd
42
        for ( int i = 0; i < MR; i ++ )
43
          regV[ j * MR + i ] = initV;
44
    }
45
    else // accumulate
46
    {
47
      #pragma unroll
48
      for ( int j = 0; j < NR; j ++ )
49
        #pragma simd
50
        for ( int i = 0; i < MR; i ++ )
51
          regV[ j * MR + i ] = v[ j * ldv + i ];
52
    }
53
54
    // semiring rank-k update
55
    for ( int p = 0; p < k; p ++ )
56
    {
57
      #pragma unroll
58
      for ( int j = 0; j < NR; j ++ )
59
        #pragma simd
60
        for ( int i = 0; i < MR; i ++ )
61
          regV[ j * MR + i ] =
62
            op1( regV[ j * MR + i ], op2( a[ p * MR + i ], b[ p * NR + j ] ) );
63
    }
64
65
    // kernel transformation and store back
66
    #pragma unroll
67
    for ( int j = 0; j < NR; j ++ )
68
      #pragma simd
69
      for ( int i = 0; i < MR; i ++ )
70
        c[ j * ldc + i ] = opkernel( regV[ j * MR + i ], aux->i, aux->j, aux->b );
71
  };
72
73
74
  inline void operator()
75
  (
76
    int k,
77
    TA *a,
78
    TB *b,
79
    TV *v, int rs_c, int cs_c,
80
    aux_s<TA, TB, TC, TV> *aux
81
  ) const
82
  {
83
    TV regV[ MR * NR ];
84
85
    if ( !aux->pc ) // Initialize
86
    {
87
      #pragma unroll
88
      for ( int j = 0; j < NR; j ++ )
89
        #pragma simd
90
        for ( int i = 0; i < MR; i ++ )
91
          regV[ j * MR + i ] = initV;
92
    }
93
    else // accumulate
94
    {
95
      #pragma unroll
96
      for ( int j = 0; j < NR; j ++ )
97
        #pragma simd
98
        for ( int i = 0; i < MR; i ++ )
99
          regV[ j * MR + i ] = v[ j * cs_c + i * rs_c ];
100
    }
101
102
    // semiring rank-k update
103
    for ( int p = 0; p < k; p ++ )
104
    {
105
      #pragma unroll
106
      for ( int j = 0; j < NR; j ++ )
107
        #pragma simd
108
        for ( int i = 0; i < MR; i ++ )
109
          regV[ j * MR + i ] =
110
            op1( regV[ j * MR + i ], op2( a[ p * MR + i ], b[ p * NR + j ] ) );
111
    }
112
113
    // kernel transformation and store back
114
    #pragma unroll
115
    for ( int j = 0; j < NR; j ++ )
116
      #pragma simd
117
      for ( int i = 0; i < MR; i ++ )
118
        v[ j * cs_c + i * rs_c ] = opkernel( regV[ j * MR + i ], aux->i, aux->j, aux->b );
119
  };
120
};
121
122
123
template<
124
int MR, int NR,
125
typename OPKERNEL, typename OP1, typename OP2, typename OPREDUCE,
126
typename TA, typename TB, typename TC, typename TV>
127
struct gkrm_mrxnr
128
{
129
  const static size_t mr         = MR;
130
  const static size_t nr         = NR;
131
  const static size_t pack_mr    = MR;
132
  const static size_t pack_nr    = NR;
133
  const static size_t align_size = 32;
134
135
  OPKERNEL opkernel;
136
  OP1 op1;
137
  OP2 op2;
138
  TV initV;
139
  OPKERNEL opreduce;
140
  TC initC;
141
142
  inline void operator()
143
  (
144
    int k,
145
    TA *a,
146
    TB *b,
147
    TC *c, int ldc, // ldc is redundant here
148
    TV *v, int ldv,
149
    aux_s<TA, TB, TC, TV> *aux
150
  ) const
151
  {
152
    TV regV[ MR * NR ];
153
    TC regC[ MR ];
154
155
    if ( !aux->pc ) // Initialize
156
    {
157
      #pragma unroll
158
      for ( int j = 0; j < NR; j ++ )
159
        #pragma simd
160
        for ( int i = 0; i < MR; i ++ )
161
          regV[ j * MR + i ] = initV;
162
    }
163
    else // accumulate
164
    {
165
      #pragma unroll
166
      for ( int j = 0; j < NR; j ++ )
167
        #pragma simd
168
        for ( int i = 0; i < MR; i ++ )
169
          regV[ j * MR + i ] = v[ j * ldv + i ];
170
    }
171
172
    // semiring rank-k update
173
    for ( int p = 0; p < k; p ++ )
174
    {
175
      #pragma unroll
176
      for ( int j = 0; j < NR; j ++ )
177
        #pragma simd
178
        for ( int i = 0; i < MR; i ++ )
179
          regV[ j * MR + i ] =
180
            op1( regV[ j * MR + i ], op2( a[ p * MR + i ], b[ p * NR + j ] ) );
181
    }
182
183
    // Initialize
184
    #pragma simd
185
    for ( int i = 0; i < MR; i ++ )
186
      regC[ i ] = initC;
187
188
    // kernel transformation and reduction
189
    #pragma unroll
190
    for ( int j = 0; j < NR; j ++ )
191
      #pragma simd
192
      for ( int i = 0; i < MR; i ++ )
193
        regC[ i ] = opreduce( regC[ i ], opkernel( regV[ j * MR + i ], aux->i, aux->j, aux->b ), aux->i, aux->j, aux->b );
194
195
    // Here we need omp atomic update
196
    for ( int i = 0; i < MR; i ++ )
197
    {
198
      TC *cptr = c + i;
199
#ifdef USE_INTEL
200
      #pragma omp atomic update
201
#else
202
			#pragma omp critical
203
#endif
204
      *c = opreduce( *c, regC[ i ] );
205
    }
206
  };
207
208
}; /** end struct gkrm_mrxnr */
209
210
211
212
template<
213
int MR, int NR,
214
typename OPKERNEL, typename OP1, typename OP2,
215
typename TA, typename TB, typename TC, typename TPACKC, typename TV>
216
struct gnbx_mrxnr
217
{
218
  const static size_t mr         = MR;
219
  const static size_t nr         = NR;
220
  const static size_t pack_mr    = MR;
221
  const static size_t pack_nr    = NR;
222
  const static size_t align_size = 32;
223
224
  OPKERNEL opkernel;
225
  OP1 op1;
226
  OP2 op2;
227
  TV initV;
228
229
  inline void operator()
230
  (
231
    int k,
232
    TA *a,
233
    TB *b,
234
    TC *c,
235
    TV *v, int rs_v, int cs_v,
236
    aux_s<TA, TB, TC, TV> *aux
237
  ) const
238
  {
239
    TV     regV[ MR * NR ];
240
    TPACKC regC[ MR * NR ];
241
242
    if ( !aux->pc )
243
    {
244
      for ( int j = 0; j < NR; j ++ )
245
        for ( int i = 0; i < MR; i ++ )
246
          regV[ j * MR + i ] = initV;
247
    }
248
    else
249
    {
250
      for ( int j = 0; j < aux->jb; j ++ )
251
        for ( int i = 0; i < aux->ib; i ++ )
252
          regV[ j * MR + i ] = v[ j * cs_v + i * rs_v ];
253
    }
254
255
    /**
256
     *  Semiring rank-k update
257
     */
258
    for ( int p = 0; p < k; p ++ )
259
    {
260
      #pragma unroll
261
      for ( int j = 0; j < NR; j ++ )
262
        #pragma simd
263
        for ( int i = 0; i < MR; i ++ )
264
          regV[ j * MR + i ] = op1( regV[ j * MR + i ], op2( a[ p * MR + i ], b[ p * NR + j ] ) );
265
    }
266
267
    #pragma unroll
268
    for ( int j = 0; j < NR; j ++ )
269
      #pragma simd
270
      for ( int i = 0; i < MR; i ++ )
271
        regC[ j * MR + i ] = opkernel( regV[ j * MR + i ], aux->i + i, aux->j + j, aux->b );
272
273
    /**
274
     *  Store back
275
     */
276
    c->Unpack( aux->m, aux->i, aux->ib, aux->n, aux->j, aux->jb, regC );
277
  };
278
279
}; /** end struct gnbx_mrxnr */
280
281
#endif /** define FUSED_MRXNR_HPP */