HMLP: High-performance Machine Learning Primitives
matrix.hpp
1 #ifndef _PVFMM_MATRIX_HPP_
2 #define _PVFMM_MATRIX_HPP_
3 
4 #include <cstdint>
5 #include <cstdlib>
6 
7 #include <pvfmm/vector.hpp>
8 #include <pvfmm/common.hpp>
9 
10 namespace pvfmm {
11 
12 template <class ValueType> class Vector;
13 template <class ValueType> class Permutation;
14 
15 template <class ValueType> class Matrix {
16 
17  public:
18  Matrix();
19 
20  Matrix(Long dim1, Long dim2, Iterator<ValueType> data_ = NULL, bool own_data_ = true);
21 
22  Matrix(const Matrix<ValueType>& M);
23 
24  ~Matrix();
25 
26  void Swap(Matrix<ValueType>& M);
27 
28  void ReInit(Long dim1, Long dim2, Iterator<ValueType> data_ = NULL, bool own_data_ = true);
29 
30  void Write(const char* fname) const;
31 
32  void Read(const char* fname);
33 
34  Long Dim(Long i) const;
35 
36  void SetZero();
37 
38  Iterator<ValueType> Begin();
39 
40  ConstIterator<ValueType> Begin() const;
41 
42  Matrix<ValueType>& operator=(const Matrix<ValueType>& M);
43 
44  Matrix<ValueType>& operator+=(const Matrix<ValueType>& M);
45 
46  Matrix<ValueType>& operator-=(const Matrix<ValueType>& M);
47 
48  Matrix<ValueType> operator+(const Matrix<ValueType>& M2) const;
49 
50  Matrix<ValueType> operator-(const Matrix<ValueType>& M2) const;
51 
52  ValueType& operator()(Long i, Long j);
53 
54  const ValueType& operator()(Long i, Long j) const;
55 
56  Iterator<ValueType> operator[](Long i);
57 
58  ConstIterator<ValueType> operator[](Long i) const;
59 
60  Matrix<ValueType> operator*(const Matrix<ValueType>& M) const;
61 
62  static void GEMM(Matrix<ValueType>& M_r, const Matrix<ValueType>& A, const Matrix<ValueType>& B, ValueType beta = 0.0);
63 
64  static void GEMM(Matrix<ValueType>& M_r, const Permutation<ValueType>& P, const Matrix<ValueType>& M, ValueType beta = 0.0);
65 
66  static void GEMM(Matrix<ValueType>& M_r, const Matrix<ValueType>& M, const Permutation<ValueType>& P, ValueType beta = 0.0);
67 
68  // cublasgemm wrapper
69  static void CUBLASGEMM(Matrix<ValueType>& M_r, const Matrix<ValueType>& A, const Matrix<ValueType>& B, ValueType beta = 0.0);
70 
71  void RowPerm(const Permutation<ValueType>& P);
72  void ColPerm(const Permutation<ValueType>& P);
73 
74  Matrix<ValueType> Transpose() const;
75 
76  static void Transpose(Matrix<ValueType>& M_r, const Matrix<ValueType>& M);
77 
78  // Original matrix is destroyed.
79  void SVD(Matrix<ValueType>& tU, Matrix<ValueType>& tS, Matrix<ValueType>& tVT);
80 
81  // Original matrix is destroyed.
82  Matrix<ValueType> pinv(ValueType eps = -1);
83 
84  private:
85  StaticArray<Long, 2> dim;
86  Iterator<ValueType> data_ptr;
87  bool own_data;
88 };
89 
90 template <class ValueType> std::ostream& operator<<(std::ostream& output, const Matrix<ValueType>& M);
91 
98 template <class ValueType> class Permutation {
99 
100  public:
101  Permutation() {}
102 
103  Permutation(Long size);
104 
105  static Permutation<ValueType> RandPerm(Long size);
106 
107  Matrix<ValueType> GetMatrix() const;
108 
109  Long Dim() const;
110 
111  Permutation<ValueType> Transpose();
112 
113  Permutation<ValueType> operator*(const Permutation<ValueType>& P) const;
114 
115  Matrix<ValueType> operator*(const Matrix<ValueType>& M) const;
116 
117  Vector<Long> perm;
118  Vector<ValueType> scal;
119 };
120 
121 template <class ValueType> Matrix<ValueType> operator*(const Matrix<ValueType>& M, const Permutation<ValueType>& P);
122 
123 template <class ValueType> std::ostream& operator<<(std::ostream& output, const Permutation<ValueType>& P);
124 
125 } // end namespace
126 
127 #include <pvfmm/matrix.txx>
128 
129 #endif //_PVFMM_MATRIX_HPP_
Definition: matrix.hpp:13
Definition: cheb_utils.hpp:12
Definition: matrix.hpp:15
Definition: matrix.hpp:12