33 #include <hmlp_base.hpp> 183 bool use_adaptive_ranks,
bool secure_accuracy,
184 int m,
int n,
int maxs, T stol,
186 vector<size_t> &skels,
Data<T> &proj, vector<int> &jpvt
191 int lwork = 2 * n + ( n + 1 ) * nb;
192 std::vector<T> work( lwork );
193 std::vector<T> tau( std::min( m, n ) );
232 for (
int j = 0; j < jpvt.size(); j ++ ) jpvt[ j ] = jpvt[ j ] - 1;
235 for ( s = 1; s < n; s ++ )
237 if ( s > maxs || std::abs( A_tmp[ s * m + s ] ) < stol )
break;
242 if ( !use_adaptive_ranks ) s = std::min( maxs, n );
248 if ( secure_accuracy )
263 for (
int j = 0; j < skels.size(); j ++ ) skels[ j ] = jpvt[ j ];
279 proj.resize( s, n, 0.0 );
281 for (
int j = 0; j < n; j ++ )
283 for (
int i = 0; i < s; i ++ )
287 if ( j >= i ) proj[ j * s + i ] = A_tmp[ j * m + i ];
288 else proj[ j * s + i ] = 0.0;
292 proj[ j * s + i ] = A_tmp[ j * m + i ];
299 Z.resize( m, skels.size() );
301 for (
int j = 0; j < skels.size(); j ++ )
303 for (
int i = 0; i < m; i ++ )
305 Z[ j * m + i ] = A[ skels[ j ] * m + i ];
322 proj.resize( skels.size(), n );
323 for (
int j = 0; j < n; j ++ )
325 for (
int i = 0; i < skels.size(); i ++ )
327 proj[ j * skels.size() + i ] = S[ j * m + i ];
332 double nrm = hmlp_norm( m, n, A.data(), m );
338 -1.0, A_skel.data(), m,
343 double err = hmlp_norm( m, n, A.data(), m );
344 printf(
"m %d n %d k %lu absolute l2 error %E related l2 error %E\n",
358 template<
bool ONESHOT = false,
typename T>
359 void nystrom(
size_t m,
size_t n,
size_t r,
360 std::vector<T> &A, std::vector<T> &C,
361 std::vector<T> &U, std::vector<T> &V )
364 if ( C.size() != r * r )
393 std::vector<int> &jpiv, std::vector<T> &P
396 int rank = maxs + 10;
399 printf(
"maxs %d\n", maxs );
401 if ( rank > n ) rank = n;
404 std::vector<T> S = A;
405 std::vector<T> O( n * rank );
406 std::vector<T> Z( m * rank );
407 std::vector<T> tau( std::min( m, n ), 0.0 );
408 std::vector<T> work( lwork, 0.0 );
410 std::default_random_engine generator;
411 std::normal_distribution<T> gaussian( 0.0, 1.0 );
414 #pragma omp parallel for 415 for (
int i = 0; i < n * rank; i ++ )
417 O[ i ] = gaussian( generator );
422 hmlp_printmatrix( n, rank, O.data(), n );
424 hmlp_printmatrix( m, n, A.data(), m );
437 printf(
"here xgemm\n" );
442 hmlp_printmatrix( m, rank, Z.data(), m );
458 printf(
"here xgeqrf\n" );
462 hmlp_printmatrix( m, rank, Z.data(), m );
475 printf(
"here xormqr\n" );
478 for (
int i = 0; i < rank; i ++ )
480 for (
int j = 0; j < n; j ++ )
482 S[ j * m + i ] = fabs( S[ j * m + i ] );
492 while ( jpiv.size() < maxs )
494 for (
int i = 0; i < rank; i ++ )
496 std::pair<T,int> pivot( 0.0, -1 );
498 for (
int j = 0; j < n; j ++ )
500 if ( S[ j * m + i ] > pivot.first )
502 pivot = std::make_pair( S[ j * m + i ], j );
505 if ( pivot.second != -1 )
507 jpiv.push_back( pivot.second );
511 std::sort( jpiv.begin(), jpiv.end() );
512 auto last = std::unique( jpiv.begin(), jpiv.end() );
513 jpiv.erase( last, jpiv.end() );
515 printf(
"Total %lu pivots\n", jpiv.size() );
518 for (
int j = 0; j < jpiv.size(); j ++ )
520 for (
int i = 0; i < rank; i ++ )
522 S[ jpiv[ j ] * m + i ] = 0.0;
531 for (
int j = 0; j < jpiv.size(); j ++ )
533 printf(
"%12d ", jpiv[ j ] );
546 Z.resize( m * jpiv.size() );
548 for (
int j = 0; j < jpiv.size(); j ++ )
551 for (
int i = 0; i < m; i ++ )
553 Z[ j * m + i ] = A[ jpiv[ j ] * m + i ];
562 for (
int j = 0; j < jpiv.size(); j ++ )
564 printf(
"%12d ", jpiv[ j ] );
568 hmlp_printmatrix( m, jpiv.size(), Z.data(), m );
586 hmlp_printmatrix<true, true>( m, n, S.data(), m );
588 double nrm = hmlp_norm( m, n, A.data(), m );
594 -1.0, A_skel.data(), m,
599 double err = hmlp_norm( m, n, A.data(), m );
601 printf(
"absolute l2 error %E related l2 error %E\n", err, err / nrm );
607 hmlp_printmatrix<true, true>( m, n, A.data(), m );
613 void skeletonize( Node *node )
615 auto lchild = node->lchild;
616 auto rchild = node->rchild;
619 std::vector<size_t> amap;
621 std::vector<size_t> bmap;
626 printf(
"id %d l %d n %d\n", node->treelist_id, node->l, node->n );
690 #endif // define LOWRANK_HPP
void xgeqrf(int m, int n, double *A, int lda, double *tau, double *work, int lwork)
DGEQRF wrapper.
Definition: blas_lapack.cpp:694
void xormqr(const char *side, const char *trans, int m, int n, int k, double *A, int lda, double *tau, double *C, int ldc, double *work, int lwork)
DORMQR wrapper.
Definition: blas_lapack.cpp:811
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
void xgels(const char *trans, int m, int n, int nrhs, double *A, int lda, double *B, int ldb, double *work, int lwork)
DGELS wrapper.
Definition: blas_lapack.cpp:997
void xgeqp4(int m, int n, double *A, int lda, int *jpvt, double *tau, double *work, int lwork)
DGEQP4 wrapper.
Definition: blas_lapack.cpp:935