HMLP: High-performance Machine Learning Primitives
blas_lapack.hpp
1 
22 #ifndef HMLP_BLAS_LAPACK_H
23 #define HMLP_BLAS_LAPACK_H
24 
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <omp.h>
28 
29 #ifdef HMLP_USE_CUDA
30 #include <cuda_runtime.h>
31 #include <cublas_v2.h>
32 #include <cusparse_v2.h>
33 #endif
34 
35 namespace hmlp
36 {
37 
38 void xgemm
39 (
40  const char *transA, const char *transB,
41  int m, int n, int k,
42  float alpha, const float *A, int lda,
43  const float *B, int ldb,
44  float beta, float *C, int ldc
45 );
46 
47 void xgemm
48 (
49  const char *transA, const char *transB,
50  int m, int n, int k,
51  double alpha, const double *A, int lda,
52  const double *B, int ldb,
53  double beta, double *C, int ldc
54 );
55 
56 
57 
58 void xsyrk
59 (
60  const char *uplo, const char *trans,
61  int n, int k,
62  double alpha, double *A, int lda,
63  double beta, double *C, int ldc
64 );
65 
66 void xsyrk
67 (
68  const char *uplo, const char *trans,
69  int n, int k,
70  float alpha, float *A, int lda,
71  float beta, float *C, int ldc
72 );
73 
74 
75 void xtrsm
76 (
77  const char *side, const char *uplo,
78  const char *transA, const char *diag,
79  int m, int n,
80  float alpha, float *A, int lda,
81  float *B, int ldb
82 );
83 
84 void xtrsm
85 (
86  const char *side, const char *uplo,
87  const char *transA, const char *diag,
88  int m, int n,
89  double alpha, double *A, int lda,
90  double *B, int ldb
91 );
92 
93 void xtrmm
94 (
95  const char *side, const char *uplo,
96  const char *transA, const char *diag,
97  int m, int n,
98  float alpha, float *A, int lda,
99  float *B, int ldb
100 );
101 
102 void xtrmm
103 (
104  const char *side, const char *uplo,
105  const char *transA, const char *diag,
106  int m, int n,
107  double alpha, double *A, int lda,
108  double *B, int ldb
109 );
110 
111 void xlaswp
112 (
113  int n, double *A, int lda,
114  int k1, int k2, int *ipiv, int incx
115 );
116 
117 void xlaswp
118 (
119  int n, float *A, int lda,
120  int k1, int k2, int *ipiv, int incx
121 );
122 
124 void xpotrf( const char *uplo, int n, double *A, int lda );
125 void xpotrf( const char *uplo, int n, float *A, int lda );
126 void xpotrs( const char *uplo, int n, int nrhs, double *A, int lda, double *B, int ldb );
127 void xpotrs( const char *uplo, int n, int nrhs, float *A, int lda, float *B, int ldb );
128 void xposv( const char *uplo, int n, int nrhs, double *A, int lda, double *B, int ldb );
129 void xposv( const char *uplo, int n, int nrhs, float *A, int lda, float *B, int ldb );
130 
131 
132 
134 void xgetrf( int m, int n, double *A, int lda, int *ipiv );
135 void xgetrf( int m, int n, float *A, int lda, int *ipiv );
136 void xgetrs( const char *trans, int m, int nrhs, double *A, int lda, int *ipiv, double *B, int ldb );
137 void xgetrs( const char *trans, int m, int nrhs, float *A, int lda, int *ipiv, float *B, int ldb );
138 
140 void xgeqrf( int m, int n, double *A, int lda, double *tau, double *work, int lwork );
141 void xgeqrf( int m, int n, float *A, int lda, float *tau, float *work, int lwork );
142 void xorgqr( int m, int n, int k, double *A, int lda, double *tau, double *work, int lwork );
143 void xorgqr( int m, int n, int k, float *A, int lda, float *tau, float *work, int lwork );
144 void xormqr( const char *side, const char *trans,
145  int m, int n, int k, float *A, int lda, float *tau, float *C, int ldc, float *work, int lwork );
146 void xormqr( const char *side, const char *trans,
147  int m, int n, int k, double *A, int lda, double *tau, double *C, int ldc, double *work, int lwork );
148 void xgeqp3( int m, int n, float *A, int lda, int *jpvt, float *tau, float *work, int lwork );
149 void xgeqp3( int m, int n, double *A, int lda, int *jpvt, double *tau, double *work, int lwork );
150 
151 void xgeqp4( int m, int n, float *A, int lda, int *jpvt, float *tau, float *work, int lwork );
152 void xgeqp4( int m, int n, double *A, int lda, int *jpvt, double *tau, double *work, int lwork );
153 void xgels( const char *trans, int m, int n, int nrhs, float *A, int lda, float *B, int ldb, float *work, int lwork );
154 void xgels( const char *trans, int m, int n, int nrhs, double *A, int lda, double *B, int ldb, double *work, int lwork );
155 
156 
157 void xgecon( const char *norm, int n, float *A, int lda, float anorm, float *rcond, float *work, int *iwork );
158 void xgecon( const char *norm, int n, double *A, int lda, double anorm, double *rcond, double *work, int *iwork );
159 
160 void xstev( const char *jobz, int n, double *D, double *E, double *Z, int ldz, double *work );
161 void xstev( const char *jobz, int n, float *D, float *E, float *Z, int ldz, float *work );
162 
163 
164 double xdot( int n, const double *dx, int incx, const double *dy, int incy );
165 float xdot( int n, const float *dx, int incx, const float *dy, int incy );
166 
167 double xnrm2( int n, double *x, int incx );
168 float xnrm2( int n, float *x, int incx );
169 
170 
171 
172 #ifdef HMLP_USE_CUDA
173 
175 void xgemm
176 (
177  cublasHandle_t &handle,
178  cublasOperation_t transa, cublasOperation_t transb,
179  int m, int n, int k,
180  float alpha,
181  float *A, int lda,
182  float *B, int ldb, float beta,
183  float *C, int ldc
184 );
185 
187 void xgemm
188 (
189  cublasHandle_t &handle,
190  cublasOperation_t transa, cublasOperation_t transb,
191  int m, int n, int k,
192  double alpha,
193  double *A, int lda,
194  double *B, int ldb, double beta,
195  double *C, int ldc
196 );
197 
199 void xgemm_batched
200 (
201  cublasHandle_t &handle,
202  cublasOperation_t transa, cublasOperation_t transb,
203  int m, int n, int k,
204  float alpha,
205  float *Aarray[], int lda,
206  float *Barray[], int ldb, float beta,
207  float *Carray[], int ldc,
208  int batchSize
209 );
210 
212 void xgemm_batched
213 (
214  cublasHandle_t &handle,
215  cublasOperation_t transa, cublasOperation_t transb,
216  int m, int n, int k,
217  double alpha,
218  double *Aarray[], int lda,
219  double *Barray[], int ldb, double beta,
220  double *Carray[], int ldc,
221  int batchSize
222 );
223 
224 
225 void xgeqp3
226 (
227  cublasHandle_t &handle,
228  int m, int n,
229  float *A, int lda,
230  int *jpvt,
231  float *tau,
232  float *work, int lwork
233 );
234 
235 
236 void xgeqp3
237 (
238  cublasHandle_t &handle,
239  int m, int n,
240  double *A, int lda,
241  int *jpvt,
242  double *tau,
243  double *work, int lwork
244 );
245 
246 #endif
248 };
250 #endif
double xnrm2(int n, double *x, int incx)
DNRM2 wrapper.
Definition: blas_lapack.cpp:83
void xstev(const char *jobz, int n, double *D, double *E, double *Z, int ldz, double *work)
Definition: blas_lapack.cpp:1127
void xgeqrf(int m, int n, double *A, int lda, double *tau, double *work, int lwork)
DGEQRF wrapper.
Definition: blas_lapack.cpp:694
void xgeqp3(int m, int n, double *A, int lda, int *jpvt, double *tau, double *work, int lwork)
DGEQP3 wrapper.
Definition: blas_lapack.cpp:873
void xgetrf(int m, int n, double *A, int lda, int *ipiv)
DGETRF wrapper.
Definition: blas_lapack.cpp:546
void xorgqr(int m, int n, int k, double *A, int lda, double *tau, double *work, int lwork)
SORGQR wrapper.
Definition: blas_lapack.cpp:757
void xormqr(const char *side, const char *trans, int m, int n, int k, double *A, int lda, double *tau, double *C, int ldc, double *work, int lwork)
DORMQR wrapper.
Definition: blas_lapack.cpp:811
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
void xgels(const char *trans, int m, int n, int nrhs, double *A, int lda, double *B, int ldb, double *work, int lwork)
DGELS wrapper.
Definition: blas_lapack.cpp:997
void xgetrs(const char *trans, int m, int nrhs, double *A, int lda, int *ipiv, double *B, int ldb)
DGETRS wrapper.
Definition: blas_lapack.cpp:577
double xdot(int n, const double *dx, int incx, const double *dy, int incy)
DDOT wrapper.
Definition: blas_lapack.cpp:50
void xlaswp(int n, double *A, int lda, int k1, int k2, int *ipiv, int incx)
DLASWP wrapper.
Definition: blas_lapack.cpp:450
void xgecon(const char *norm, int n, double *A, int lda, double anorm, double *rcond, double *work, int *iwork)
DGECON wrapper.
Definition: blas_lapack.cpp:631
void xpotrs(const char *uplo, int n, int nrhs, double *A, int lda, double *B, int ldb)
DPOTRS wrapper.
Definition: blas_lapack.cpp:512
void xgeqp4(int m, int n, double *A, int lda, int *jpvt, double *tau, double *work, int lwork)
DGEQP4 wrapper.
Definition: blas_lapack.cpp:935
void xpotrf(const char *uplo, int n, double *A, int lda)
DPOTRF wrapper.
Definition: blas_lapack.cpp:480
void xtrsm(const char *side, const char *uplo, const char *transA, const char *diag, int m, int n, double alpha, double *A, int lda, double *B, int ldb)
DTRSM wrapper.
Definition: blas_lapack.cpp:315
void xtrmm(const char *side, const char *uplo, const char *transA, const char *diag, int m, int n, double alpha, double *A, int lda, double *B, int ldb)
DTRMM wrapper.
Definition: blas_lapack.cpp:385
Definition: gofmm.hpp:83