HMLP: High-performance Machine Learning Primitives
lowrank.hpp
1 
23 #ifndef LOWRANK_HPP
24 #define LOWRANK_HPP
25 
26 #include <assert.h>
27 #include <typeinfo>
28 #include <algorithm>
29 #include <random>
30 
31 
32 #include <hmlp.h>
33 #include <hmlp_base.hpp>
34 
35 
36 using namespace std;
37 using namespace hmlp;
38 
39 
40 namespace hmlp
41 {
42 namespace lowrank
43 {
44 
45 
46 
47 
48 
49 
50 
52 //template<typename T>
53 //void id
54 //(
55 // int m, int n, int maxs,
56 // std::vector<T> A,
57 // std::vector<size_t> &skels, hmlp::Data<T> &proj, std::vector<int> &jpvt
58 //)
59 //{
60 // int nb = 512;
61 // int lwork = 2 * n + ( n + 1 ) * nb;
62 // std::vector<T> work( lwork );
63 // std::vector<T> tau( std::min( m, n ) );
64 // std::vector<T> S, Z;
65 // std::vector<T> A_tmp = A;
66 //
67 // // Early return
68 // if ( n <= maxs )
69 // {
70 // skels.resize( n );
71 // proj.resize( n, n, 0.0 );
72 // for ( int i = 0; i < n; i ++ )
73 // {
74 // skels[i] = i;
75 // proj[ i * proj.row() + i ] = 1.0;
76 // }
77 // return;
78 // }
79 //
80 // // Initilize jpvt to zeros. Otherwise, GEQP3 will permute A.
81 // jpvt.resize( n, 0 );
82 //
83 // // Traditional pivoting QR (GEQP3)
84 // hmlp::xgeqp3
85 // (
86 // m, n,
87 // A_tmp.data(), m,
88 // jpvt.data(),
89 // tau.data(),
90 // work.data(), lwork
91 // );
92 // //printf( "end xgeqp3\n" );
93 //
94 // jpvt.resize( maxs );
95 // skels.resize( maxs );
96 //
97 // // Now shift jpvt from 1-base to 0-base index.
98 // for ( int j = 0; j < jpvt.size(); j ++ )
99 // {
100 // jpvt[ j ] = jpvt[ j ] - 1;
101 // skels[ j ] = jpvt[ j ];
102 // }
103 //
104 // // TODO: Here we only need several things to get rid of xgels.
105 // //
106 // // 0. R11 = zeros( s )
107 // // 1. get R11 = up_tiangular( A_tmp( 1:s, 1:s ) )
108 // // 2. get proj( 1:s, jpvt( 1:n ) ) = A_tmp( 1:s 1:n )
109 // // 3. xtrsm( "L", "U", "N", "N", s, n, 1.0, R11.data(), s, proj.data(), s )
110 //
111 //
112 //
113 //
114 //
115 // Z.resize( m * jpvt.size() );
116 //
117 // for ( int j = 0; j < jpvt.size(); j ++ )
118 // {
119 // // reuse Z
120 // for ( int i = 0; i < m; i ++ )
121 // {
122 // Z[ j * m + i ] = A[ jpvt[ j ] * m + i ];
123 // }
124 // }
125 // auto A_skel = Z;
126 //
127 //
128 // S = A;
129 // // P (overwrite S) = pseudo-inverse( Z ) * S
130 // hmlp::xgels
131 // (
132 // "N",
133 // m, jpvt.size(), n,
134 // Z.data(), m,
135 // S.data(), m,
136 // work.data(), lwork
137 // );
138 //
139 //
140 // // Fill in proj
141 // proj.resize( jpvt.size(), n );
142 // for ( int j = 0; j < n; j ++ )
143 // {
144 // for ( int i = 0; i < jpvt.size(); i ++ )
145 // {
146 // proj[ j * jpvt.size() + i ] = S[ j * m + i ];
147 // }
148 // }
149 //
150 //
151 //
152 //#ifdef DEBUG_SKEL
153 // double nrm = hmlp_norm( m, n, A.data(), m );
154 //
155 // hmlp::xgemm
156 // (
157 // "N", "N",
158 // m, n, jpvt.size(),
159 // -1.0, A_skel.data(), m,
160 // S.data(), m,
161 // 1.0, A.data(), m
162 // );
163 //
164 // double err = hmlp_norm( m, n, A.data(), m );
165 // printf( "m %d n %d k %lu absolute l2 error %E related l2 error %E\n",
166 // m, n, jpvt.size(),
167 // err, err / nrm );
168 //#endif
169 //
170 //}; // end id()
171 //
172 
173 
174 
175 
180 template<typename T>
181 void id
182 (
183  bool use_adaptive_ranks, bool secure_accuracy,
184  int m, int n, int maxs, T stol,
185  Data<T> A,
186  vector<size_t> &skels, Data<T> &proj, vector<int> &jpvt
187 )
188 {
189  int s;
190  int nb = 512;
191  int lwork = 2 * n + ( n + 1 ) * nb;
192  std::vector<T> work( lwork );
193  std::vector<T> tau( std::min( m, n ) );
194  hmlp::Data<T> S, Z;
195  hmlp::Data<T> A_tmp = A;
196 
198  //assert( m >= n );
199 
200  // Initilize jpvt to zeros. Otherwise, GEQP3 will permute A.
201  jpvt.clear();
202  jpvt.resize( n, 0 );
203 
205 //#ifdef HMLP_USE_CUDA
206 // auto *dev = hmlp_get_device( 0 );
207 // cublasHandle_t &handle =
208 // reinterpret_cast<hmlp::gpu::Nvidia*>( dev )->gethandle( 0 );
209 // hmlp::xgeqp3
210 // (
211 // handle,
212 // m, n,
213 // A_tmp.data(), m,
214 // jpvt.data(),
215 // tau.data(),
216 // work.data(), lwork
217 // );
218 //#else
220  //hmlp::xgeqp3
221  (
222  m, n,
223  A_tmp.data(), m,
224  jpvt.data(),
225  tau.data(),
226  work.data(), lwork
227  );
228 //#endif
229  //printf( "end xgeqp3\n" );
230 
232  for ( int j = 0; j < jpvt.size(); j ++ ) jpvt[ j ] = jpvt[ j ] - 1;
233 
235  for ( s = 1; s < n; s ++ )
236  {
237  if ( s > maxs || std::abs( A_tmp[ s * m + s ] ) < stol ) break;
238  //if ( s > maxs || std::abs( A_tmp[ s * m + s ] ) / std::abs( A_tmp[ 0 ] ) < stol ) break;
239  }
240 
242  if ( !use_adaptive_ranks ) s = std::min( maxs, n );
243 
245  if ( s > maxs )
246  {
247  //if ( LEVELRESTRICTION ) /** abort */
248  if ( secure_accuracy )
249  {
250  skels.clear();
251  proj.resize( 0, 0 );
252  jpvt.resize( 0 );
253  return;
254  }
255  else
256  {
257  s = maxs;
258  }
259  }
260 
262  skels.resize( s );
263  for ( int j = 0; j < skels.size(); j ++ ) skels[ j ] = jpvt[ j ];
264 
265 
266  // TODO: Here we only need several things to get rid of xgels.
267  //
268  // 0. R11 = zeros( s )
269  // 1. get R11 = up_tiangular( A_tmp( 1:s, 1:s ) )
270  // 2. get proj( 1:s, jpvt( 1:n ) ) = A_tmp( 1:s 1:n )
271  // 3. xtrsm( "L", "U", "N", "N", s, n, 1.0, R11.data(), s, proj.data(), s )
272 
273 
275  if ( true )
276  {
278  proj.clear();
279  proj.resize( s, n, 0.0 );
280 
281  for ( int j = 0; j < n; j ++ )
282  {
283  for ( int i = 0; i < s; i ++ )
284  {
285  if ( j < s )
286  {
287  if ( j >= i ) proj[ j * s + i ] = A_tmp[ j * m + i ];
288  else proj[ j * s + i ] = 0.0;
289  }
290  else
291  {
292  proj[ j * s + i ] = A_tmp[ j * m + i ];
293  }
294  }
295  }
296  }
297  else
298  {
299  Z.resize( m, skels.size() );
300 
301  for ( int j = 0; j < skels.size(); j ++ )
302  {
303  for ( int i = 0; i < m; i ++ )
304  {
305  Z[ j * m + i ] = A[ skels[ j ] * m + i ];
306  }
307  }
308  auto A_skel = Z;
309 
310  S = A;
311  // P (overwrite S) = pseudo-inverse( Z ) * S
313  (
314  "N",
315  m, skels.size(), n,
316  Z.data(), m,
317  S.data(), m,
318  work.data(), lwork
319  );
320 
321  // Fill in proj
322  proj.resize( skels.size(), n );
323  for ( int j = 0; j < n; j ++ )
324  {
325  for ( int i = 0; i < skels.size(); i ++ )
326  {
327  proj[ j * skels.size() + i ] = S[ j * m + i ];
328  }
329  }
330 
331 #ifdef DEBUG_SKEL
332  double nrm = hmlp_norm( m, n, A.data(), m );
333 
335  (
336  "N", "N",
337  m, n, skels.size(),
338  -1.0, A_skel.data(), m,
339  S.data(), m,
340  1.0, A.data(), m
341  );
342 
343  double err = hmlp_norm( m, n, A.data(), m );
344  printf( "m %d n %d k %lu absolute l2 error %E related l2 error %E\n",
345  m, n, skels.size(),
346  err, err / nrm );
347 #endif
348 
349  }
350 
351 }; // end id()
352 
353 
354 
358 template<bool ONESHOT = false,typename T>
359 void nystrom( size_t m, size_t n, size_t r,
360  std::vector<T> &A, std::vector<T> &C,
361  std::vector<T> &U, std::vector<T> &V )
362 {
364  if ( C.size() != r * r )
365  {
368  }
369  else
370  {
372  }
373 
375  if ( ONESHOT )
376  {
377  }
378 
379 
380 };
386 template<typename T>
387 void pmid
388 (
389  int m,
390  int n,
391  int maxs,
392  std::vector<T> A,
393  std::vector<int> &jpiv, std::vector<T> &P
394 )
395 {
396  int rank = maxs + 10;
397  int lwork = 512 * n;
398 
399  printf( "maxs %d\n", maxs );
400 
401  if ( rank > n ) rank = n;
402 
403 
404  std::vector<T> S = A;
405  std::vector<T> O( n * rank );
406  std::vector<T> Z( m * rank );
407  std::vector<T> tau( std::min( m, n ), 0.0 );
408  std::vector<T> work( lwork, 0.0 );
409 
410  std::default_random_engine generator;
411  std::normal_distribution<T> gaussian( 0.0, 1.0 );
412 
413  // generate O n-by-(maxs+10) random matrix (need to be Gaussian samples)
414  #pragma omp parallel for
415  for ( int i = 0; i < n * rank; i ++ )
416  {
417  O[ i ] = gaussian( generator );
418  }
419 
420 #ifdef DEBUG_SKEL
421  printf( "O\n" );
422  hmlp_printmatrix( n, rank, O.data(), n );
423  printf( "A\n" );
424  hmlp_printmatrix( m, n, A.data(), m );
425 #endif
426 
427 
428  // Z = 0.0 * Z + 1.0 * A * O
430  (
431  "N", "N",
432  m, rank, n,
433  1.0, A.data(), m,
434  O.data(), n,
435  0.0, Z.data(), m
436  );
437  printf( "here xgemm\n" );
438 
439 
440 #ifdef DEBUG_SKEL
441  printf( "Z\n" );
442  hmlp_printmatrix( m, rank, Z.data(), m );
443 #endif
444 
445 
446 
447 
448 
449 
450  // [Q,~] = qr(Z,0), so I need the orthogonal matrix
452  (
453  m, rank,
454  Z.data(), m,
455  tau.data(),
456  work.data(), lwork
457  );
458  printf( "here xgeqrf\n" );
459 
460 #ifdef DEBUG_SKEL
461  printf( "Z\n" );
462  hmlp_printmatrix( m, rank, Z.data(), m );
463 #endif
464 
465  // S = Q' * A
467  (
468  "L", "T",
469  m, n, rank,
470  Z.data(), m,
471  tau.data(),
472  S.data(), m,
473  work.data(), lwork
474  );
475  printf( "here xormqr\n" );
476 
477 
478  for ( int i = 0; i < rank; i ++ )
479  {
480  for ( int j = 0; j < n; j ++ )
481  {
482  S[ j * m + i ] = fabs( S[ j * m + i ] );
483  }
484  }
485 
486 
487 
488 
489 
490 
491  // abs( S(1:rank,1:n) ) and select the largest entry per row.
492  while ( jpiv.size() < maxs )
493  {
494  for ( int i = 0; i < rank; i ++ )
495  {
496  std::pair<T,int> pivot( 0.0, -1 );
497 
498  for ( int j = 0; j < n; j ++ )
499  {
500  if ( S[ j * m + i ] > pivot.first )
501  {
502  pivot = std::make_pair( S[ j * m + i ], j );
503  }
504  }
505  if ( pivot.second != -1 )
506  {
507  jpiv.push_back( pivot.second );
508  }
509  }
510 
511  std::sort( jpiv.begin(), jpiv.end() );
512  auto last = std::unique( jpiv.begin(), jpiv.end() );
513  jpiv.erase( last, jpiv.end() );
514 
515  printf( "Total %lu pivots\n", jpiv.size() );
516 
517  // zero out S
518  for ( int j = 0; j < jpiv.size(); j ++ )
519  {
520  for ( int i = 0; i < rank; i ++ )
521  {
522  S[ jpiv[ j ] * m + i ] = 0.0;
523  }
524  }
525  }
526 
527  jpiv.resize( maxs );
528 
529 #ifdef DEBUG_SKEL
530  printf( "jpjv:\n" );
531  for ( int j = 0; j < jpiv.size(); j ++ )
532  {
533  printf( "%12d ", jpiv[ j ] );
534  }
535 #endif
536  // std::sort( ipiv.begin(), ipiv.end() );
537  // auto last = std::unique( ipiv.begin(), ipiv.end() );
538  // ipiv.erase( last, ipiv.end() );
539 
540  // printf( "Total %lu pivots\n", ipiv.size() );
541 
542 
543 
544 
545 
546  Z.resize( m * jpiv.size() );
547 
548  for ( int j = 0; j < jpiv.size(); j ++ )
549  {
550  // reuse Z
551  for ( int i = 0; i < m; i ++ )
552  {
553  Z[ j * m + i ] = A[ jpiv[ j ] * m + i ];
554  }
555  }
556  auto A_skel = Z;
557  //P.resize( ipiv.size() * n );
558 
559 
560 #ifdef DEBUG_SKEL
561  printf( "jpjv:\n" );
562  for ( int j = 0; j < jpiv.size(); j ++ )
563  {
564  printf( "%12d ", jpiv[ j ] );
565  }
566  printf( "\n" );
567  printf( "Z = [\n" );
568  hmlp_printmatrix( m, jpiv.size(), Z.data(), m );
569 #endif
570 
571 
572  S = A;
573  // P (overwrite S) = pseudo-inverse( Z ) * S
575  (
576  "N",
577  m, jpiv.size(), n,
578  Z.data(), m,
579  S.data(), m,
580  work.data(), lwork
581  );
582 
583 
584 #ifdef DEBUG_SKEL
585  printf( "S\n" );
586  hmlp_printmatrix<true, true>( m, n, S.data(), m );
587 
588  double nrm = hmlp_norm( m, n, A.data(), m );
589 
591  (
592  "N", "N",
593  m, n, jpiv.size(),
594  -1.0, A_skel.data(), m,
595  S.data(), m,
596  1.0, A.data(), m
597  );
598 
599  double err = hmlp_norm( m, n, A.data(), m );
600 
601  printf( "absolute l2 error %E related l2 error %E\n", err, err / nrm );
602 #endif
603 
604 
605 #ifdef DEBUG_SKEL
606  printf( "A\n" );
607  hmlp_printmatrix<true, true>( m, n, A.data(), m );
608 #endif
609 
610 }; // end pmid()
611 
612 template<class Node>
613 void skeletonize( Node *node )
614 {
615  auto lchild = node->lchild;
616  auto rchild = node->rchild;
617 
618  // random sampling or important sampling for rows.
619  std::vector<size_t> amap;
620 
621  std::vector<size_t> bmap;
622 
623  //bmap = lchild
624 
625 
626  printf( "id %d l %d n %d\n", node->treelist_id, node->l, node->n );
627 
628 }; // end skeletonize()
629 
630 
631 
632 
633 //template<typename CONTEXT>
634 //class Task : public hmlp::Task
635 //{
636 // public:
637 //
638 // /* function ptr */
639 // void (*function)(Task<CONTEXT>*);
640 //
641 // /* argument ptr */
642 // CONTEXT *arg;
643 //
644 // void Set( CONTEXT *user_arg )
645 // {
646 // name = std::string( "Skeletonization" );
647 // arg = user_arg;
648 // }
649 //
650 // void Execute( Worker* user_worker )
651 // {
652 // printf( "SkeletonizeTask Execute 2\n" );
653 // }
654 //
655 // private:
656 //
657 //}; // end class Task
658 
659 //template<class Node>
660 //class Task : public Task
661 //{
662 // public:
663 //
664 // Node *arg;
665 //
666 // void Set( Node *user_arg )
667 // {
668 // name = std::string( "Skeletonization" );
669 // arg = user_arg;
670 // };
671 //
672 // void Execute( Worker* user_worker )
673 // {
674 // //printf( "SkeletonizeTask Execute 2\n" );
675 // skeletonize( arg );
676 // };
677 //
678 // private:
679 //};
680 
681 
682 
683 
684 
685 
686 
687 };
688 };
690 #endif // define LOWRANK_HPP
void xgeqrf(int m, int n, double *A, int lda, double *tau, double *work, int lwork)
DGEQRF wrapper.
Definition: blas_lapack.cpp:694
void xormqr(const char *side, const char *trans, int m, int n, int k, double *A, int lda, double *tau, double *C, int ldc, double *work, int lwork)
DORMQR wrapper.
Definition: blas_lapack.cpp:811
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
void xgels(const char *trans, int m, int n, int nrhs, double *A, int lda, double *B, int ldb, double *work, int lwork)
DGELS wrapper.
Definition: blas_lapack.cpp:997
void xgeqp4(int m, int n, double *A, int lda, int *jpvt, double *tau, double *work, int lwork)
DGEQP4 wrapper.
Definition: blas_lapack.cpp:935
Definition: gofmm.hpp:83