78 bool issymmetric,
bool do_ulv_factorization,
79 bool isleaf,
bool isroot,
81 size_t n,
size_t nl,
size_t nr,
83 size_t s,
size_t sl,
size_t sr
86 this->issymmetric = issymmetric;
87 this->do_ulv_factorization = do_ulv_factorization;
88 this->isleaf = isleaf;
89 this->isroot = isroot;
90 this->n = n; this->nl = nl; this->nr = nr;
91 this->s = s; this->sl = sl; this->sr = sr;
96 bool issymmetric,
bool do_ulv_factorization,
97 bool isleaf,
bool isroot,
98 size_t n,
size_t nl,
size_t nr,
99 size_t s,
size_t sl,
size_t sr,
106 SetupFactor( issymmetric, do_ulv_factorization,
107 isleaf, isroot, n, nl, nr, s, sl, sr );
110 bool DoULVFactorization()
112 return do_ulv_factorization;
115 bool IsSymmetric() {
return issymmetric; };
121 void CheckCondition()
123 assert( do_ulv_factorization && issymmetric );
127 for (
size_t i = 0; i < Z.row(); i ++ )
129 T abs_diag = std::abs( Z( i, i ) );
137 if ( abs_diag > max_diag ) max_diag = abs_diag;
138 if ( abs_diag < min_diag ) min_diag = abs_diag;
141 printf(
"condiditon( Z ): min_diag %3.1E max_diag %3.1E ratio %3.1E\n",
142 min_diag, max_diag, min_diag / max_diag );
148 assert( Kaa.
row() == n ); assert( Kaa.
col() == n );
158 for (
auto &z : Z ) nrm1 += z;
161 xgetrf( n, n, Z.data(), n, ipiv.data() );
166 vector<int> iwork( Z.row() );
167 xgecon(
"1", Z.row(), Z.data(), Z.row(), nrm1,
168 &rcond1, work.data(), iwork.data() );
169 if ( 1.0 / rcond1 > 1E+6 )
170 printf(
"Warning! large 1-norm condition number %3.1E, nrm1( Z ) %3.1E\n",
171 1.0 / rcond1, nrm1 );
187 Zv.Partition2x2( Ztl, Ztr,
188 Zbl, Zbr, s, s, BOTTOMRIGHT );
194 ipiv.resize( Ztl.row(), 0 );
196 xgetrf( Ztl.row(), Z.col(), Z.data(), Z.row(), ipiv.data() );
198 xtrsm(
"Right",
"Upper",
"No transpose",
"Non-unit", Zbl.row(), Zbl.col(),
199 1.0, Ztl.data(), Ztl.ld(), Zbl.data(), Zbl.ld() );
201 xgemm(
"No transpose",
"No transpose", Zbr.row(), Zbr.col(), Ztl.col(),
202 -1.0, Zbl.data(), Zbl.ld(),
203 Ztr.data(), Ztr.ld(),
204 1.0, Zbr.data(), Zbr.ld() );
252 assert( Crl.row() == sr ); assert( Crl.col() == sl );
256 assert( Clr.row() == sl ); assert( Clr.col() == sr );
257 assert( Crl.row() == sr ); assert( Crl.col() == sl );
267 Z.resize( sl + sr, sl + sr, 0.0 );
268 for (
size_t i = 0; i < sl + sr; i ++ ) Z[ i * Z.row() + i ] = 1.0;
272 if ( do_ulv_factorization )
290 "Right",
"Upper",
"Transpose",
"Non-unit",
292 1.0, Ul.data(), Ul.
row(),
293 Zbl.data(), Zbl.
row()
301 "Left",
"Upper",
"Non-transpose",
"Non-unit",
303 1.0, Ur.data(), Ur.
row(),
304 Zbl.data(), Zbl.
row()
310 for (
size_t j = 0; j < sl; j ++ )
311 for (
size_t i = 0; i < sr; i ++ )
313 Z( sl + i, j ) = Zbl( i, j );
314 Z( j, sl + i ) = Zbl( i, j );
320 xpotrf(
"Lower", Z.row(), Z.data(), Z.row() );
326 ipiv.resize( Z.row(), 0 );
327 xgetrf( Z.row(), Z.col(), Z.data(), Z.row(), ipiv.data() );
333 ipiv.resize( Z.row(), 0 );
339 ipiv.resize( Z.row(), 0 );
345 std::vector<T> VltUl( sl * sl, 0.0 );
346 std::vector<T> VrtUr( sr * sr, 0.0 );
349 xgemm(
"T",
"N", sl, sl, nl,
352 0.0, VltUl.data(), sl );
355 xgemm(
"T",
"N", sr, sr, nr,
358 0.0, VrtUr.data(), sr );
361 xgemm(
"N",
"N", sr, sl, sl,
364 0.0, Z.data() + sl, sl + sr );
370 xgemm(
"T",
"N", sl, sr, sr,
373 0.0, Z.data() + ( sl + sr ) * sl, sl + sr );
377 printf(
"bug\n" ); exit( 1 );
379 xgemm(
"N",
"N", sl, sr, sr,
382 0.0, Z.data() + ( sl + sr ) * sl, sl + sr );
387 for (
size_t i = 0; i < Z.size(); i ++ )
388 nrm1 += std::abs( Z[ i ] );
391 xgetrf( Z.row(), Z.col(), Z.data(), Z.row(), ipiv.data() );
402 std::vector<int> iwork( Z.row() );
403 xgecon(
"1", Z.row(), Z.data(), Z.row(), nrm1,
404 &rcond1, work.data(), iwork.data() );
405 if ( 1.0 / rcond1 > 1E+6 )
406 printf(
"Warning! large 1-norm condition number %3.1E\n",
407 1.0 / rcond1 ); fflush( stdout );
422 Z.resize( sl + sr, sl + sr, 0.0 );
426 Zv.Partition2x2( Ztl, Ztr,
427 Zbl, Zbr, sl, sl, TOPLEFT );
433 Zbl.CopyValuesFrom( Crl );
435 xtrmm(
"Right",
"Upper",
"Transpose",
"Non-unit", Zbl.row(), Zbl.col(),
436 1.0, Ul.data(), Ul.
row(), Zbl.data(), Zbl.ld() );
438 xtrmm(
"Left",
"Upper",
"Non-transpose",
"Non-unit", Zbl.row(), Zbl.col(),
439 1.0, Ur.data(), Ur.
row(), Zbl.data(), Zbl.ld() );
441 Ztl.CopyValuesFrom( Zl );
442 Zbr.CopyValuesFrom( Zr );
444 for (
size_t j = 0; j < sl; j ++ )
445 for (
size_t i = 0; i < sr; i ++ )
446 Ztr( j, i ) = Zbl( i, j );
448 PartialFactorize( Z );
458 assert( !isleaf && bl.
col() == br.
col() );
460 size_t nrhs = bl.
col();
462 std::vector<T> ta( ( sl + sr ) * nrhs );
463 std::vector<T> tl( sl * nrhs );
464 std::vector<T> tr( sr * nrhs );
467 xgemm(
"T",
"N", sl, nrhs, nl,
470 0.0, tl.data(), sl );
472 xgemm(
"T",
"N", sr, nrhs, nr,
475 0.0, tr.data(), sr );
478 xgemm(
"N",
"N", sr, nrhs, sl,
481 0.0, ta.data() + sl, sl + sr );
486 xgemm(
"T",
"N", sl, nrhs, sr,
489 0.0, ta.data(), sl + sr );
493 printf(
"bug here !!!!!\n" ); fflush( stdout ); exit( 1 );
495 xgemm(
"N",
"N", sl, nrhs, sr,
498 0.0, ta.data(), sl + sr );
502 xgemm(
"N",
"N", nl, nrhs, sl,
503 -1.0, Ul->data(), nl,
505 1.0, bl.
data(), bl.
ld() );
508 xgemm(
"N",
"N", nr, nrhs, sr,
509 -1.0, Ur->data(), nr,
510 ta.data() + sl, sl + sr,
511 1.0, br.
data(), br.
ld() );
521 assert( !do_ulv_factorization );
522 assert( rhs.
data() && Z.data() );
523 assert( ipiv.data() );
527 size_t nrhs = rhs.
col();
530 xgetrs(
"Non-transpose", rhs.
row(), nrhs,
531 Z.data(), Z.row(), ipiv.data(),
543 size_t nrhs = bl.
col();
549 assert( !do_ulv_factorization );
550 assert( bl.
col() == br.
col() );
551 assert( bl.
row() == nl );
552 assert( br.
row() == nr );
553 assert( Ul && Ur && Vl && Vr );
560 vector<T> ta( ( sl + sr ) * nrhs );
561 vector<T> tl( sl * nrhs );
562 vector<T> tr( sr * nrhs );
577 xgemm(
"T",
"N", sl, nrhs, nl,
580 0.0, tl.data(), sl );
582 xgemm(
"T",
"N", sr, nrhs, nr,
585 0.0, tr.data(), sr );
589 xgemm(
"N",
"N", sr, nrhs, sl,
592 0.0, ta.data() + sl, sl + sr );
597 xgemm(
"T",
"N", sl, nrhs, sr,
600 0.0, ta.data(), sl + sr );
604 printf(
"bug here !!!!!\n" ); fflush( stdout ); exit( 1 );
606 xgemm(
"N",
"N", sl, nrhs, sr,
609 0.0, ta.data(), sl + sr );
613 xgetrs(
"N", sl + sr, nrhs,
614 Z.data(), Z.row(), ipiv.data(),
615 ta.data(), sl + sr );
618 xgemm(
"N",
"N", nl, nrhs, sl,
619 -1.0, Ul->data(), nl,
621 1.0, bl.
data(), bl.
ld() );
624 xgemm(
"N",
"N", nr, nrhs, sr,
625 -1.0, Ur->data(), nr,
626 ta.data() + sl, sl + sr,
627 1.0, br.
data(), br.
ld() );
645 Pa.resize( n, s, 0.0 );
652 assert( Palr.
row() == s ); assert( Palr.
col() == n );
655 for (
size_t j = 0; j < Pa.
col(); j ++ )
656 for (
size_t i = 0; i < Pa.
row(); i ++ )
657 Pa( i, j ) = Palr( j, i );
661 if ( do_ulv_factorization )
663 xtrsm(
"Left",
"Lower",
"No transpose",
"Non-unit",
665 1.0, Z.data(), Z.row(), Pa.data(), Pa.
row() );
669 assert( ipiv.size() );
672 n, s, Z.data(), n, ipiv.data(), Pa.data(), n );
700 assert( n == nl + nr );
701 assert( Pl.
col() == sl );
702 assert( Pr.
col() == sr );
703 assert( Palr.
row() == s ); assert( Palr.
col() == ( sl + sr ) );
715 if ( do_ulv_factorization )
717 Pa.resize( sl + sr, s, 0.0 );
720 for (
size_t j = 0; j < Pa.
col(); j ++ )
721 for (
size_t i = 0; i < Pa.
row(); i ++ )
722 Pa[ j * Pa.
row() + i ] = Palr[ i * Palr.
row() + j ];
725 xtrmm(
"Left",
"Upper",
"No Transpose",
"Non-unit", sl, s,
726 1.0, Pl.data(), Pl.
row(),
727 Pa.data(), Pa.
row() );
732 xtrmm(
"Left",
"Upper",
"No Transpose",
"Non-unit", sr, s,
733 1.0, Pr.data() , Pr.
row(),
734 Pa.data() + sl, Pa.
row() );
743 xtrsm(
"Left",
"Lower",
"No transpose",
"Non-unit",
745 1.0, Z.data(), Z.row(),
746 Pa.data(), Pa.
row() );
751 1, Pa.
row(), ipiv.data(), 1 );
752 xtrsm(
"Left",
"Lower",
"No transpose",
"Unit", Pa.
row(), Pa.
col(),
753 1.0, Z.data(), Z.row(),
754 Pa.data(), Pa.
row() );
762 Pa.resize( nl + nr, s, 0.0 );
781 xgemm(
"N",
"T", nl, s, sl,
786 xgemm(
"N",
"T", nr, s, sr,
788 Palr.data() + s * sl, s,
789 0.0, Pa.data() + nl, n );
803 xgemm(
"T",
"N", sl, s, nl,
806 0.0, xl.data(), sl );
808 xgemm(
"T",
"N", sr, s, nr,
811 0.0, xr.data(), sr );
815 xgemm(
"T",
"N", sl, s, sr,
818 0.0, x.data(), sl + sr );
819 xgemm(
"N",
"N", sr, s, sl,
822 0.0, x.data() + sl, sl + sr );
825 xgetrs(
"N", x.
row(), x.
col(),
826 Z.data(), Z.row(), ipiv.data(),
830 xgemm(
"N",
"N", nl, s, sl,
831 -1.0, Ul->data(), nl,
835 xgemm(
"N",
"N", nr, s, sr,
836 -1.0, Ur->data(), nr,
837 x.data() + sl, sl + sr,
838 1.0, Pa.data() + nl, n );
847 tau.resize( std::min( U.
row(), U.
col() ) );
852 tau.data(), work.data(), work.size() );
856 Q.resize( U.
row(), U.
row() );
859 work.data(), work.size() );
865 Qv.Partition1x2( Q1, Q2, tau.size(), LEFT );
870 xgemm(
"Transpose",
"No Transpose", C.
row(), C.
col(), Q.
row(),
871 1.0, Q.data(), Q.
row(),
873 0.0, C.data(), C.
row() );
875 xgemm(
"No Transpose",
"Transpose", D.
row(), D.
col(), Q.
row(),
876 1.0, Q.data(), Q.
row(),
878 0.0, D.data(), D.
row() );
880 for (
size_t j = 0; j < Q.
col(); j ++ )
882 for (
size_t i = 0; i < Q.
row(); i ++ )
884 if ( i == j ) assert( std::fabs( C( i, j ) - 1 ) < 1E-5 );
885 else assert( std::fabs( C( i, j ) - 0 ) < 1E-5 );
888 for (
size_t j = 0; j < Q.
col(); j ++ )
890 for (
size_t i = 0; i < Q.
row(); i ++ )
892 if ( i == j ) assert( std::fabs( D( i, j ) - 1 ) < 1E-5 );
893 else assert( std::fabs( D( i, j ) - 0 ) < 1E-5 );
906 if ( !Q.size() )
return;
928 xgemm(
"Transpose",
"No Transpose", Bt.
row(), Bt.
col(), Q2.row(),
929 1.0, Q2.data(), Q2.ld(),
931 0.0, Bt.
data(), Bt.
ld() );
933 xgemm(
"Transpose",
"No Transpose", Bb.
row(), Bb.
col(), Q1.row(),
934 1.0, Q1.data(), Q1.ld(),
936 0.0, Bb.
data(), Bb.
ld() );
948 xgemm(
"No Transpose",
"No Transpose", Bl.
row(), Bl.
col(), Q2.row(),
951 0.0, Bl.
data(), Bl.
ld() );
953 xgemm(
"No Transpose",
"No Transpose", Br.
row(), Br.
col(), Q1.row(),
956 0.0, Br.
data(), Br.
ld() );
962 throw "Value of (SideType) side is not recognized.";
972 ChangeBasis( LEFT, A );
973 ChangeBasis( RIGHT, A );
981 if ( isleaf ) B = bview.toData();
983 ChangeBasis( LEFT, B );
985 xlaswp( Bf.col(), Bf.data(), Bf.ld(), 1, Bf.row(), ipiv.data(), 1 );
987 xtrsm(
"Left",
"Lower",
"No transpose",
"Unit", Bf.row(), Bf.col(),
988 1.0, Ztl.data(), Ztl.ld(), Bf.data(), Bf.ld() );
990 xgemm(
"No Transpose",
"No Transpose", Bc.row(), Bc.col(), Bf.row(),
991 -1.0, Zbl.data(), Zbl.ld(), Bf.data(), Bf.ld(), 1.0, Bc.data(), Bc.ld() );
994 Bp.CopyValuesFrom( Bc );
1001 Bc.CopyValuesFrom( Bp );
1003 xgemm(
"No Transpose",
"No Transpose", Bf.row(), Bf.col(), Bc.row(),
1004 -1.0, Ztr.data(), Ztr.ld(), Bc.data(), Bc.ld(), 1.0, Bf.data(), Bf.ld() );
1006 xtrsm(
"Left",
"Upper",
"No transpose",
"Non-unit", Bf.row(), Bf.col(),
1007 1.0, Ztl.data(), Ztl.ld(), Bf.data(), Bf.ld() );
1012 xgemm(
"No Transpose",
"No Transpose", A.
row(), A.
col(), Bf.row(),
1013 1.0, Q2.data(), Q2.ld(), Bf.data(), Bf.ld(), 0.0, A.data(), A.
row() );
1014 xgemm(
"No Transpose",
"No Transpose", A.
row(), A.
col(), Bc.row(),
1015 1.0, Q1.data(), Q1.ld(), Bc.data(), Bc.ld(), 1.0, A.data(), A.
row() );
1017 if ( isleaf ) bview.CopyValuesFrom( A );
1018 else Bv.CopyValuesFrom( A );
1027 bool isleaf =
false;
1029 bool isroot =
false;
1077 View<T> Bv, Bp, Bsibling, Bf, Bc;
1081 bool issymmetric =
true;
1083 bool do_ulv_factorization =
false;
1091 template<
typename NODE,
typename T>
1092 void SetupFactor( NODE *node )
1094 size_t n, nl, nr, s, sl, sr;
1095 bool issymmetric, do_ulv_factorization;
1099 printf(
"begin SetupFactor %lu\n", node->treelist_id ); fflush( stdout );
1102 issymmetric = node->setup->IsSymmetric();
1103 do_ulv_factorization = node->setup->do_ulv_factorization;
1107 s = node->data.skels.size();
1111 if ( !node->isleaf )
1113 nl = node->lchild->n;
1114 nr = node->rchild->n;
1115 sl = node->lchild->data.skels.size();
1116 sr = node->rchild->data.skels.size();
1120 node->data.SetupFactor( issymmetric, do_ulv_factorization,
1121 node->isleaf, !node->l, n, nl, nr, s, sl, sr );
1124 printf(
"end SetupFactor %lu\n", node->treelist_id ); fflush( stdout );
1133 template<
typename NODE,
typename T>
1140 void Set( NODE *user_arg )
1143 name = string(
"sf" );
1144 label = to_string( arg->treelist_id );
1150 double flops = 0.0, mops = 0.0;
1151 event.Set( label + name, flops, mops );
1154 void DependencyAnalysis()
1156 arg->DependencyAnalysis( W,
this );
1160 void Execute(
Worker* user_worker )
1162 SetupFactor<NODE, T>( arg );
1170 template<
typename NODE>
1171 void SolverTreeView( NODE *node )
1173 auto &data = node->data;
1174 auto *setup = node->setup;
1175 auto &input = *(setup->input);
1176 auto &output = *(setup->output);
1178 if ( node->isleaf ) data.B.resize( data.n, input.col() );
1179 else data.B.resize( data.sl + data.sr, input.col() );
1182 data.Bv.Set( data.B );
1183 data.Bv.Partition2x1( data.Bf,
1184 data.Bc, data.s, BOTTOM );
1187 if ( !node->parent ) data.bview.Set( output );
1190 if ( !node->isleaf )
1192 auto &ldata = node->lchild->data;
1193 auto &rdata = node->rchild->data;
1195 data.bview.Partition2x1( ldata.bview,
1196 rdata.bview, data.nl, TOP );
1197 data.Bv.Partition2x1( ldata.Bp,
1198 rdata.Bp, data.sl, TOP );
1206 template<
typename NODE>
1213 void Set( NODE *user_arg )
1216 name = string(
"TreeView" );
1217 label = to_string( arg->treelist_id );
1223 double flops = 0.0, mops = 0.0;
1224 event.Set( label + name, flops, mops );
1230 void Execute(
Worker* user_worker ) { SolverTreeView( arg ); };
1240 template<
bool FORWARD,
typename NODE>
1247 void Set( NODE *user_arg )
1249 name = std::string(
"MatrixPermutation" );
1256 double flops = 0.0, mops = 0.0;
1257 event.Set( label + name, flops, mops );
1265 arg->DependencyAnalysis( RW,
this );
1277 auto &gids = node->gids;
1278 auto &input = *(node->setup->input);
1279 auto &output = *(node->setup->output);
1280 auto &A = node->data.bview;
1282 assert( A.row() == gids.size() );
1283 assert( A.col() == input.col() );
1290 for (
size_t j = 0; j < input.col(); j ++ )
1291 for (
size_t i = 0; i < gids.size(); i ++ )
1293 if ( FORWARD ) A( i, j ) = input( gids[ i ], j );
1295 else input( gids[ i ], j ) = A( i, j );
1312 template<
typename NODE,
typename T>
1313 void Apply( NODE *node )
1315 auto &data = node->data;
1316 auto &setup = node->setup;
1317 auto &K = *setup->K;
1321 auto lambda = setup->lambda;
1322 auto &amap = node->gids;
1324 auto Kaa = K( amap, amap );
1326 for (
size_t i = 0; i < Kaa.row(); i ++ )
1327 Kaa[ i * Kaa.row() + i ] += lambda;
1331 auto &bl = node->lchild->data.bview;
1332 auto &br = node->rchild->data.bview;
1333 data.Apply<
true>( bl, br );
1344 template<
typename NODE,
typename T>
1351 void Set( NODE *user_arg )
1354 name = string(
"ulvforward" );
1355 label = to_string( arg->treelist_id );
1375 void Execute(
Worker* user_worker ) { arg->data.ULVForward(); };
1382 template<
typename NODE,
typename T>
1389 void Set( NODE *user_arg )
1392 name = string(
"ulvbackward" );
1393 label = std::to_string( arg->treelist_id );
1412 void Execute(
Worker* user_worker ) { arg->data.ULVBackward(); };
1436 template<
typename NODE,
typename T>
1437 void Solve( NODE *node )
1440 auto &data = node->data;
1441 auto &setup = node->setup;
1442 auto &K = *setup->K;
1450 auto &b = data.bview;
1456 auto &bl = node->lchild->data.bview;
1457 auto &br = node->rchild->data.bview;
1458 data.Solve( bl, br );
1470 template<
typename NODE,
typename T>
1477 void Set( NODE *user_arg )
1480 name = string(
"sl" );
1481 label = to_string( arg->treelist_id );
1489 double flops = 0.0, mops = 0.0;
1490 event.Set( label + name, flops, mops );
1493 void DependencyAnalysis()
1495 arg->DependencyAnalysis( RW,
this );
1498 arg->lchild->DependencyAnalysis( R,
this );
1499 arg->rchild->DependencyAnalysis( R,
this );
1503 void Execute(
Worker* user_worker )
1505 Solve<NODE, T>( arg );
1514 template<
typename T,
typename TREE>
1515 void Solve( TREE &tree,
Data<T> &input )
1517 using NODE =
typename TREE::NODE;
1519 const bool AUTO_DEPENDENCY =
true;
1520 const bool USE_RUNTIME =
true;
1535 tree.setup.input = &input;
1536 tree.setup.output = output;
1538 if ( tree.setup.do_ulv_factorization )
1542 tree.TraverseDown( treeviewtask );
1543 tree.TraverseLeafs( forwardpermutetask );
1544 tree.TraverseUp( ulvforwardsolvetask );
1545 tree.TraverseDown( ulvbackwardsolvetask );
1546 if ( USE_RUNTIME ) hmlp_run();
1549 tree.DependencyCleanUp();
1550 tree.TraverseLeafs( inversepermutetask );
1551 if ( USE_RUNTIME ) hmlp_run();
1556 tree.DependencyCleanUp();
1557 tree.TraverseDown( treeviewtask );
1558 tree.TraverseLeafs( forwardpermutetask );
1559 tree.TraverseUp( solvetask1 );
1560 if ( USE_RUNTIME ) hmlp_run();
1562 tree.DependencyCleanUp();
1563 tree.TraverseLeafs( inversepermutetask );
1564 if ( USE_RUNTIME ) hmlp_run();
1580 template<
typename NODE,
typename T>
1581 void LowRankError( NODE *node )
1583 auto &data = node->data;
1584 auto &setup = node->setup;
1585 auto &K = *setup->K;
1587 if ( !node->isleaf )
1589 auto Krl = K( node->rchild->gids, node->lchild->gids );
1591 auto nrm2 = hmlp_norm( Krl.row(), Krl.col(),
1592 Krl.data(), Krl.row() );
1598 xgemm(
"N",
"N", data.nr, data.sl, data.sr,
1599 1.0, data.Vr->data(), data.nr,
1600 data.Crl.data(), data.sr,
1601 0.0, VrCrl.data(), data.nr );
1604 xgemm(
"N",
"T", data.nr, data.nl, data.sl,
1605 -1.0, VrCrl.data(), data.nr,
1606 data.Vl->data(), data.nl,
1607 1.0, Krl.data(), data.nr );
1609 auto err = hmlp_norm( Krl.row(), Krl.col(),
1610 Krl.data(), Krl.row() );
1612 printf(
"%4lu ||Krl -VrCrlVl|| %3.1E\n",
1613 node->treelist_id, std::sqrt( err / nrm2 ) );
1623 template<
typename NODE,
typename T>
1624 void Factorize( NODE *node )
1626 auto &data = node->data;
1627 auto &setup = node->setup;
1628 auto &K = *setup->K;
1629 auto &proj = data.proj;
1631 auto do_ulv_factorization = setup->do_ulv_factorization;
1635 auto lambda = setup->lambda;
1636 auto &amap = node->gids;
1639 Data<T> Kaa = K( amap, amap );
1642 for (
size_t i = 0; i < Kaa.
row(); i ++ ) Kaa( i, i ) += lambda;
1644 if ( do_ulv_factorization )
1647 data.Telescope(
false, data.U, proj );
1649 data.Orthogonalization();
1651 data.PartialFactorize( Kaa );
1656 data.Factorize( Kaa );
1658 data.Telescope(
true, data.U, proj );
1660 data.Telescope(
false, data.V, proj );
1665 auto &Ul = node->lchild->data.U;
1666 auto &Vl = node->lchild->data.V;
1667 auto &Zl = node->lchild->data.Zbr;
1668 auto &Ur = node->rchild->data.U;
1669 auto &Vr = node->rchild->data.V;
1670 auto &Zr = node->rchild->data.Zbr;
1673 auto &amap = node->lchild->data.skels;
1674 auto &bmap = node->rchild->data.skels;
1677 node->data.Crl = K( bmap, amap );
1679 if ( do_ulv_factorization )
1681 if ( !node->data.isroot )
1683 data.Telescope(
false, data.U, proj, Ul, Ur );
1684 data.Orthogonalization();
1686 data.PartialFactorize( Zl, Zr, Ul, Ur, Vl, Vr );
1691 data.Factorize( Ul, Ur, Vl, Vr );
1693 if ( !node->data.isroot )
1696 data.Telescope(
true, data.U, proj, Ul, Ur );
1698 data.Telescope(
false, data.V, proj, Vl, Vr );
1801 template<
typename NODE,
typename T>
1808 void Set( NODE *user_arg )
1811 name = string(
"fa" );
1812 label = to_string( arg->treelist_id );
1819 double flops = 0.0, mops = 0.0;
1820 event.Set( label + name, flops, mops );
1823 void DependencyAnalysis() { arg->DependOnChildren(
this ); };
1825 void Execute(
Worker* user_worker ) { Factorize<NODE, T>( arg ); };
1842 template<
typename T,
typename TREE>
1843 void Factorize( TREE &tree, T lambda )
1845 using NODE =
typename TREE::NODE;
1848 tree.DependencyCleanUp();
1851 tree.setup.lambda = lambda;
1854 tree.setup.do_ulv_factorization =
true;
1858 tree.TraverseUp( setupfactortask );
1859 tree.ExecuteAllTasks();
1863 tree.TraverseUp( factorizetask );
1864 tree.ExecuteAllTasks();
1874 template<
typename TREE,
typename T>
1875 void ComputeError( TREE &tree, T lambda,
Data<T> weights,
Data<T> potentials )
1877 using NODE =
typename TREE::NODE;
1881 assert( weights.
row() == potentials.
row() );
1882 assert( weights.
col() == potentials.
col() );
1884 size_t n = weights.
row();
1885 size_t nrhs = weights.
col();
1889 for (
size_t j = 0; j < nrhs; j ++ )
1890 for (
size_t i = 0; i < n; i ++ )
1891 rhs( i, j ) = potentials( i, j ) + lambda * weights( i, j );
1898 printf(
"========================================================\n" );
1899 printf(
"Inverse accuracy report\n" );
1900 printf(
"========================================================\n" );
1901 printf(
"#rhs, max err, @, min err, @, relative \n" );
1902 printf(
"========================================================\n" );
1905 for (
size_t j = 0; j < std::min( nrhs, ntest ); j ++ )
1908 T nrm2 = 0.0, err2 = 0.0;
1909 T max2 = 0.0, min2 = std::numeric_limits<T>::max();
1911 size_t maxi = 0, mini = 0;
1913 for (
size_t i = 0; i < n; i ++ )
1915 T sse = rhs( i, j ) - weights( i, j );
1916 assert( rhs( i, j ) == rhs( i, j ) );
1919 nrm2 += weights( i, j ) * weights( i, j );
1925 if ( sse > max2 ) { max2 = sse; maxi = i; }
1926 if ( sse < min2 ) { min2 = sse; mini = i; }
1928 total_err += std::sqrt( err2 / nrm2 );
1930 printf(
"%4lu, %3.1E, %7lu, %3.1E, %7lu, %3.1E\n",
1931 j, std::sqrt( max2 ), maxi, std::sqrt( min2 ), mini,
1932 std::sqrt( err2 / nrm2 ) );
1934 printf(
"========================================================\n" );
1935 printf(
" avg over %2lu rhs, %3.1E \n",
1936 std::min( nrhs, ntest ), total_err / std::min( nrhs, ntest ) );
1937 printf(
"========================================================\n\n" );
void Partition1x2(View< T > &A1, View< T > &A2, size_t nb, SideType side)
Definition: View.hpp:180
void ChangeBasis(SideType side, Data< T > &B)
Definition: igofmm.hpp:903
doward traversal to create matrix views, at the leaf level execute explicit permutation.
Definition: igofmm.hpp:1241
Definition: igofmm.hpp:1471
vector< T > tau
Definition: igofmm.hpp:1073
void Factorize(Data< T > &Kaa)
Definition: igofmm.hpp:145
Definition: igofmm.hpp:1345
void Solve(View< T > &rhs)
Solver for leaf nodes.
Definition: igofmm.hpp:517
void ULVForward()
Definition: igofmm.hpp:978
Definition: igofmm.hpp:70
Definition: igofmm.hpp:1802
void Execute(Worker *user_worker)
Definition: igofmm.hpp:1273
void xgeqrf(int m, int n, double *A, int lda, double *tau, double *work, int lwork)
DGEQRF wrapper.
Definition: blas_lapack.cpp:694
void DependencyCleanUp()
Definition: runtime.cpp:436
size_t row()
Definition: View.hpp:345
void xgetrf(int m, int n, double *A, int lda, int *ipiv)
DGETRF wrapper.
Definition: blas_lapack.cpp:546
void xorgqr(int m, int n, int k, double *A, int lda, double *tau, double *work, int lwork)
SORGQR wrapper.
Definition: blas_lapack.cpp:757
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
Data< T > Crl
Definition: igofmm.hpp:1057
void xgetrs(const char *trans, int m, int nrhs, double *A, int lda, int *ipiv, double *B, int ldb)
DGETRS wrapper.
Definition: blas_lapack.cpp:577
Definition: igofmm.hpp:1134
Data< T > Z
Definition: igofmm.hpp:1046
size_t col()
Definition: View.hpp:348
SideType
Definition: View.hpp:37
T * data()
Definition: View.hpp:354
Data< T > Q
Definition: igofmm.hpp:1069
Data< T > B
Definition: igofmm.hpp:1076
void xlaswp(int n, double *A, int lda, int k1, int k2, int *ipiv, int incx)
DLASWP wrapper.
Definition: blas_lapack.cpp:450
void PartialFactorize(Data< T > &A)
Definition: igofmm.hpp:179
View< T > bview
Definition: igofmm.hpp:1060
void ChangeBasis(Data< T > &A)
Definition: igofmm.hpp:970
void Multiply(View< T > &bl, View< T > &br)
Definition: igofmm.hpp:456
vector< int > ipiv
Definition: igofmm.hpp:1051
size_t col() const noexcept
Definition: Data.hpp:281
void xgecon(const char *norm, int n, double *A, int lda, double anorm, double *rcond, double *work, int *iwork)
DGECON wrapper.
Definition: blas_lapack.cpp:631
void DependencyAnalysis()
Definition: igofmm.hpp:1228
void DependencyAnalysis()
Definition: igofmm.hpp:1261
void Solve(View< T > &bl, View< T > &br)
b - U * inv( Z ) * C * V' * b
Definition: igofmm.hpp:541
size_t row() const noexcept
Definition: Data.hpp:278
Definition: igofmm.hpp:1383
void GetEventRecord()
Definition: igofmm.hpp:1487
size_t ld()
Definition: View.hpp:351
void xpotrf(const char *uplo, int n, double *A, int lda)
DPOTRF wrapper.
Definition: blas_lapack.cpp:480
void Orthogonalization()
Definition: igofmm.hpp:844
void xtrsm(const char *side, const char *uplo, const char *transA, const char *diag, int m, int n, double alpha, double *A, int lda, double *B, int ldb)
DTRSM wrapper.
Definition: blas_lapack.cpp:315
void PartialFactorize(View< T > &Zl, View< T > &Zr, Data< T > &Ul, Data< T > &Ur, Data< T > &Vl, Data< T > &Vr)
Definition: igofmm.hpp:413
void DependencyAnalysis()
Definition: igofmm.hpp:1410
void ULVBackward()
Definition: igofmm.hpp:998
Data< T > U
Definition: igofmm.hpp:1054
void xtrmm(const char *side, const char *uplo, const char *transA, const char *diag, int m, int n, double alpha, double *A, int lda, double *B, int ldb)
DTRMM wrapper.
Definition: blas_lapack.cpp:385
void GetEventRecord()
Definition: igofmm.hpp:1148
void GetEventRecord()
Definition: igofmm.hpp:1817
Definition: runtime.hpp:174
void Partition2x1(View< T > &A1, View< T > &A2, size_t mb, SideType side)
Definition: View.hpp:155
void DependencyAnalysis()
Definition: igofmm.hpp:1372
void GetEventRecord()
Definition: igofmm.hpp:1221
Creates an hierarchical tree view for a matrix.
Definition: igofmm.hpp:1207
void GetEventRecord()
Definition: igofmm.hpp:1254
Definition: thread.hpp:166