HMLP: High-performance Machine Learning Primitives
igofmm_mpi.hpp
1 #ifndef IGOFMM_MPI_HPP
2 #define IGOFMM_MPI_HPP
3 
5 #include <hmlp_mpi.hpp>
7 #include <igofmm.hpp>
9 using namespace std;
10 using namespace hmlp;
11 
12 
13 namespace hmlp
14 {
15 namespace mpigofmm
16 {
17 
18 template<typename NODE, typename T>
19 class DistSetupFactorTask : public Task
20 {
21  public:
22 
23  NODE *arg = NULL;
24 
25  void Set( NODE *user_arg )
26  {
27  arg = user_arg;
28  name = string( "PSF" );
29  label = to_string( arg->treelist_id );
30  };
31 
32  void DependencyAnalysis() { arg->DependOnChildren( this ); };
33 
34  void Execute( Worker *user_worker )
35  {
36  auto *node = arg;
37  auto &data = node->data;
38  auto *setup = node->setup;
39 
41  auto comm = node->GetComm();
42  int size = node->GetCommSize();
43  int rank = node->GetCommRank();
44 
45  if ( size == 1 ) return gofmm::SetupFactor<NODE, T>( arg );
46 
47  size_t n, nl, nr, s, sl, sr;
48  bool issymmetric, do_ulv_factorization;
49 
50  issymmetric = setup->IsSymmetric();
51  do_ulv_factorization = setup->do_ulv_factorization;
52  n = node->n;
53  nl = 0;
54  nr = 0;
55  s = data.skels.size();
56  sl = 0;
57  sr = 0;
58 
59  mpi::Bcast( &n, 1, 0, comm );
60  mpi::Bcast( &s, 1, 0, comm );
61 
62  if ( rank < size / 2 )
63  {
64  nl = node->child->n;
65  sl = node->child->data.s;
66  }
67  else
68  {
69  nr = node->child->n;
70  sr = node->child->data.s;
71  }
72 
73  mpi::Bcast( &nl, 1, 0, comm );
74  mpi::Bcast( &nr, 1, size / 2, comm );
75  mpi::Bcast( &sl, 1, 0, comm );
76  mpi::Bcast( &sr, 1, size / 2, comm );
77 
78  data.SetupFactor( issymmetric, do_ulv_factorization,
79  node->isleaf, !node->l, n, nl, nr, s, sl, sr );
80 
81  //printf( "n %lu nl %lu nr %lu s %lu sl %lu sr %lu\n",
82  // n, nl, nr, s, sl, sr ); fflush( stdout );
83  };
84 };
85 
86 
87 
88 template<typename NODE, typename T>
89 class DistFactorizeTask : public Task
90 {
91  public:
92 
93  NODE *arg = NULL;
94 
95  void Set( NODE *user_arg )
96  {
97  arg = user_arg;
98  name = string( "PULVF" );
99  label = to_string( arg->treelist_id );
101  cost = 5.0;
102  };
103 
104  void DependencyAnalysis() { arg->DependOnChildren( this ); };
105 
106  void Execute( Worker *user_worker )
107  {
108  auto *node = arg;
109  auto &data = node->data;
110  auto *setup = node->setup;
111  auto &K = *setup->K;
112 
114  auto comm = node->GetComm();
115  int size = node->GetCommSize();
116  int rank = node->GetCommRank();
117  mpi::Status status;
118 
119  if ( size == 1 ) return gofmm::Factorize<NODE, T>( node );
120 
121  if ( rank == 0 )
122  {
124  Data<T> Ur, &Ul = node->child->data.U;
125  mpi::RecvData( Ur, size / 2, comm );
126  //printf( "Ur %lux%lu\n", Ur.row(), Ur.col() ); fflush( stdout );
128  Data<T> Vl, Vr;
130  View<T> &Zl = node->child->data.Zbr;
131  Data<T> Zr;
132  mpi::RecvData( Zr, size / 2, comm );
133  View<T> Zrv( Zr );
135  auto &amap = node->child->data.skels;
136  vector<size_t> bmap( data.sr );
137  mpi::Recv( bmap.data(), bmap.size(), size / 2, 20, comm, &status );
138 
140  data.Crl = K( bmap, amap );
141 
142  if ( node->l )
143  {
144  data.Telescope( false, data.U, data.proj, Ul, Ur );
145  data.Orthogonalization();
146  }
147  data.PartialFactorize( Zl, Zrv, Ul, Ur, Vl, Vr );
148  }
149 
150  if ( rank == size / 2 )
151  {
153  auto &Ur = node->child->data.U;
154  mpi::SendData( Ur, 0, comm );
156  auto Zr = node->child->data.Zbr.toData();
157  mpi::SendData( Zr, 0, comm );
159  auto &bmap = node->child->data.skels;
160  mpi::Send( bmap.data(), bmap.size(), 0, 20, comm );
161  }
162  };
163 };
167 template<typename NODE, typename T>
169 {
170  public:
171 
172  NODE *arg = NULL;
173 
174  void Set( NODE *user_arg )
175  {
176  arg = user_arg;
177  name = string( "PTV" );
178  label = to_string( arg->treelist_id );
180  cost = 5.0;
181  };
182 
183  void DependencyAnalysis() { arg->DependOnParent( this ); };
184 
185  void Execute( Worker *user_worker )
186  {
187  auto *node = arg;
188  auto &data = node->data;
189  auto *setup = node->setup;
190  auto &input = *(setup->input);
191  auto &output = *(setup->output);
192 
194  int size = node->GetCommSize();
195  int rank = node->GetCommRank();
196 
198  data.bview.Set( output );
199 
201  if ( size == 1 ) return gofmm::SolverTreeView( arg );
202 
203  if ( rank == 0 )
204  {
206  data.B.resize( data.sl + data.sr, input.col() );
208  data.Bv.Set( data.B );
209  data.Bv.Partition2x1( data.Bf,
210  data.Bc, data.s, BOTTOM );
212  auto &cdata = node->child->data;
213  data.Bv.Partition2x1( cdata.Bp,
214  data.Bsibling, data.sl, TOP );
215  }
216 
217  if ( rank == size / 2 )
218  {
220  data.B.resize( data.sr, input.col() );
222  auto &cdata = node->child->data;
223  cdata.Bp.Set( data.B );
224  }
225 
226  };
227 };
232 template<typename NODE, typename T>
234 {
235  public:
236 
237  NODE *arg = NULL;
238 
239  void Set( NODE *user_arg )
240  {
241  arg = user_arg;
242  name = string( "PULVS1" );
243  label = to_string( arg->treelist_id );
245  cost = 5.0;
247  priority = true;
248  };
249 
250  void DependencyAnalysis() { arg->DependOnChildren( this ); };
251 
252  void Execute( Worker *user_worker )
253  {
254  //printf( "[BEG] level-%lu\n", arg->l ); fflush( stdout );
255  auto *node = arg;
256  auto &data = node->data;
258  auto comm = node->GetComm();
259  int size = node->GetCommSize();
260  int rank = node->GetCommRank();
261  mpi::Status status;
262 
264  if ( size == 1 ) return data.ULVForward();
265 
266  if ( rank == 0 )
267  {
268  auto Br = data.Bsibling.toData();
270  mpi::Recv( Br.data(), Br.size(), size / 2, 0, comm, &status );
272  data.Bsibling.CopyValuesFrom( Br );
274  data.ULVForward();
275  }
276 
277  if ( rank == size / 2 )
278  {
279  auto &Br = data.B;
281  mpi::Send( Br.data(), Br.size(), 0, 0, comm );
282  }
283  //printf( "[END] level-%lu\n", arg->l ); fflush( stdout );
284  };
285 };
288 template<typename NODE, typename T>
290 {
291  public:
292 
293  NODE *arg = NULL;
294 
295  void Set( NODE *user_arg )
296  {
297  arg = user_arg;
298  name = string( "PULVS2" );
299  label = to_string( arg->treelist_id );
301  cost = 5.0;
303  priority = true;
304  };
305 
306  void DependencyAnalysis() { arg->DependOnParent( this ); };
307 
308  void Execute( Worker *user_worker )
309  {
310  //printf( "[BEG] level-%lu\n", arg->l ); fflush( stdout );
311  auto *node = arg;
312  auto &data = node->data;
314  auto comm = node->GetComm();
315  int size = node->GetCommSize();
316  int rank = node->GetCommRank();
317  mpi::Status status;
318 
320  if ( size == 1 ) return data.ULVBackward();
321 
322  if ( rank == 0 )
323  {
325  data.ULVBackward();
327  auto Br = data.Bsibling.toData();
329  mpi::Send( Br.data(), Br.size(), size / 2, 0, comm );
330  }
331 
332  if ( rank == size / 2 )
333  {
334  auto &Br = data.B;
336  mpi::Recv( Br.data(), Br.size(), 0, 0, comm, &status );
337  }
338  //printf( "[END] level-%lu\n", arg->l ); fflush( stdout );
339  };
340 
341 };
350 template<typename T, typename TREE>
351 void DistFactorize( TREE &tree, T lambda )
352 {
353  using NODE = typename TREE::NODE;
354  using MPINODE = typename TREE::MPINODE;
355 
357  tree.setup.lambda = lambda;
359  tree.setup.do_ulv_factorization = true;
360 
362  gofmm::SetupFactorTask<NODE, T> seqSETUPFACTORtask;
363  DistSetupFactorTask<MPINODE, T> parSETUPFACTORtask;
364 
365  mpi::PrintProgress( "[BEG] DistFactorize setup ...\n", tree.GetComm() );
366  tree.DependencyCleanUp();
367  tree.LocaTraverseUp( seqSETUPFACTORtask );
368  tree.DistTraverseUp( parSETUPFACTORtask );
369  tree.ExecuteAllTasks();
370  mpi::PrintProgress( "[END] DistFactorize setup ...\n", tree.GetComm() );
371 
372  mpi::PrintProgress( "[BEG] DistFactorize ...\n", tree.GetComm() );
373  gofmm::FactorizeTask< NODE, T> seqFACTORIZEtask;
374  DistFactorizeTask<MPINODE, T> parFACTORIZEtask;
375  tree.LocaTraverseUp( seqFACTORIZEtask );
376  tree.DistTraverseUp( parFACTORIZEtask );
377  tree.ExecuteAllTasks();
378  mpi::PrintProgress( "[END] DistFactorize ...\n", tree.GetComm() );
379 
380 };
383 template<typename T, typename TREE>
384 void DistSolve( TREE &tree, Data<T> &input )
385 {
386  using NODE = typename TREE::NODE;
387  using MPINODE = typename TREE::MPINODE;
388 
390  tree.setup.input = &input;
391  tree.setup.output = &input;
392 
394  gofmm::SolverTreeViewTask<NODE> seqTREEVIEWtask;
395  DistFactorTreeViewTask<MPINODE, T> parTREEVIEWtask;
398  gofmm::ULVBackwardSolveTask<NODE, T> seqBACKWARDtask;
399  DistULVBackwardSolveTask<MPINODE, T> parBACKWARDtask;
400 
401  tree.DistTraverseDown( parTREEVIEWtask );
402  tree.LocaTraverseDown( seqTREEVIEWtask );
403  tree.LocaTraverseUp( seqFORWARDtask );
404  tree.DistTraverseUp( parFORWARDtask );
405  tree.DistTraverseDown( parBACKWARDtask );
406  tree.LocaTraverseDown( seqBACKWARDtask );
407  mpi::PrintProgress( "[PREP] DistSolve ...\n", tree.GetComm() );
408  tree.ExecuteAllTasks();
409  mpi::PrintProgress( "[DONE] DistSolve ...\n", tree.GetComm() );
410 
411 };
418 template<typename TREE, typename T>
419 void ComputeError( TREE &tree, T lambda, Data<T> weights, Data<T> potentials )
420 {
421  using NODE = typename TREE::NODE;
422  using MPINODE = typename TREE::MPINODE;
423 
424  auto comm = tree.GetComm();
425  auto size = tree.GetCommSize();
426  auto rank = tree.GetCommRank();
427 
428 
429  size_t n = weights.row();
430  size_t nrhs = weights.col();
431 
433  Data<T> rhs( n, nrhs );
434  for ( size_t j = 0; j < nrhs; j ++ )
435  for ( size_t i = 0; i < n; i ++ )
436  rhs( i, j ) = potentials( i, j ) + lambda * weights( i, j );
437 
439  DistSolve( tree, rhs );
440 
441 
443  if ( rank == 0 )
444  {
445  printf( "========================================================\n" );
446  printf( "Inverse accuracy report\n" );
447  printf( "========================================================\n" );
448  printf( "#rhs, max err, @, min err, @, relative \n" );
449  printf( "========================================================\n" );
450  }
451 
452  size_t ntest = 10;
453  T total_err = 0.0;
454  T total_nrm = 0.0;
455  T total_max = 0.0;
456  T total_min = 0.0;
457  for ( size_t j = 0; j < std::min( nrhs, ntest ); j ++ )
458  {
460  T nrm2 = 0.0, err2 = 0.0;
461  T max2 = 0.0, min2 = std::numeric_limits<T>::max();
462 
463  for ( size_t i = 0; i < n; i ++ )
464  {
465  T sse = rhs( i, j ) - weights( i, j );
466  assert( rhs( i, j ) == rhs( i, j ) );
467  sse = sse * sse;
468  nrm2 += weights( i, j ) * weights( i, j );
469  err2 += sse;
470  max2 = std::max( max2, sse );
471  min2 = std::min( min2, sse );
472  }
473 
474  mpi::Allreduce( &err2, &total_err, 1, MPI_SUM, comm );
475  mpi::Allreduce( &nrm2, &total_nrm, 1, MPI_SUM, comm );
476  mpi::Allreduce( &max2, &total_max, 1, MPI_MAX, comm );
477  mpi::Allreduce( &min2, &total_min, 1, MPI_MIN, comm );
478 
479  total_err += std::sqrt( total_err / total_nrm );
480  total_max = std::sqrt( total_max );
481  total_min = std::sqrt( total_min );
482 
483  if ( rank == 0 )
484  {
485  printf( "%4lu, %3.1E, %7lu, %3.1E, %7lu, %3.1E\n",
486  j, total_max, (size_t)0, total_min, (size_t)0, total_err );
487  }
488  }
489  if ( rank == 0 )
490  {
491  printf( "========================================================\n" );
492  }
493 };
494 
495 
496 };
497 };
499 #endif
Definition: mpi_prototypes.h:81
Definition: igofmm.hpp:1345
void Execute(Worker *user_worker)
Definition: igofmm_mpi.hpp:106
Definition: igofmm.hpp:1802
void Set(NODE *user_arg)
Definition: igofmm_mpi.hpp:239
void Execute(Worker *user_worker)
Definition: igofmm_mpi.hpp:34
void Set(NODE *user_arg)
Definition: igofmm_mpi.hpp:295
Definition: igofmm.hpp:1134
T * data()
Definition: View.hpp:354
void Execute(Worker *user_worker)
Definition: igofmm_mpi.hpp:185
Definition: igofmm_mpi.hpp:168
size_t col() const noexcept
Definition: Data.hpp:281
Definition: igofmm_mpi.hpp:289
Definition: igofmm_mpi.hpp:19
size_t row() const noexcept
Definition: Data.hpp:278
Definition: igofmm.hpp:1383
Definition: View.hpp:43
void Execute(Worker *user_worker)
Definition: igofmm_mpi.hpp:252
Definition: igofmm_mpi.hpp:233
Definition: igofmm_mpi.hpp:89
void Execute(Worker *user_worker)
Definition: igofmm_mpi.hpp:308
void Set(NODE *user_arg)
Definition: igofmm_mpi.hpp:174
void Set(NODE *user_arg)
Definition: igofmm_mpi.hpp:95
Definition: gofmm.hpp:83
Definition: runtime.hpp:174
Creates an hierarchical tree view for a matrix.
Definition: igofmm.hpp:1207
Definition: thread.hpp:166