HMLP: High-performance Machine Learning Primitives
igofmm.hpp
1 
23 #ifndef IGOFMM_HPP
24 #define IGOFMM_HPP
25 
26 
28 using namespace std;
29 using namespace hmlp;
30 
31 
32 
33 
34 namespace hmlp
35 {
36 namespace gofmm
37 {
38 
39 
69 template<typename T>
70 class Factor
71 {
72  public:
73 
74  Factor() {};
75 
76  void SetupFactor
77  (
78  bool issymmetric, bool do_ulv_factorization,
79  bool isleaf, bool isroot,
81  size_t n, size_t nl, size_t nr,
83  size_t s, size_t sl, size_t sr
84  )
85  {
86  this->issymmetric = issymmetric;
87  this->do_ulv_factorization = do_ulv_factorization;
88  this->isleaf = isleaf;
89  this->isroot = isroot;
90  this->n = n; this->nl = nl; this->nr = nr;
91  this->s = s; this->sl = sl; this->sr = sr;
92  };
93 
94  void SetupFactor
95  (
96  bool issymmetric, bool do_ulv_factorization,
97  bool isleaf, bool isroot,
98  size_t n, size_t nl, size_t nr,
99  size_t s, size_t sl, size_t sr,
101  Data<T> &U,
103  Data<T> &V
104  )
105  {
106  SetupFactor( issymmetric, do_ulv_factorization,
107  isleaf, isroot, n, nl, nr, s, sl, sr );
108  };
109 
110  bool DoULVFactorization()
111  {
112  return do_ulv_factorization;
113  };
114 
115  bool IsSymmetric() { return issymmetric; };
116 
117 
118 
119 
120 
121  void CheckCondition()
122  {
123  assert( do_ulv_factorization && issymmetric );
124  T max_diag = 0.0;
125  T min_diag = 0.0;
126 
127  for ( size_t i = 0; i < Z.row(); i ++ )
128  {
129  T abs_diag = std::abs( Z( i, i ) );
130 
131  if ( !i )
132  {
133  max_diag = abs_diag;
134  min_diag = abs_diag;
135  }
136 
137  if ( abs_diag > max_diag ) max_diag = abs_diag;
138  if ( abs_diag < min_diag ) min_diag = abs_diag;
139  }
140 
141  printf( "condiditon( Z ): min_diag %3.1E max_diag %3.1E ratio %3.1E\n",
142  min_diag, max_diag, min_diag / max_diag );
143  };
144 
145  void Factorize( Data<T> &Kaa )
146  {
147  assert( isleaf );
148  assert( Kaa.row() == n ); assert( Kaa.col() == n );
149 
151  Z = Kaa;
152 
154  ipiv.resize( n, 0 );
155 
157  T nrm1 = 0.0;
158  for ( auto &z : Z ) nrm1 += z;
159 
161  xgetrf( n, n, Z.data(), n, ipiv.data() );
162 
164  T rcond1 = 0.0;
165  Data<T> work( Z.row(), 4 );
166  vector<int> iwork( Z.row() );
167  xgecon( "1", Z.row(), Z.data(), Z.row(), nrm1,
168  &rcond1, work.data(), iwork.data() );
169  if ( 1.0 / rcond1 > 1E+6 )
170  printf( "Warning! large 1-norm condition number %3.1E, nrm1( Z ) %3.1E\n",
171  1.0 / rcond1, nrm1 );
172  };
180  {
182  Z = A;
183  ChangeBasis( Z );
184 
186  Zv.Set( false, Z );
187  Zv.Partition2x2( Ztl, Ztr,
188  Zbl, Zbr, s, s, BOTTOMRIGHT );
189 
190  //printf( "Ztl %lux%lu Ztr %lux%lu\n", Ztl.row(), Ztl.col(), Ztr.row(), Ztr.col() ); fflush( stdout );
191  //printf( "Zbl %lux%lu Zbr %lux%lu\n", Zbl.row(), Zbl.col(), Zbr.row(), Zbr.col() ); fflush( stdout );
192 
194  ipiv.resize( Ztl.row(), 0 );
196  xgetrf( Ztl.row(), Z.col(), Z.data(), Z.row(), ipiv.data() );
198  xtrsm( "Right", "Upper", "No transpose", "Non-unit", Zbl.row(), Zbl.col(),
199  1.0, Ztl.data(), Ztl.ld(), Zbl.data(), Zbl.ld() );
201  xgemm( "No transpose", "No transpose", Zbr.row(), Zbr.col(), Ztl.col(),
202  -1.0, Zbl.data(), Zbl.ld(),
203  Ztr.data(), Ztr.ld(),
204  1.0, Zbr.data(), Zbr.ld() );
205 
206  };
231  void Factorize
232  (
234  Data<T> &Ul,
236  Data<T> &Ur,
238  Data<T> &Vl,
240  Data<T> &Vr
241  )
242  {
243  assert( !isleaf );
244  //assert( Ul.row() == nl ); assert( Ul.col() == sl );
245  //assert( Ur.row() == nr ); assert( Ur.col() == sr );
246  //assert( Vl.row() == nl ); assert( Vl.col() == sl );
247  //assert( Vr.row() == nr ); assert( Vr.col() == sr );
248 
250  if ( issymmetric )
251  {
252  assert( Crl.row() == sr ); assert( Crl.col() == sl );
253  }
254  else
255  {
256  assert( Clr.row() == sl ); assert( Clr.col() == sr );
257  assert( Crl.row() == sr ); assert( Crl.col() == sl );
258  }
259 
266  Z.resize( 0, 0 );
267  Z.resize( sl + sr, sl + sr, 0.0 );
268  for ( size_t i = 0; i < sl + sr; i ++ ) Z[ i * Z.row() + i ] = 1.0;
269 
270 
271 
272  if ( do_ulv_factorization )
273  {
278  if ( issymmetric )
279  {
281  hmlp::Data<T> Zbl = Crl;
282 
283  //printf( "Crl\n" );
284  //Crl.Print();
285 
286 
288  xtrmm
289  (
290  "Right", "Upper", "Transpose", "Non-unit",
291  Zbl.row(), Zbl.col(),
292  1.0, Ul.data(), Ul.row(),
293  Zbl.data(), Zbl.row()
294  );
295  //printf( "Ul.row() %lu Zbl.row() %lu Zbl.col() %lu\n",
296  // Ul.row(), Zbl.row(), Zbl.col() );
297 
299  xtrmm
300  (
301  "Left", "Upper", "Non-transpose", "Non-unit",
302  Zbl.row(), Zbl.col(),
303  1.0, Ur.data(), Ur.row(),
304  Zbl.data(), Zbl.row()
305  );
306  //printf( "Ur.row() %lu Zbl.row() %lu Zbl.col() %lu\n",
307  // Ur.row(), Zbl.row(), Zbl.col() );
308 
310  for ( size_t j = 0; j < sl; j ++ )
311  for ( size_t i = 0; i < sr; i ++ )
312  {
313  Z( sl + i, j ) = Zbl( i, j );
314  Z( j, sl + i ) = Zbl( i, j );
315  }
316 
318  if ( 1 )
319  {
320  xpotrf( "Lower", Z.row(), Z.data(), Z.row() );
321  //CheckCondition();
322  }
323  else
324  {
326  ipiv.resize( Z.row(), 0 );
327  xgetrf( Z.row(), Z.col(), Z.data(), Z.row(), ipiv.data() );
328  }
329  }
330  else
331  {
333  ipiv.resize( Z.row(), 0 );
334  }
335  }
336  else
337  {
339  ipiv.resize( Z.row(), 0 );
340 
345  std::vector<T> VltUl( sl * sl, 0.0 );
346  std::vector<T> VrtUr( sr * sr, 0.0 );
347 
349  xgemm( "T", "N", sl, sl, nl,
350  1.0, Vl.data(), nl,
351  Ul.data(), nl,
352  0.0, VltUl.data(), sl );
353 
355  xgemm( "T", "N", sr, sr, nr,
356  1.0, Vr.data(), nr,
357  Ur.data(), nr,
358  0.0, VrtUr.data(), sr );
359 
361  xgemm( "N", "N", sr, sl, sl,
362  1.0, Crl.data(), sr,
363  VltUl.data(), sl,
364  0.0, Z.data() + sl, sl + sr );
365 
366 
367  if ( issymmetric )
368  {
370  xgemm( "T", "N", sl, sr, sr,
371  1.0, Crl.data(), sr,
372  VrtUr.data(), sr,
373  0.0, Z.data() + ( sl + sr ) * sl, sl + sr );
374  }
375  else
376  {
377  printf( "bug\n" ); exit( 1 );
379  xgemm( "N", "N", sl, sr, sr,
380  1.0, Clr.data(), sl,
381  VrtUr.data(), sr,
382  0.0, Z.data() + ( sl + sr ) * sl, sl + sr );
383  }
384 
386  T nrm1 = 0.0;
387  for ( size_t i = 0; i < Z.size(); i ++ )
388  nrm1 += std::abs( Z[ i ] );
389 
391  xgetrf( Z.row(), Z.col(), Z.data(), Z.row(), ipiv.data() );
392 
394  this->Ul = &Ul;
395  this->Ur = &Ur;
396  this->Vl = &Vl;
397  this->Vr = &Vr;
398 
400  T rcond1 = 0.0;
401  hmlp::Data<T> work( Z.row(), 4 );
402  std::vector<int> iwork( Z.row() );
403  xgecon( "1", Z.row(), Z.data(), Z.row(), nrm1,
404  &rcond1, work.data(), iwork.data() );
405  if ( 1.0 / rcond1 > 1E+6 )
406  printf( "Warning! large 1-norm condition number %3.1E\n",
407  1.0 / rcond1 ); fflush( stdout );
408  }
409 
410  };
415  View<T> &Zl, View<T> &Zr,
417  Data<T> &Ul, Data<T> &Ur,
419  Data<T> &Vl, Data<T> &Vr )
420  {
421  Z.resize( 0, 0 );
422  Z.resize( sl + sr, sl + sr, 0.0 );
423 
425  Zv.Set( false, Z );
426  Zv.Partition2x2( Ztl, Ztr,
427  Zbl, Zbr, sl, sl, TOPLEFT );
428 
429  //printf( "Ztl %lux%lu Ztr %lux%lu\n", Ztl.row(), Ztl.col(), Ztr.row(), Ztr.col() ); fflush( stdout );
430  //printf( "Zbl %lux%lu Zbr %lux%lu\n", Zbl.row(), Zbl.col(), Zbr.row(), Zbr.col() ); fflush( stdout );
431 
432 
433  Zbl.CopyValuesFrom( Crl );
435  xtrmm( "Right", "Upper", "Transpose", "Non-unit", Zbl.row(), Zbl.col(),
436  1.0, Ul.data(), Ul.row(), Zbl.data(), Zbl.ld() );
438  xtrmm( "Left", "Upper", "Non-transpose", "Non-unit", Zbl.row(), Zbl.col(),
439  1.0, Ur.data(), Ur.row(), Zbl.data(), Zbl.ld() );
440 
441  Ztl.CopyValuesFrom( Zl );
442  Zbr.CopyValuesFrom( Zr );
443 
444  for ( size_t j = 0; j < sl; j ++ )
445  for ( size_t i = 0; i < sr; i ++ )
446  Ztr( j, i ) = Zbl( i, j );
447 
448  PartialFactorize( Z );
449 
450  };
456  void Multiply( View<T> &bl, View<T> &br )
457  {
458  assert( !isleaf && bl.col() == br.col() );
459 
460  size_t nrhs = bl.col();
461 
462  std::vector<T> ta( ( sl + sr ) * nrhs );
463  std::vector<T> tl( sl * nrhs );
464  std::vector<T> tr( sr * nrhs );
465 
467  xgemm( "T", "N", sl, nrhs, nl,
468  1.0, Vl->data(), nl,
469  bl.data(), bl.ld(),
470  0.0, tl.data(), sl );
472  xgemm( "T", "N", sr, nrhs, nr,
473  1.0, Vr->data(), nr,
474  br.data(), br.ld(),
475  0.0, tr.data(), sr );
476 
478  xgemm( "N", "N", sr, nrhs, sl,
479  1.0, Crl.data(), sr,
480  tl.data(), sl,
481  0.0, ta.data() + sl, sl + sr );
482 
483  if ( issymmetric )
484  {
486  xgemm( "T", "N", sl, nrhs, sr,
487  1.0, Crl.data(), sr,
488  tr.data(), sr,
489  0.0, ta.data(), sl + sr );
490  }
491  else
492  {
493  printf( "bug here !!!!!\n" ); fflush( stdout ); exit( 1 );
495  xgemm( "N", "N", sl, nrhs, sr,
496  1.0, Clr.data(), sl,
497  tr.data(), sr,
498  0.0, ta.data(), sl + sr );
499  }
500 
502  xgemm( "N", "N", nl, nrhs, sl,
503  -1.0, Ul->data(), nl,
504  ta.data(), sl + sr,
505  1.0, bl.data(), bl.ld() );
506 
508  xgemm( "N", "N", nr, nrhs, sr,
509  -1.0, Ur->data(), nr,
510  ta.data() + sl, sl + sr,
511  1.0, br.data(), br.ld() );
512  };
513 
517  void Solve( View<T> &rhs )
518  {
520  assert( isleaf );
521  assert( !do_ulv_factorization );
522  assert( rhs.data() && Z.data() );
523  assert( ipiv.data() );
524 
525  //rhs.Print();
526 
527  size_t nrhs = rhs.col();
528 
530  xgetrs( "Non-transpose", rhs.row(), nrhs,
531  Z.data(), Z.row(), ipiv.data(),
532  rhs.data(), rhs.ld() );
533 
534  };
541  void Solve( View<T> &bl, View<T> &br )
542  {
543  size_t nrhs = bl.col();
544 
545  //bl.Print();
546  //br.Print();
547 
549  assert( !do_ulv_factorization );
550  assert( bl.col() == br.col() );
551  assert( bl.row() == nl );
552  assert( br.row() == nr );
553  assert( Ul && Ur && Vl && Vr );
554 
556 // hmlp::Data<T> ta( sl + sr, nrhs );
557 // hmlp::Data<T> tl( sl, nrhs );
558 // hmlp::Data<T> tr( sr, nrhs );
559 
560  vector<T> ta( ( sl + sr ) * nrhs );
561  vector<T> tl( sl * nrhs );
562  vector<T> tr( sr * nrhs );
563 
564 
566  //hmlp::View<T> xa( ta ), xl, xr;
567 
569  //xa.Partition2x1
570  //(
571  // xl,
572  // xr, sl
573  //);
574 
575 
577  xgemm( "T", "N", sl, nrhs, nl,
578  1.0, Vl->data(), nl,
579  bl.data(), bl.ld(),
580  0.0, tl.data(), sl );
582  xgemm( "T", "N", sr, nrhs, nr,
583  1.0, Vr->data(), nr,
584  br.data(), br.ld(),
585  0.0, tr.data(), sr );
586 
587 
589  xgemm( "N", "N", sr, nrhs, sl,
590  1.0, Crl.data(), sr,
591  tl.data(), sl,
592  0.0, ta.data() + sl, sl + sr );
593 
594  if ( issymmetric )
595  {
597  xgemm( "T", "N", sl, nrhs, sr,
598  1.0, Crl.data(), sr,
599  tr.data(), sr,
600  0.0, ta.data(), sl + sr );
601  }
602  else
603  {
604  printf( "bug here !!!!!\n" ); fflush( stdout ); exit( 1 );
606  xgemm( "N", "N", sl, nrhs, sr,
607  1.0, Clr.data(), sl,
608  tr.data(), sr,
609  0.0, ta.data(), sl + sr );
610  }
611 
613  xgetrs( "N", sl + sr, nrhs,
614  Z.data(), Z.row(), ipiv.data(),
615  ta.data(), sl + sr );
616 
618  xgemm( "N", "N", nl, nrhs, sl,
619  -1.0, Ul->data(), nl,
620  ta.data(), sl + sr,
621  1.0, bl.data(), bl.ld() );
622 
624  xgemm( "N", "N", nr, nrhs, sr,
625  -1.0, Ur->data(), nr,
626  ta.data() + sl, sl + sr,
627  1.0, br.data(), br.ld() );
628 
629  };
634  void Telescope
635  (
636  bool DO_INVERSE,
638  Data<T> &Pa,
640  Data<T> &Palr
641  )
642  {
643  assert( isleaf );
645  Pa.resize( n, s, 0.0 );
646 
648  //hmlp::View<T> Xa;
649 
650  //Xa.Set( Pa );
651 
652  assert( Palr.row() == s ); assert( Palr.col() == n );
653 
655  for ( size_t j = 0; j < Pa.col(); j ++ )
656  for ( size_t i = 0; i < Pa.row(); i ++ )
657  Pa( i, j ) = Palr( j, i );
658 
659  if ( DO_INVERSE )
660  {
661  if ( do_ulv_factorization )
662  {
663  xtrsm( "Left", "Lower", "No transpose", "Non-unit",
664  Pa.row(), Pa.col(),
665  1.0, Z.data(), Z.row(), Pa.data(), Pa.row() );
666  }
667  else
668  {
669  assert( ipiv.size() );
671  xgetrs( "Non-transpose",
672  n, s, Z.data(), n, ipiv.data(), Pa.data(), n );
673  }
674  }
675 
676 
677  //printf( "call solve from telescope\n" ); fflush( stdout );
678  //if ( DO_INVERSE ) Solve<true>( Xa );
679  //printf( "call solve from telescope (exist)\n" ); fflush( stdout );
680 
681  };
686  void Telescope
687  (
688  bool DO_INVERSE,
690  Data<T> &Pa,
692  Data<T> &Palr,
694  Data<T> &Pl,
696  Data<T> &Pr
697  )
698  {
699  assert( !isleaf );
700  assert( n == nl + nr );
701  assert( Pl.col() == sl );
702  assert( Pr.col() == sr );
703  assert( Palr.row() == s ); assert( Palr.col() == ( sl + sr ) );
704 
706  Pa.resize( 0, 0 );
707 
709  //hmlp::View<T> Xa;
710 
711  //Xa.Set( Pa );
712  //assert( Xa.row() == Pa.row() );
713  //assert( Xa.col() == Pa.col() );
714 
715  if ( do_ulv_factorization )
716  {
717  Pa.resize( sl + sr, s, 0.0 );
718 
720  for ( size_t j = 0; j < Pa.col(); j ++ )
721  for ( size_t i = 0; i < Pa.row(); i ++ )
722  Pa[ j * Pa.row() + i ] = Palr[ i * Palr.row() + j ];
723 
725  xtrmm( "Left", "Upper", "No Transpose", "Non-unit", sl, s,
726  1.0, Pl.data(), Pl.row(),
727  Pa.data(), Pa.row() );
728  // printf( "Pl.row() %lu Pa.row() %lu Pa.col() %lu\n",
729  // Pl.row(), Pa.row(), Pa.col() );
730 
732  xtrmm( "Left", "Upper", "No Transpose", "Non-unit", sr, s,
733  1.0, Pr.data() , Pr.row(),
734  Pa.data() + sl, Pa.row() );
735  // printf( "Pr.row() %lu Pa.row() %lu Pa.col() %lu\n",
736  // Pr.row(), Pa.row(), Pa.col() );
737 
739  if ( DO_INVERSE )
740  {
741  if ( 1 )
742  {
743  xtrsm( "Left", "Lower", "No transpose", "Non-unit",
744  Pa.row(), Pa.col(),
745  1.0, Z.data(), Z.row(),
746  Pa.data(), Pa.row() );
747  }
748  else
749  {
750  xlaswp( Pa.col(), Pa.data(), Pa.row(),
751  1, Pa.row(), ipiv.data(), 1 );
752  xtrsm( "Left", "Lower", "No transpose", "Unit", Pa.row(), Pa.col(),
753  1.0, Z.data(), Z.row(),
754  Pa.data(), Pa.row() );
755  }
756  //printf( "Z.row() %lu Z.col() %lu\n", Z.row(), Z.col() );
757  }
758 
759  }
760  else
761  {
762  Pa.resize( nl + nr, s, 0.0 );
763 
765  //hmlp::View<T> Xl, Xr;
766 
768  //Xa.Partition2x1
769  //(
770  // Xl,
771  // Xr, nl
772  //);
773 
774  //assert( Xl.row() == nl );
775  //assert( Xr.row() == nr );
776  //assert( Xl.col() == s );
777  //assert( Xr.col() == s );
778 
779 
781  xgemm( "N", "T", nl, s, sl,
782  1.0, Pl.data(), nl,
783  Palr.data(), s,
784  0.0, Pa.data(), n );
786  xgemm( "N", "T", nr, s, sr,
787  1.0, Pr.data(), nr,
788  Palr.data() + s * sl, s,
789  0.0, Pa.data() + nl, n );
790 
791 
792 
793  //if ( DO_INVERSE ) Solve<true>( Xl, Xr );
794  //printf( "end inner solve from telescope\n" ); fflush( stdout );
795 
796  if ( DO_INVERSE )
797  {
798  Data<T> x( sl + sr, s );
799  Data<T> xl( sl, s );
800  Data<T> xr( sr, s );
801 
803  xgemm( "T", "N", sl, s, nl,
804  1.0, Vl->data(), nl,
805  Pa.data(), n,
806  0.0, xl.data(), sl );
808  xgemm( "T", "N", sr, s, nr,
809  1.0, Vr->data(), nr,
810  Pa.data() + nl, n,
811  0.0, xr.data(), sr );
812 
815  xgemm( "T", "N", sl, s, sr,
816  1.0, Crl.data(), sr,
817  xr.data(), sr,
818  0.0, x.data(), sl + sr );
819  xgemm( "N", "N", sr, s, sl,
820  1.0, Crl.data(), sr,
821  xl.data(), sl,
822  0.0, x.data() + sl, sl + sr );
823 
825  xgetrs( "N", x.row(), x.col(),
826  Z.data(), Z.row(), ipiv.data(),
827  x.data(), x.row() );
828 
830  xgemm( "N", "N", nl, s, sl,
831  -1.0, Ul->data(), nl,
832  x.data(), sl + sr,
833  1.0, Pa.data(), n );
835  xgemm( "N", "N", nr, s, sr,
836  -1.0, Ur->data(), nr,
837  x.data() + sl, sl + sr,
838  1.0, Pa.data() + nl, n );
839  }
840  }
841  };
842 
845  {
847  tau.resize( std::min( U.row(), U.col() ) );
849  Data<T> work( U.col() * 512, 1 );
851  xgeqrf( U.row(), U.col(), U.data(), U.row(),
852  tau.data(), work.data(), work.size() );
854  Q = U;
856  Q.resize( U.row(), U.row() );
858  xorgqr( Q.row(), Q.col(), U.col(), Q.data(), Q.row(), tau.data(),
859  work.data(), work.size() );
860 
861 
862 
864  Qv.Set( false, Q );
865  Qv.Partition1x2( Q1, Q2, tau.size(), LEFT );
867  Data<T> C = Q;
868  Data<T> D = Q;
869 
870  xgemm( "Transpose", "No Transpose", C.row(), C.col(), Q.row(),
871  1.0, Q.data(), Q.row(),
872  Q.data(), Q.row(),
873  0.0, C.data(), C.row() );
874 
875  xgemm( "No Transpose", "Transpose", D.row(), D.col(), Q.row(),
876  1.0, Q.data(), Q.row(),
877  Q.data(), Q.row(),
878  0.0, D.data(), D.row() );
879 
880  for ( size_t j = 0; j < Q.col(); j ++ )
881  {
882  for ( size_t i = 0; i < Q.row(); i ++ )
883  {
884  if ( i == j ) assert( std::fabs( C( i, j ) - 1 ) < 1E-5 );
885  else assert( std::fabs( C( i, j ) - 0 ) < 1E-5 );
886  }
887  }
888  for ( size_t j = 0; j < Q.col(); j ++ )
889  {
890  for ( size_t i = 0; i < Q.row(); i ++ )
891  {
892  if ( i == j ) assert( std::fabs( D( i, j ) - 1 ) < 1E-5 );
893  else assert( std::fabs( D( i, j ) - 0 ) < 1E-5 );
894  }
895  }
896 
897  };
898 
899 
900 
901 
903  void ChangeBasis( SideType side, Data<T> &B )
904  {
906  if ( !Q.size() ) return;
907 
909  Data<T> A = B;
910 
912  View<T> Av( false, A );
913  View<T> Bv( false, B );
914  View<T> Bl, Br, Bt, Bb;
915 
917  switch ( side )
918  {
919  case LEFT:
920  {
922  Bv.Partition2x1( Bt,
923  Bb, Q2.col(), TOP );
924  //printf("Bt %lux%lu Bb %lux%lu\n", Bt.row(), Bt.col(),
925  // Bb.row(), Bb.col() ); fflush( stdout );
926 
928  xgemm( "Transpose", "No Transpose", Bt.row(), Bt.col(), Q2.row(),
929  1.0, Q2.data(), Q2.ld(),
930  Av.data(), Av.ld(),
931  0.0, Bt.data(), Bt.ld() );
933  xgemm( "Transpose", "No Transpose", Bb.row(), Bb.col(), Q1.row(),
934  1.0, Q1.data(), Q1.ld(),
935  Av.data(), Av.ld(),
936  0.0, Bb.data(), Bb.ld() );
937  break;
938  }
939  case RIGHT:
940  {
942  Bv.Partition1x2( Bl, Br, Q2.col(), LEFT );
943 
944  //printf("Bl %lux%lu Br %lux%lu\n", Bl.row(), Bl.col(),
945  // Br.row(), Br.col() ); fflush( stdout );
946 
948  xgemm( "No Transpose", "No Transpose", Bl.row(), Bl.col(), Q2.row(),
949  1.0, Av.data(), Av.ld(),
950  Q2.data(), Q2.ld(),
951  0.0, Bl.data(), Bl.ld() );
953  xgemm( "No Transpose", "No Transpose", Br.row(), Br.col(), Q1.row(),
954  1.0, Av.data(), Av.ld(),
955  Q1.data(), Q1.ld(),
956  0.0, Br.data(), Br.ld() );
957  break;
958  }
959  default:
960  {
962  throw "Value of (SideType) side is not recognized.";
963  }
964  }
965  //printf( "end ChangeBasis\n" ); fflush( stdout );
966  };
970  void ChangeBasis( Data<T> &A )
971  {
972  ChangeBasis( LEFT, A );
973  ChangeBasis( RIGHT, A );
974  };
978  void ULVForward()
979  {
981  if ( isleaf ) B = bview.toData();
983  ChangeBasis( LEFT, B );
985  xlaswp( Bf.col(), Bf.data(), Bf.ld(), 1, Bf.row(), ipiv.data(), 1 );
987  xtrsm( "Left", "Lower", "No transpose", "Unit", Bf.row(), Bf.col(),
988  1.0, Ztl.data(), Ztl.ld(), Bf.data(), Bf.ld() );
990  xgemm( "No Transpose", "No Transpose", Bc.row(), Bc.col(), Bf.row(),
991  -1.0, Zbl.data(), Zbl.ld(), Bf.data(), Bf.ld(), 1.0, Bc.data(), Bc.ld() );
992  //printf( "Bc %lux%lu Bp %lux%lu\n", Bc.row(), Bc.col(), Bp.row(), Bp.col() ); fflush( stdout );
994  Bp.CopyValuesFrom( Bc );
995  };
998  void ULVBackward()
999  {
1001  Bc.CopyValuesFrom( Bp );
1003  xgemm( "No Transpose", "No Transpose", Bf.row(), Bf.col(), Bc.row(),
1004  -1.0, Ztr.data(), Ztr.ld(), Bc.data(), Bc.ld(), 1.0, Bf.data(), Bf.ld() );
1006  xtrsm( "Left", "Upper", "No transpose", "Non-unit", Bf.row(), Bf.col(),
1007  1.0, Ztl.data(), Ztl.ld(), Bf.data(), Bf.ld() );
1008  if ( Q.size() )
1009  {
1011  Data<T> A = B;
1012  xgemm( "No Transpose", "No Transpose", A.row(), A.col(), Bf.row(),
1013  1.0, Q2.data(), Q2.ld(), Bf.data(), Bf.ld(), 0.0, A.data(), A.row() );
1014  xgemm( "No Transpose", "No Transpose", A.row(), A.col(), Bc.row(),
1015  1.0, Q1.data(), Q1.ld(), Bc.data(), Bc.ld(), 1.0, A.data(), A.row() );
1017  if ( isleaf ) bview.CopyValuesFrom( A );
1018  else Bv.CopyValuesFrom( A );
1019  }
1020  };
1027  bool isleaf = false;
1028 
1029  bool isroot = false;
1030 
1031  size_t n = 0;
1032 
1033  size_t nl = 0;
1034 
1035  size_t nr = 0;
1036 
1037  size_t s = 0;
1038 
1039  size_t sl = 0;
1040 
1041  size_t sr = 0;
1042 
1043 
1047  View<T> Zv;
1048  View<T> Ztl, Ztr, Zbl, Zbr;
1049 
1051  vector<int> ipiv;
1052 
1055 
1057  Data<T> Crl, Clr;
1058 
1061 
1063  Data<T> *Ul = NULL;
1064  Data<T> *Ur = NULL;
1065  Data<T> *Vl = NULL;
1066  Data<T> *Vr = NULL;
1067 
1070  View<T> Qv, Q1, Q2;
1071 
1073  vector<T> tau;
1074 
1077  View<T> Bv, Bp, Bsibling, Bf, Bc;
1078 
1079  private:
1081  bool issymmetric = true;
1082 
1083  bool do_ulv_factorization = false;
1084 
1085 };
1091 template<typename NODE, typename T>
1092 void SetupFactor( NODE *node )
1093 {
1094  size_t n, nl, nr, s, sl, sr;
1095  bool issymmetric, do_ulv_factorization;
1096 
1097 
1098 #ifdef DEBUG_IGOFMM
1099  printf( "begin SetupFactor %lu\n", node->treelist_id ); fflush( stdout );
1100 #endif
1101 
1102  issymmetric = node->setup->IsSymmetric();
1103  do_ulv_factorization = node->setup->do_ulv_factorization;
1104  n = node->n;
1105  nl = 0;
1106  nr = 0;
1107  s = node->data.skels.size();
1108  sl = 0;
1109  sr = 0;
1110 
1111  if ( !node->isleaf )
1112  {
1113  nl = node->lchild->n;
1114  nr = node->rchild->n;
1115  sl = node->lchild->data.skels.size();
1116  sr = node->rchild->data.skels.size();
1117  }
1118 
1119 
1120  node->data.SetupFactor( issymmetric, do_ulv_factorization,
1121  node->isleaf, !node->l, n, nl, nr, s, sl, sr );
1122 
1123 #ifdef DEBUG_IGOFMM
1124  printf( "end SetupFactor %lu\n", node->treelist_id ); fflush( stdout );
1125 #endif
1126 
1127 };
1133 template<typename NODE, typename T>
1134 class SetupFactorTask : public Task
1135 {
1136  public:
1137 
1138  NODE *arg = NULL;
1139 
1140  void Set( NODE *user_arg )
1141  {
1142  arg = user_arg;
1143  name = string( "sf" );
1144  label = to_string( arg->treelist_id );
1145  cost = 1.0;
1146  };
1147 
1149  {
1150  double flops = 0.0, mops = 0.0;
1151  event.Set( label + name, flops, mops );
1152  };
1153 
1154  void DependencyAnalysis()
1155  {
1156  arg->DependencyAnalysis( W, this );
1157  this->TryEnqueue();
1158  };
1159 
1160  void Execute( Worker* user_worker )
1161  {
1162  SetupFactor<NODE, T>( arg );
1163  };
1164 
1165 };
1170 template<typename NODE>
1171 void SolverTreeView( NODE *node )
1172 {
1173  auto &data = node->data;
1174  auto *setup = node->setup;
1175  auto &input = *(setup->input);
1176  auto &output = *(setup->output);
1178  if ( node->isleaf ) data.B.resize( data.n, input.col() );
1179  else data.B.resize( data.sl + data.sr, input.col() );
1180 
1182  data.Bv.Set( data.B );
1183  data.Bv.Partition2x1( data.Bf,
1184  data.Bc, data.s, BOTTOM );
1185 
1187  if ( !node->parent ) data.bview.Set( output );
1188 
1190  if ( !node->isleaf )
1191  {
1192  auto &ldata = node->lchild->data;
1193  auto &rdata = node->rchild->data;
1195  data.bview.Partition2x1( ldata.bview,
1196  rdata.bview, data.nl, TOP );
1197  data.Bv.Partition2x1( ldata.Bp,
1198  rdata.Bp, data.sl, TOP );
1199  }
1200 };
1206 template<typename NODE>
1207 class SolverTreeViewTask : public Task
1208 {
1209  public:
1210 
1211  NODE *arg = NULL;
1212 
1213  void Set( NODE *user_arg )
1214  {
1215  arg = user_arg;
1216  name = string( "TreeView" );
1217  label = to_string( arg->treelist_id );
1218  cost = 1.0;
1219  };
1220 
1222  {
1223  double flops = 0.0, mops = 0.0;
1224  event.Set( label + name, flops, mops );
1225  };
1226 
1228  void DependencyAnalysis() { arg->DependOnParent( this ); };
1229 
1230  void Execute( Worker* user_worker ) { SolverTreeView( arg ); };
1231 
1232 };
1240 template<bool FORWARD, typename NODE>
1242 {
1243  public:
1244 
1245  NODE *arg;
1246 
1247  void Set( NODE *user_arg )
1248  {
1249  name = std::string( "MatrixPermutation" );
1250  arg = user_arg;
1251  cost = 1.0;
1252  };
1253 
1255  {
1256  double flops = 0.0, mops = 0.0;
1257  event.Set( label + name, flops, mops );
1258  };
1259 
1262  {
1263  if ( FORWARD )
1264  {
1265  arg->DependencyAnalysis( RW, this );
1266  }
1267  else
1268  {
1269  this->Enqueue();
1270  }
1271  };
1272 
1273  void Execute( Worker* user_worker )
1274  {
1275  //printf( "PermuteMatrix %lu\n", arg->treelist_id );
1276  auto *node = arg;
1277  auto &gids = node->gids;
1278  auto &input = *(node->setup->input);
1279  auto &output = *(node->setup->output);
1280  auto &A = node->data.bview;
1281 
1282  assert( A.row() == gids.size() );
1283  assert( A.col() == input.col() );
1284 
1285  //for ( size_t i = 0; i < gids.size(); i ++ )
1286  // printf( "%lu ", gids[ i ] );
1287  //printf( "\n" );
1288 
1290  for ( size_t j = 0; j < input.col(); j ++ )
1291  for ( size_t i = 0; i < gids.size(); i ++ )
1293  if ( FORWARD ) A( i, j ) = input( gids[ i ], j );
1295  else input( gids[ i ], j ) = A( i, j );
1296 
1297  //for ( size_t j = 0; j < 1; j ++ )
1298  // for ( size_t i = 0; i < gids.size(); i ++ )
1299  // printf( "%E ", A( i, j ) );
1300  //printf( "\n" );
1301 
1302  //printf( "end PermuteMatrix %lu\n", arg->treelist_id );
1303  };
1304 
1305 };
1312 template<typename NODE, typename T>
1313 void Apply( NODE *node )
1314 {
1315  auto &data = node->data;
1316  auto &setup = node->setup;
1317  auto &K = *setup->K;
1318 
1319  if ( node->isleaf )
1320  {
1321  auto lambda = setup->lambda;
1322  auto &amap = node->gids;
1324  auto Kaa = K( amap, amap );
1326  for ( size_t i = 0; i < Kaa.row(); i ++ )
1327  Kaa[ i * Kaa.row() + i ] += lambda;
1328  }
1329  else
1330  {
1331  auto &bl = node->lchild->data.bview;
1332  auto &br = node->rchild->data.bview;
1333  data.Apply<true>( bl, br );
1334  }
1335 };
1339 //template<typename NODE, typename T>
1340 //void ULVForwardSolve( NODE *node ) { node->data.ULVForward(); };
1341 
1342 
1343 
1344 template<typename NODE, typename T>
1346 {
1347  public:
1348 
1349  NODE *arg = NULL;
1350 
1351  void Set( NODE *user_arg )
1352  {
1353  arg = user_arg;
1354  name = string( "ulvforward" );
1355  label = to_string( arg->treelist_id );
1356  cost = 1.0;
1357  };
1358 
1359  //void DependencyAnalysis()
1360  //{
1361  // arg->DependencyAnalysis( RW, this );
1362  // /** depend on two children */
1363  // if ( !arg->isleaf )
1364  // {
1365  // arg->lchild->DependencyAnalysis( R, this );
1366  // arg->rchild->DependencyAnalysis( R, this );
1367  // }
1368  // /** dispatch the task if there is no dependency */
1369  // this->TryEnqueue();
1370  //};
1371 
1372  void DependencyAnalysis() { arg->DependOnChildren( this ); };
1373 
1374 
1375  void Execute( Worker* user_worker ) { arg->data.ULVForward(); };
1376 
1377 };
1382 template<typename NODE, typename T>
1384 {
1385  public:
1386 
1387  NODE *arg;
1388 
1389  void Set( NODE *user_arg )
1390  {
1391  arg = user_arg;
1392  name = string( "ulvbackward" );
1393  label = std::to_string( arg->treelist_id );
1394  cost = 1.0;
1395 
1396  //printf( "Set treelist_id %lu\n", arg->treelist_id ); fflush( stdout );
1397  };
1398 
1399  //void DependencyAnalysis()
1400  //{
1401  // /** depend on parent */
1402  // if ( arg->parent )
1403  // arg->parent->DependencyAnalysis( hmlp::ReadWriteType::R, this );
1404  // arg->DependencyAnalysis( hmlp::ReadWriteType::RW, this );
1405  // /** dispatch the task if there is no dependency */
1406  // this->TryEnqueue();
1407  //};
1408 
1409 
1410  void DependencyAnalysis() { arg->DependOnParent( this ); };
1411 
1412  void Execute( Worker* user_worker ) { arg->data.ULVBackward(); };
1413 
1414 };
1436 template<typename NODE, typename T>
1437 void Solve( NODE *node )
1438 {
1439 
1440  auto &data = node->data;
1441  auto &setup = node->setup;
1442  auto &K = *setup->K;
1443 
1444 
1445  //printf( "%lu beg Solve\n", node->treelist_id ); fflush( stdout );
1446 
1448  if ( node->isleaf )
1449  {
1450  auto &b = data.bview;
1451  data.Solve( b );
1452  //printf( "Solve %lu, m %lu n %lu\n", node->treelist_id, b.row(), b.col() );
1453  }
1454  else
1455  {
1456  auto &bl = node->lchild->data.bview;
1457  auto &br = node->rchild->data.bview;
1458  data.Solve( bl, br );
1459  //printf( "Solve %lu, m %lu n %lu\n", node->treelist_id, bl.row(), bl.col() );
1460  }
1461 
1462  //printf( "%lu end Solve\n", node->treelist_id ); fflush( stdout );
1463 
1464 };
1470 template<typename NODE, typename T>
1471 class SolveTask : public Task
1472 {
1473  public:
1474 
1475  NODE *arg = NULL;
1476 
1477  void Set( NODE *user_arg )
1478  {
1479  arg = user_arg;
1480  name = string( "sl" );
1481  label = to_string( arg->treelist_id );
1482  cost = 1.0;
1483 
1484  //printf( "Set treelist_id %lu\n", arg->treelist_id ); fflush( stdout );
1485  };
1486 
1488  {
1489  double flops = 0.0, mops = 0.0;
1490  event.Set( label + name, flops, mops );
1491  };
1492 
1493  void DependencyAnalysis()
1494  {
1495  arg->DependencyAnalysis( RW, this );
1496  if ( !arg->isleaf )
1497  {
1498  arg->lchild->DependencyAnalysis( R, this );
1499  arg->rchild->DependencyAnalysis( R, this );
1500  }
1501  };
1502 
1503  void Execute( Worker* user_worker )
1504  {
1505  Solve<NODE, T>( arg );
1506  };
1507 
1508 };
1514 template<typename T, typename TREE>
1515 void Solve( TREE &tree, Data<T> &input )
1516 {
1517  using NODE = typename TREE::NODE;
1518 
1519  const bool AUTO_DEPENDENCY = true;
1520  const bool USE_RUNTIME = true;
1521 
1523  auto *output = new Data<T>( input.row(), input.col() );
1524 
1525  SolverTreeViewTask<NODE> treeviewtask;
1526  MatrixPermuteTask<true, NODE> forwardpermutetask;
1527  MatrixPermuteTask<false, NODE> inversepermutetask;
1529  SolveTask<NODE, T> solvetask1;
1531  ULVForwardSolveTask<NODE, T> ulvforwardsolvetask;
1532  ULVBackwardSolveTask<NODE, T> ulvbackwardsolvetask;
1533 
1535  tree.setup.input = &input;
1536  tree.setup.output = output;
1537 
1538  if ( tree.setup.do_ulv_factorization )
1539  {
1541  tree.DependencyCleanUp();
1542  tree.TraverseDown( treeviewtask );
1543  tree.TraverseLeafs( forwardpermutetask );
1544  tree.TraverseUp( ulvforwardsolvetask );
1545  tree.TraverseDown( ulvbackwardsolvetask );
1546  if ( USE_RUNTIME ) hmlp_run();
1547 
1549  tree.DependencyCleanUp();
1550  tree.TraverseLeafs( inversepermutetask );
1551  if ( USE_RUNTIME ) hmlp_run();
1552  }
1553  else
1554  {
1556  tree.DependencyCleanUp();
1557  tree.TraverseDown( treeviewtask );
1558  tree.TraverseLeafs( forwardpermutetask );
1559  tree.TraverseUp( solvetask1 );
1560  if ( USE_RUNTIME ) hmlp_run();
1562  tree.DependencyCleanUp();
1563  tree.TraverseLeafs( inversepermutetask );
1564  if ( USE_RUNTIME ) hmlp_run();
1565  }
1566 
1568  delete output;
1569 
1570 };
1580 template<typename NODE, typename T>
1581 void LowRankError( NODE *node )
1582 {
1583  auto &data = node->data;
1584  auto &setup = node->setup;
1585  auto &K = *setup->K;
1586 
1587  if ( !node->isleaf )
1588  {
1589  auto Krl = K( node->rchild->gids, node->lchild->gids );
1590 
1591  auto nrm2 = hmlp_norm( Krl.row(), Krl.col(),
1592  Krl.data(), Krl.row() );
1593 
1594 
1595  hmlp::Data<T> VrCrl( data.nr, data.sl );
1596 
1598  xgemm( "N", "N", data.nr, data.sl, data.sr,
1599  1.0, data.Vr->data(), data.nr,
1600  data.Crl.data(), data.sr,
1601  0.0, VrCrl.data(), data.nr );
1602 
1604  xgemm( "N", "T", data.nr, data.nl, data.sl,
1605  -1.0, VrCrl.data(), data.nr,
1606  data.Vl->data(), data.nl,
1607  1.0, Krl.data(), data.nr );
1608 
1609  auto err = hmlp_norm( Krl.row(), Krl.col(),
1610  Krl.data(), Krl.row() );
1611 
1612  printf( "%4lu ||Krl -VrCrlVl|| %3.1E\n",
1613  node->treelist_id, std::sqrt( err / nrm2 ) );
1614  }
1615 
1616 };
1623 template<typename NODE, typename T>
1624 void Factorize( NODE *node )
1625 {
1626  auto &data = node->data;
1627  auto &setup = node->setup;
1628  auto &K = *setup->K;
1629  auto &proj = data.proj;
1630 
1631  auto do_ulv_factorization = setup->do_ulv_factorization;
1632 
1633  if ( node->isleaf )
1634  {
1635  auto lambda = setup->lambda;
1636  auto &amap = node->gids;
1637 
1639  Data<T> Kaa = K( amap, amap );
1640 
1642  for ( size_t i = 0; i < Kaa.row(); i ++ ) Kaa( i, i ) += lambda;
1643 
1644  if ( do_ulv_factorization )
1645  {
1647  data.Telescope( false, data.U, proj );
1649  data.Orthogonalization();
1651  data.PartialFactorize( Kaa );
1652  }
1653  else
1654  {
1656  data.Factorize( Kaa );
1658  data.Telescope( true, data.U, proj );
1660  data.Telescope( false, data.V, proj );
1661  }
1662  }
1663  else
1664  {
1665  auto &Ul = node->lchild->data.U;
1666  auto &Vl = node->lchild->data.V;
1667  auto &Zl = node->lchild->data.Zbr;
1668  auto &Ur = node->rchild->data.U;
1669  auto &Vr = node->rchild->data.V;
1670  auto &Zr = node->rchild->data.Zbr;
1671 
1673  auto &amap = node->lchild->data.skels;
1674  auto &bmap = node->rchild->data.skels;
1675 
1677  node->data.Crl = K( bmap, amap );
1678 
1679  if ( do_ulv_factorization )
1680  {
1681  if ( !node->data.isroot )
1682  {
1683  data.Telescope( false, data.U, proj, Ul, Ur );
1684  data.Orthogonalization();
1685  }
1686  data.PartialFactorize( Zl, Zr, Ul, Ur, Vl, Vr );
1687  }
1688  else
1689  {
1691  data.Factorize( Ul, Ur, Vl, Vr );
1693  if ( !node->data.isroot )
1694  {
1696  data.Telescope( true, data.U, proj, Ul, Ur );
1698  data.Telescope( false, data.V, proj, Vl, Vr );
1699  }
1700  }
1701  }
1702 
1703 
1704 
1705 
1706 
1707 
1708 
1709 
1710 // /** SMW factorization (LU or Cholesky) */
1711 // data.Factorize<true>( Ul, Ur, Vl, Vr );
1712 //
1713 // /** telescope U and V */
1714 // if ( !node->data.isroot )
1715 // {
1716 // if ( do_ulv_factorization )
1717 // {
1718 // data.Telescope( true, data.U, proj, Ul, Ur );
1719 // data.Orthogonalization();
1720 // }
1721 // else
1722 // {
1723 // /** U = inv( I + UCV' ) * [ Ul; Ur ] * proj' */
1724 // data.Telescope( true, data.U, proj, Ul, Ur );
1725 // /** V = [ Vl; Vr ] * proj' */
1726 // data.Telescope( false, data.V, proj, Vl, Vr );
1727 // }
1728 // }
1729 // else
1730 // {
1731 // /** output Crl from children */
1732 //
1733 // //size_t L = 3;
1734 //
1735 // auto *cl = node->lchild;
1736 // auto *cr = node->rchild;
1737 // auto *c1 = cl->lchild;
1738 // auto *c2 = cl->rchild;
1739 // auto *c3 = cr->lchild;
1740 // auto *c4 = cr->rchild;
1741 //
1742 // //hmlp::Data<T> C21 = K( c2->data.skels, c1->data.skels );
1743 // //hmlp::Data<T> C31 = K( c3->data.skels, c1->data.skels );
1744 // //hmlp::Data<T> C41 = K( c4->data.skels, c1->data.skels );
1745 // //hmlp::Data<T> C32 = K( c3->data.skels, c2->data.skels );
1746 // //hmlp::Data<T> C42 = K( c4->data.skels, c2->data.skels );
1747 // //hmlp::Data<T> C43 = K( c4->data.skels, c3->data.skels );
1748 //
1749 // //C21.WriteFile( "C21.m" );
1750 // //C31.WriteFile( "C31.m" );
1751 // //C41.WriteFile( "C41.m" );
1752 // //C32.WriteFile( "C32.m" );
1753 // //C42.WriteFile( "C42.m" );
1754 // //C43.WriteFile( "C43.m" );
1755 //
1756 //
1757 // //hmlp::Data<T> V11( c1->data.V.col(), c1->data.V.col() );
1758 // //hmlp::Data<T> V22( c2->data.V.col(), c2->data.V.col() );
1759 // //hmlp::Data<T> V33( c3->data.V.col(), c3->data.V.col() );
1760 // //hmlp::Data<T> V44( c4->data.V.col(), c4->data.V.col() );
1761 //
1762 // //xgemm( "T", "N", c1->data.V.col(), c1->data.V.col(), c1->data.V.row(),
1763 // // 1.0, c1->data.V.data(), c1->data.V.row(),
1764 // // c1->data.V.data(), c1->data.V.row(),
1765 // // 0.0, V11.data(), V11.row() );
1766 //
1767 // //xgemm( "T", "N", c2->data.V.col(), c2->data.V.col(), c2->data.V.row(),
1768 // // 1.0, c2->data.V.data(), c2->data.V.row(),
1769 // // c2->data.V.data(), c2->data.V.row(),
1770 // // 0.0, V22.data(), V22.row() );
1771 //
1772 // //xgemm( "T", "N", c3->data.V.col(), c3->data.V.col(), c3->data.V.row(),
1773 // // 1.0, c3->data.V.data(), c3->data.V.row(),
1774 // // c3->data.V.data(), c3->data.V.row(),
1775 // // 0.0, V33.data(), V33.row() );
1776 //
1777 // //xgemm( "T", "N", c4->data.V.col(), c4->data.V.col(), c4->data.V.row(),
1778 // // 1.0, c4->data.V.data(), c4->data.V.row(),
1779 // // c4->data.V.data(), c4->data.V.row(),
1780 // // 0.0, V44.data(), V44.row() );
1781 //
1782 // //V11.WriteFile( "V11.m" );
1783 // //V22.WriteFile( "V22.m" );
1784 // //V33.WriteFile( "V33.m" );
1785 // //V44.WriteFile( "V44.m" );
1786 // }
1787 // //printf( "end inner forward telescoping\n" ); fflush( stdout );
1788 //
1789 // /** check the offdiagonal block VrCrlVl' accuracy */
1790 // if ( !do_ulv_factorization )
1791 // LowRankError<NODE, T>( node );
1792 // }
1793 
1794 };
1801 template<typename NODE, typename T>
1802 class FactorizeTask : public Task
1803 {
1804  public:
1805 
1806  NODE *arg = NULL;
1807 
1808  void Set( NODE *user_arg )
1809  {
1810  arg = user_arg;
1811  name = string( "fa" );
1812  label = to_string( arg->treelist_id );
1813  // Need an accurate cost model.
1814  cost = 1.0;
1815  };
1816 
1818  {
1819  double flops = 0.0, mops = 0.0;
1820  event.Set( label + name, flops, mops );
1821  };
1822 
1823  void DependencyAnalysis() { arg->DependOnChildren( this ); };
1824 
1825  void Execute( Worker* user_worker ) { Factorize<NODE, T>( arg ); };
1826 
1827 };
1842 template<typename T, typename TREE>
1843 void Factorize( TREE &tree, T lambda )
1844 {
1845  using NODE = typename TREE::NODE;
1846 
1848  tree.DependencyCleanUp();
1849 
1851  tree.setup.lambda = lambda;
1852 
1854  tree.setup.do_ulv_factorization = true;
1855 
1857  SetupFactorTask<NODE, T> setupfactortask;
1858  tree.TraverseUp( setupfactortask );
1859  tree.ExecuteAllTasks();
1860 
1862  FactorizeTask<NODE, T> factorizetask;
1863  tree.TraverseUp( factorizetask );
1864  tree.ExecuteAllTasks();
1865 
1866 };
1874 template<typename TREE, typename T>
1875 void ComputeError( TREE &tree, T lambda, Data<T> weights, Data<T> potentials )
1876 {
1877  using NODE = typename TREE::NODE;
1878 
1879 
1881  assert( weights.row() == potentials.row() );
1882  assert( weights.col() == potentials.col() );
1883 
1884  size_t n = weights.row();
1885  size_t nrhs = weights.col();
1886 
1888  Data<T> rhs( n, nrhs );
1889  for ( size_t j = 0; j < nrhs; j ++ )
1890  for ( size_t i = 0; i < n; i ++ )
1891  rhs( i, j ) = potentials( i, j ) + lambda * weights( i, j );
1892 
1894  Solve( tree, rhs );
1895 
1896 
1898  printf( "========================================================\n" );
1899  printf( "Inverse accuracy report\n" );
1900  printf( "========================================================\n" );
1901  printf( "#rhs, max err, @, min err, @, relative \n" );
1902  printf( "========================================================\n" );
1903  size_t ntest = 10;
1904  T total_err = 0.0;
1905  for ( size_t j = 0; j < std::min( nrhs, ntest ); j ++ )
1906  {
1908  T nrm2 = 0.0, err2 = 0.0;
1909  T max2 = 0.0, min2 = std::numeric_limits<T>::max();
1911  size_t maxi = 0, mini = 0;
1912 
1913  for ( size_t i = 0; i < n; i ++ )
1914  {
1915  T sse = rhs( i, j ) - weights( i, j );
1916  assert( rhs( i, j ) == rhs( i, j ) );
1917  sse = sse * sse;
1918 
1919  nrm2 += weights( i, j ) * weights( i, j );
1920  err2 += sse;
1921 
1922  //printf( "%lu %3.1E\n", i, sse );
1923 
1924 
1925  if ( sse > max2 ) { max2 = sse; maxi = i; }
1926  if ( sse < min2 ) { min2 = sse; mini = i; }
1927  }
1928  total_err += std::sqrt( err2 / nrm2 );
1929 
1930  printf( "%4lu, %3.1E, %7lu, %3.1E, %7lu, %3.1E\n",
1931  j, std::sqrt( max2 ), maxi, std::sqrt( min2 ), mini,
1932  std::sqrt( err2 / nrm2 ) );
1933  }
1934  printf( "========================================================\n" );
1935  printf( " avg over %2lu rhs, %3.1E \n",
1936  std::min( nrhs, ntest ), total_err / std::min( nrhs, ntest ) );
1937  printf( "========================================================\n\n" );
1938 
1939 };
1948 };
1949 };
1951 #endif
void Partition1x2(View< T > &A1, View< T > &A2, size_t nb, SideType side)
Definition: View.hpp:180
void ChangeBasis(SideType side, Data< T > &B)
Definition: igofmm.hpp:903
doward traversal to create matrix views, at the leaf level execute explicit permutation.
Definition: igofmm.hpp:1241
Definition: igofmm.hpp:1471
vector< T > tau
Definition: igofmm.hpp:1073
void Factorize(Data< T > &Kaa)
Definition: igofmm.hpp:145
Definition: igofmm.hpp:1345
void Solve(View< T > &rhs)
Solver for leaf nodes.
Definition: igofmm.hpp:517
void ULVForward()
Definition: igofmm.hpp:978
Definition: igofmm.hpp:70
Definition: igofmm.hpp:1802
void Execute(Worker *user_worker)
Definition: igofmm.hpp:1273
void xgeqrf(int m, int n, double *A, int lda, double *tau, double *work, int lwork)
DGEQRF wrapper.
Definition: blas_lapack.cpp:694
void DependencyCleanUp()
Definition: runtime.cpp:436
size_t row()
Definition: View.hpp:345
void xgetrf(int m, int n, double *A, int lda, int *ipiv)
DGETRF wrapper.
Definition: blas_lapack.cpp:546
void xorgqr(int m, int n, int k, double *A, int lda, double *tau, double *work, int lwork)
SORGQR wrapper.
Definition: blas_lapack.cpp:757
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
Data< T > Crl
Definition: igofmm.hpp:1057
void xgetrs(const char *trans, int m, int nrhs, double *A, int lda, int *ipiv, double *B, int ldb)
DGETRS wrapper.
Definition: blas_lapack.cpp:577
Definition: igofmm.hpp:1134
Data< T > Z
Definition: igofmm.hpp:1046
size_t col()
Definition: View.hpp:348
SideType
Definition: View.hpp:37
T * data()
Definition: View.hpp:354
Data< T > Q
Definition: igofmm.hpp:1069
Data< T > B
Definition: igofmm.hpp:1076
void xlaswp(int n, double *A, int lda, int k1, int k2, int *ipiv, int incx)
DLASWP wrapper.
Definition: blas_lapack.cpp:450
void PartialFactorize(Data< T > &A)
Definition: igofmm.hpp:179
View< T > bview
Definition: igofmm.hpp:1060
void ChangeBasis(Data< T > &A)
Definition: igofmm.hpp:970
void Multiply(View< T > &bl, View< T > &br)
Definition: igofmm.hpp:456
vector< int > ipiv
Definition: igofmm.hpp:1051
size_t col() const noexcept
Definition: Data.hpp:281
void xgecon(const char *norm, int n, double *A, int lda, double anorm, double *rcond, double *work, int *iwork)
DGECON wrapper.
Definition: blas_lapack.cpp:631
void DependencyAnalysis()
Definition: igofmm.hpp:1228
void DependencyAnalysis()
Definition: igofmm.hpp:1261
void Solve(View< T > &bl, View< T > &br)
b - U * inv( Z ) * C * V&#39; * b
Definition: igofmm.hpp:541
size_t row() const noexcept
Definition: Data.hpp:278
Definition: igofmm.hpp:1383
void GetEventRecord()
Definition: igofmm.hpp:1487
Definition: View.hpp:43
size_t ld()
Definition: View.hpp:351
void xpotrf(const char *uplo, int n, double *A, int lda)
DPOTRF wrapper.
Definition: blas_lapack.cpp:480
void Orthogonalization()
Definition: igofmm.hpp:844
void xtrsm(const char *side, const char *uplo, const char *transA, const char *diag, int m, int n, double alpha, double *A, int lda, double *B, int ldb)
DTRSM wrapper.
Definition: blas_lapack.cpp:315
void PartialFactorize(View< T > &Zl, View< T > &Zr, Data< T > &Ul, Data< T > &Ur, Data< T > &Vl, Data< T > &Vr)
Definition: igofmm.hpp:413
void DependencyAnalysis()
Definition: igofmm.hpp:1410
void ULVBackward()
Definition: igofmm.hpp:998
Data< T > U
Definition: igofmm.hpp:1054
void xtrmm(const char *side, const char *uplo, const char *transA, const char *diag, int m, int n, double alpha, double *A, int lda, double *B, int ldb)
DTRMM wrapper.
Definition: blas_lapack.cpp:385
void GetEventRecord()
Definition: igofmm.hpp:1148
Definition: gofmm.hpp:83
void GetEventRecord()
Definition: igofmm.hpp:1817
Definition: runtime.hpp:174
void Partition2x1(View< T > &A1, View< T > &A2, size_t mb, SideType side)
Definition: View.hpp:155
void DependencyAnalysis()
Definition: igofmm.hpp:1372
void GetEventRecord()
Definition: igofmm.hpp:1221
Creates an hierarchical tree view for a matrix.
Definition: igofmm.hpp:1207
void GetEventRecord()
Definition: igofmm.hpp:1254
Definition: thread.hpp:166