GCC Code Coverage Report
Directory: . Exec Total Coverage
File: kernel/reference/semiring_mrxnr.hpp Lines: 0 17 0.0 %
Date: 2019-01-14 Branches: 0 20 0.0 %

Line Exec Source
1
#ifndef SEMIRING_MRXNR_HPP
2
#define SEMIRING_MRXNR_HPP
3
4
#include <hmlp_internal.hpp>
5
6
7
template<
8
int MR, int NR,
9
typename OP1, typename OP2,
10
typename TA, typename TB, typename TC, typename TV>
11
struct semiring_mrxnr
12
{
13
  const static size_t mr         = MR;
14
  const static size_t nr         = NR;
15
  const static size_t pack_mr    = MR;
16
  const static size_t pack_nr    = NR;
17
  const static size_t align_size = 32;
18
19
  OP1 op1;
20
  OP2 op2;
21
  TV initV;
22
23
  /** Strassen interface: ignore op1, op2 and initV */
24
  inline void operator()
25
  (
26
    int k,
27
    TA *a,
28
    TB *b,
29
    int len,
30
    TV **v_list, int ldv, TV *alpha_list,
31
    aux_s<TA, TB, TC, TV> *aux
32
  ) const
33
  {
34
    TV regV[ MR * NR ] = { 0.0 };
35
36
    // semiring rank-k update
37
    for ( int p = 0; p < k; p ++ )
38
    {
39
      #pragma unroll
40
      for ( int j = 0; j < NR; j ++ )
41
        #pragma simd
42
        for ( int i = 0; i < MR; i ++ )
43
          regV[ j * MR + i ] += a[ p * MR + i ] * b[ p * NR + j ];
44
    }
45
46
    // store back
47
    for ( int t = 0; t < len; t ++ )
48
    {
49
      #pragma unroll
50
      for ( int j = 0; j < NR; j ++ )
51
      {
52
        #pragma simd
53
        for ( int i = 0; i < MR; i ++ )
54
        {
55
          v_list[ t ][ j * ldv + i ] += alpha_list[ t ] * regV[ j * MR + i ];
56
        }
57
      }
58
    }
59
  };
60
61
  /** Non-Strassen interface */
62
  inline void operator()
63
  (
64
    dim_t k,
65
    TA *a,
66
    TB *b,
67
    TV *v, int rs_c, int cs_c,
68
    aux_s<TA, TB, TC, TV> *aux
69
  ) const
70
  {
71
    TV regV[ MR * NR ];
72
73
    if ( !aux->pc ) // Initialize
74
    {
75
      #pragma unroll
76
      for ( int j = 0; j < NR; j ++ )
77
        #pragma simd
78
        for ( int i = 0; i < MR; i ++ )
79
          regV[ j * MR + i ] = initV;
80
    }
81
    else // accumulate
82
    {
83
      #pragma unroll
84
      for ( int j = 0; j < NR; j ++ )
85
        #pragma simd
86
        for ( int i = 0; i < MR; i ++ )
87
          regV[ j * MR + i ] = v[ j * cs_c + i * rs_c ];
88
    }
89
90
    // semiring rank-k update
91
    for ( int p = 0; p < k; p ++ )
92
    {
93
      #pragma unroll
94
      for ( int j = 0; j < NR; j ++ )
95
        #pragma simd
96
        for ( int i = 0; i < MR; i ++ )
97
          regV[ j * MR + i ] =
98
            op1( regV[ j * MR + i ], op2( a[ p * MR + i ], b[ p * NR + j ] ) );
99
100
    }
101
102
    // store back
103
    #pragma unroll
104
    for ( int j = 0; j < NR; j ++ )
105
      #pragma simd
106
      for ( int i = 0; i < MR; i ++ )
107
        v[ j * cs_c + i * rs_c ] = regV[ j * MR + i ];
108
109
  };
110
};
111
112
113
#endif /** define SEMIRING_MRXNR_HPP */