1 |
|
#ifndef SEMIRING_MRXNR_HPP |
2 |
|
#define SEMIRING_MRXNR_HPP |
3 |
|
|
4 |
|
#include <hmlp_internal.hpp> |
5 |
|
|
6 |
|
|
7 |
|
template< |
8 |
|
int MR, int NR, |
9 |
|
typename OP1, typename OP2, |
10 |
|
typename TA, typename TB, typename TC, typename TV> |
11 |
|
struct semiring_mrxnr |
12 |
|
{ |
13 |
|
const static size_t mr = MR; |
14 |
|
const static size_t nr = NR; |
15 |
|
const static size_t pack_mr = MR; |
16 |
|
const static size_t pack_nr = NR; |
17 |
|
const static size_t align_size = 32; |
18 |
|
|
19 |
|
OP1 op1; |
20 |
|
OP2 op2; |
21 |
|
TV initV; |
22 |
|
|
23 |
|
/** Strassen interface: ignore op1, op2 and initV */ |
24 |
|
inline void operator() |
25 |
|
( |
26 |
|
int k, |
27 |
|
TA *a, |
28 |
|
TB *b, |
29 |
|
int len, |
30 |
|
TV **v_list, int ldv, TV *alpha_list, |
31 |
|
aux_s<TA, TB, TC, TV> *aux |
32 |
|
) const |
33 |
|
{ |
34 |
|
TV regV[ MR * NR ] = { 0.0 }; |
35 |
|
|
36 |
|
// semiring rank-k update |
37 |
|
for ( int p = 0; p < k; p ++ ) |
38 |
|
{ |
39 |
|
#pragma unroll |
40 |
|
for ( int j = 0; j < NR; j ++ ) |
41 |
|
#pragma simd |
42 |
|
for ( int i = 0; i < MR; i ++ ) |
43 |
|
regV[ j * MR + i ] += a[ p * MR + i ] * b[ p * NR + j ]; |
44 |
|
} |
45 |
|
|
46 |
|
// store back |
47 |
|
for ( int t = 0; t < len; t ++ ) |
48 |
|
{ |
49 |
|
#pragma unroll |
50 |
|
for ( int j = 0; j < NR; j ++ ) |
51 |
|
{ |
52 |
|
#pragma simd |
53 |
|
for ( int i = 0; i < MR; i ++ ) |
54 |
|
{ |
55 |
|
v_list[ t ][ j * ldv + i ] += alpha_list[ t ] * regV[ j * MR + i ]; |
56 |
|
} |
57 |
|
} |
58 |
|
} |
59 |
|
}; |
60 |
|
|
61 |
|
/** Non-Strassen interface */ |
62 |
|
inline void operator() |
63 |
|
( |
64 |
|
dim_t k, |
65 |
|
TA *a, |
66 |
|
TB *b, |
67 |
|
TV *v, int rs_c, int cs_c, |
68 |
|
aux_s<TA, TB, TC, TV> *aux |
69 |
|
) const |
70 |
|
{ |
71 |
|
TV regV[ MR * NR ]; |
72 |
|
|
73 |
|
if ( !aux->pc ) // Initialize |
74 |
|
{ |
75 |
|
#pragma unroll |
76 |
|
for ( int j = 0; j < NR; j ++ ) |
77 |
|
#pragma simd |
78 |
|
for ( int i = 0; i < MR; i ++ ) |
79 |
|
regV[ j * MR + i ] = initV; |
80 |
|
} |
81 |
|
else // accumulate |
82 |
|
{ |
83 |
|
#pragma unroll |
84 |
|
for ( int j = 0; j < NR; j ++ ) |
85 |
|
#pragma simd |
86 |
|
for ( int i = 0; i < MR; i ++ ) |
87 |
|
regV[ j * MR + i ] = v[ j * cs_c + i * rs_c ]; |
88 |
|
} |
89 |
|
|
90 |
|
// semiring rank-k update |
91 |
|
for ( int p = 0; p < k; p ++ ) |
92 |
|
{ |
93 |
|
#pragma unroll |
94 |
|
for ( int j = 0; j < NR; j ++ ) |
95 |
|
#pragma simd |
96 |
|
for ( int i = 0; i < MR; i ++ ) |
97 |
|
regV[ j * MR + i ] = |
98 |
|
op1( regV[ j * MR + i ], op2( a[ p * MR + i ], b[ p * NR + j ] ) ); |
99 |
|
|
100 |
|
} |
101 |
|
|
102 |
|
// store back |
103 |
|
#pragma unroll |
104 |
|
for ( int j = 0; j < NR; j ++ ) |
105 |
|
#pragma simd |
106 |
|
for ( int i = 0; i < MR; i ++ ) |
107 |
|
v[ j * cs_c + i * rs_c ] = regV[ j * MR + i ]; |
108 |
|
|
109 |
|
}; |
110 |
|
}; |
111 |
|
|
112 |
|
|
113 |
|
#endif /** define SEMIRING_MRXNR_HPP */ |