HMLP: High-performance Machine Learning Primitives
semiring_mrxnr.hpp
1 #ifndef SEMIRING_MRXNR_HPP
2 #define SEMIRING_MRXNR_HPP
3 
4 #include <hmlp_internal.hpp>
5 
6 
7 template<
8 int MR, int NR,
9 typename OP1, typename OP2,
10 typename TA, typename TB, typename TC, typename TV>
12 {
13  const static size_t mr = MR;
14  const static size_t nr = NR;
15  const static size_t pack_mr = MR;
16  const static size_t pack_nr = NR;
17  const static size_t align_size = 32;
18 
19  OP1 op1;
20  OP2 op2;
21  TV initV;
22 
24  inline void operator()
25  (
26  int k,
27  TA *a,
28  TB *b,
29  int len,
30  TV **v_list, int ldv, TV *alpha_list,
32  ) const
33  {
34  TV regV[ MR * NR ] = { 0.0 };
35 
36  // semiring rank-k update
37  for ( int p = 0; p < k; p ++ )
38  {
39  #pragma unroll
40  for ( int j = 0; j < NR; j ++ )
41  #pragma simd
42  for ( int i = 0; i < MR; i ++ )
43  regV[ j * MR + i ] += a[ p * MR + i ] * b[ p * NR + j ];
44  }
45 
46  // store back
47  for ( int t = 0; t < len; t ++ )
48  {
49  #pragma unroll
50  for ( int j = 0; j < NR; j ++ )
51  {
52  #pragma simd
53  for ( int i = 0; i < MR; i ++ )
54  {
55  v_list[ t ][ j * ldv + i ] += alpha_list[ t ] * regV[ j * MR + i ];
56  }
57  }
58  }
59  };
60 
62  inline void operator()
63  (
64  dim_t k,
65  TA *a,
66  TB *b,
67  TV *v, int rs_c, int cs_c,
69  ) const
70  {
71  TV regV[ MR * NR ];
72 
73  if ( !aux->pc ) // Initialize
74  {
75  #pragma unroll
76  for ( int j = 0; j < NR; j ++ )
77  #pragma simd
78  for ( int i = 0; i < MR; i ++ )
79  regV[ j * MR + i ] = initV;
80  }
81  else // accumulate
82  {
83  #pragma unroll
84  for ( int j = 0; j < NR; j ++ )
85  #pragma simd
86  for ( int i = 0; i < MR; i ++ )
87  regV[ j * MR + i ] = v[ j * cs_c + i * rs_c ];
88  }
89 
90  // semiring rank-k update
91  for ( int p = 0; p < k; p ++ )
92  {
93  #pragma unroll
94  for ( int j = 0; j < NR; j ++ )
95  #pragma simd
96  for ( int i = 0; i < MR; i ++ )
97  regV[ j * MR + i ] =
98  op1( regV[ j * MR + i ], op2( a[ p * MR + i ], b[ p * NR + j ] ) );
99 
100  }
101 
102  // store back
103  #pragma unroll
104  for ( int j = 0; j < NR; j ++ )
105  #pragma simd
106  for ( int i = 0; i < MR; i ++ )
107  v[ j * cs_c + i * rs_c ] = regV[ j * MR + i ];
108 
109  };
110 };
111 
112 
113 #endif
Definition: semiring_mrxnr.hpp:11
Definition: hmlp_internal.hpp:38