5 #include <hmlp_mpi.hpp> 18 template<
typename NODE,
typename T>
25 void Set( NODE *user_arg )
28 name = string(
"PSF" );
29 label = to_string( arg->treelist_id );
32 void DependencyAnalysis() { arg->DependOnChildren(
this ); };
37 auto &data = node->data;
38 auto *setup = node->setup;
41 auto comm = node->GetComm();
42 int size = node->GetCommSize();
43 int rank = node->GetCommRank();
45 if ( size == 1 )
return gofmm::SetupFactor<NODE, T>( arg );
47 size_t n, nl, nr, s, sl, sr;
48 bool issymmetric, do_ulv_factorization;
50 issymmetric = setup->IsSymmetric();
51 do_ulv_factorization = setup->do_ulv_factorization;
55 s = data.skels.size();
59 mpi::Bcast( &n, 1, 0, comm );
60 mpi::Bcast( &s, 1, 0, comm );
62 if ( rank < size / 2 )
65 sl = node->child->data.s;
70 sr = node->child->data.s;
73 mpi::Bcast( &nl, 1, 0, comm );
74 mpi::Bcast( &nr, 1, size / 2, comm );
75 mpi::Bcast( &sl, 1, 0, comm );
76 mpi::Bcast( &sr, 1, size / 2, comm );
78 data.SetupFactor( issymmetric, do_ulv_factorization,
79 node->isleaf, !node->l, n, nl, nr, s, sl, sr );
88 template<
typename NODE,
typename T>
95 void Set( NODE *user_arg )
98 name = string(
"PULVF" );
99 label = to_string( arg->treelist_id );
104 void DependencyAnalysis() { arg->DependOnChildren(
this ); };
109 auto &data = node->data;
110 auto *setup = node->setup;
114 auto comm = node->GetComm();
115 int size = node->GetCommSize();
116 int rank = node->GetCommRank();
119 if ( size == 1 )
return gofmm::Factorize<NODE, T>( node );
124 Data<T> Ur, &Ul = node->child->data.U;
125 mpi::RecvData( Ur, size / 2, comm );
132 mpi::RecvData( Zr, size / 2, comm );
135 auto &amap = node->child->data.skels;
136 vector<size_t> bmap( data.sr );
137 mpi::Recv( bmap.data(), bmap.size(), size / 2, 20, comm, &status );
140 data.Crl = K( bmap, amap );
144 data.Telescope(
false, data.U, data.proj, Ul, Ur );
145 data.Orthogonalization();
147 data.PartialFactorize( Zl, Zrv, Ul, Ur, Vl, Vr );
150 if ( rank == size / 2 )
153 auto &Ur = node->child->data.U;
154 mpi::SendData( Ur, 0, comm );
156 auto Zr = node->child->data.Zbr.toData();
157 mpi::SendData( Zr, 0, comm );
159 auto &bmap = node->child->data.skels;
160 mpi::Send( bmap.data(), bmap.size(), 0, 20, comm );
167 template<
typename NODE,
typename T>
174 void Set( NODE *user_arg )
177 name = string(
"PTV" );
178 label = to_string( arg->treelist_id );
183 void DependencyAnalysis() { arg->DependOnParent(
this ); };
188 auto &data = node->data;
189 auto *setup = node->setup;
190 auto &input = *(setup->input);
191 auto &output = *(setup->output);
194 int size = node->GetCommSize();
195 int rank = node->GetCommRank();
198 data.bview.Set( output );
201 if ( size == 1 )
return gofmm::SolverTreeView( arg );
206 data.B.resize( data.sl + data.sr, input.col() );
208 data.Bv.Set( data.B );
209 data.Bv.Partition2x1( data.Bf,
210 data.Bc, data.s, BOTTOM );
212 auto &cdata = node->child->data;
213 data.Bv.Partition2x1( cdata.Bp,
214 data.Bsibling, data.sl, TOP );
217 if ( rank == size / 2 )
220 data.B.resize( data.sr, input.col() );
222 auto &cdata = node->child->data;
223 cdata.Bp.Set( data.B );
232 template<
typename NODE,
typename T>
239 void Set( NODE *user_arg )
242 name = string(
"PULVS1" );
243 label = to_string( arg->treelist_id );
250 void DependencyAnalysis() { arg->DependOnChildren(
this ); };
256 auto &data = node->data;
258 auto comm = node->GetComm();
259 int size = node->GetCommSize();
260 int rank = node->GetCommRank();
264 if ( size == 1 )
return data.ULVForward();
268 auto Br = data.Bsibling.toData();
270 mpi::Recv( Br.data(), Br.size(), size / 2, 0, comm, &status );
272 data.Bsibling.CopyValuesFrom( Br );
277 if ( rank == size / 2 )
281 mpi::Send( Br.data(), Br.size(), 0, 0, comm );
288 template<
typename NODE,
typename T>
295 void Set( NODE *user_arg )
298 name = string(
"PULVS2" );
299 label = to_string( arg->treelist_id );
306 void DependencyAnalysis() { arg->DependOnParent(
this ); };
312 auto &data = node->data;
314 auto comm = node->GetComm();
315 int size = node->GetCommSize();
316 int rank = node->GetCommRank();
320 if ( size == 1 )
return data.ULVBackward();
327 auto Br = data.Bsibling.toData();
329 mpi::Send( Br.data(), Br.size(), size / 2, 0, comm );
332 if ( rank == size / 2 )
336 mpi::Recv( Br.data(), Br.size(), 0, 0, comm, &status );
350 template<
typename T,
typename TREE>
351 void DistFactorize( TREE &tree, T lambda )
353 using NODE =
typename TREE::NODE;
354 using MPINODE =
typename TREE::MPINODE;
357 tree.setup.lambda = lambda;
359 tree.setup.do_ulv_factorization =
true;
365 mpi::PrintProgress(
"[BEG] DistFactorize setup ...\n", tree.GetComm() );
366 tree.DependencyCleanUp();
367 tree.LocaTraverseUp( seqSETUPFACTORtask );
368 tree.DistTraverseUp( parSETUPFACTORtask );
369 tree.ExecuteAllTasks();
370 mpi::PrintProgress(
"[END] DistFactorize setup ...\n", tree.GetComm() );
372 mpi::PrintProgress(
"[BEG] DistFactorize ...\n", tree.GetComm() );
375 tree.LocaTraverseUp( seqFACTORIZEtask );
376 tree.DistTraverseUp( parFACTORIZEtask );
377 tree.ExecuteAllTasks();
378 mpi::PrintProgress(
"[END] DistFactorize ...\n", tree.GetComm() );
383 template<
typename T,
typename TREE>
384 void DistSolve( TREE &tree,
Data<T> &input )
386 using NODE =
typename TREE::NODE;
387 using MPINODE =
typename TREE::MPINODE;
390 tree.setup.input = &input;
391 tree.setup.output = &input;
401 tree.DistTraverseDown( parTREEVIEWtask );
402 tree.LocaTraverseDown( seqTREEVIEWtask );
403 tree.LocaTraverseUp( seqFORWARDtask );
404 tree.DistTraverseUp( parFORWARDtask );
405 tree.DistTraverseDown( parBACKWARDtask );
406 tree.LocaTraverseDown( seqBACKWARDtask );
407 mpi::PrintProgress(
"[PREP] DistSolve ...\n", tree.GetComm() );
408 tree.ExecuteAllTasks();
409 mpi::PrintProgress(
"[DONE] DistSolve ...\n", tree.GetComm() );
418 template<
typename TREE,
typename T>
419 void ComputeError( TREE &tree, T lambda,
Data<T> weights,
Data<T> potentials )
421 using NODE =
typename TREE::NODE;
422 using MPINODE =
typename TREE::MPINODE;
424 auto comm = tree.GetComm();
425 auto size = tree.GetCommSize();
426 auto rank = tree.GetCommRank();
429 size_t n = weights.
row();
430 size_t nrhs = weights.
col();
434 for (
size_t j = 0; j < nrhs; j ++ )
435 for (
size_t i = 0; i < n; i ++ )
436 rhs( i, j ) = potentials( i, j ) + lambda * weights( i, j );
439 DistSolve( tree, rhs );
445 printf(
"========================================================\n" );
446 printf(
"Inverse accuracy report\n" );
447 printf(
"========================================================\n" );
448 printf(
"#rhs, max err, @, min err, @, relative \n" );
449 printf(
"========================================================\n" );
457 for (
size_t j = 0; j < std::min( nrhs, ntest ); j ++ )
460 T nrm2 = 0.0, err2 = 0.0;
461 T max2 = 0.0, min2 = std::numeric_limits<T>::max();
463 for (
size_t i = 0; i < n; i ++ )
465 T sse = rhs( i, j ) - weights( i, j );
466 assert( rhs( i, j ) == rhs( i, j ) );
468 nrm2 += weights( i, j ) * weights( i, j );
470 max2 = std::max( max2, sse );
471 min2 = std::min( min2, sse );
474 mpi::Allreduce( &err2, &total_err, 1, MPI_SUM, comm );
475 mpi::Allreduce( &nrm2, &total_nrm, 1, MPI_SUM, comm );
476 mpi::Allreduce( &max2, &total_max, 1, MPI_MAX, comm );
477 mpi::Allreduce( &min2, &total_min, 1, MPI_MIN, comm );
479 total_err += std::sqrt( total_err / total_nrm );
480 total_max = std::sqrt( total_max );
481 total_min = std::sqrt( total_min );
485 printf(
"%4lu, %3.1E, %7lu, %3.1E, %7lu, %3.1E\n",
486 j, total_max, (
size_t)0, total_min, (
size_t)0, total_err );
491 printf(
"========================================================\n" );
Definition: mpi_prototypes.h:81
Definition: igofmm.hpp:1345
void Execute(Worker *user_worker)
Definition: igofmm_mpi.hpp:106
Definition: igofmm.hpp:1802
void Set(NODE *user_arg)
Definition: igofmm_mpi.hpp:239
void Execute(Worker *user_worker)
Definition: igofmm_mpi.hpp:34
void Set(NODE *user_arg)
Definition: igofmm_mpi.hpp:295
Definition: igofmm.hpp:1134
T * data()
Definition: View.hpp:354
void Execute(Worker *user_worker)
Definition: igofmm_mpi.hpp:185
Definition: igofmm_mpi.hpp:168
size_t col() const noexcept
Definition: Data.hpp:281
Definition: igofmm_mpi.hpp:289
Definition: igofmm_mpi.hpp:19
size_t row() const noexcept
Definition: Data.hpp:278
Definition: igofmm.hpp:1383
void Execute(Worker *user_worker)
Definition: igofmm_mpi.hpp:252
Definition: igofmm_mpi.hpp:233
Definition: igofmm_mpi.hpp:89
void Execute(Worker *user_worker)
Definition: igofmm_mpi.hpp:308
void Set(NODE *user_arg)
Definition: igofmm_mpi.hpp:174
void Set(NODE *user_arg)
Definition: igofmm_mpi.hpp:95
Definition: runtime.hpp:174
Creates an hierarchical tree view for a matrix.
Definition: igofmm.hpp:1207
Definition: thread.hpp:166