HMLP: High-performance Machine Learning Primitives
gofmm_mpi.hpp
1 
22 #ifndef GOFMM_MPI_HPP
23 #define GOFMM_MPI_HPP
24 
26 #include <gofmm.hpp>
28 #include <tree_mpi.hpp>
29 #include <igofmm_mpi.hpp>
31 //#include <DistData.hpp>
33 using namespace std;
34 using namespace hmlp;
35 
36 
37 namespace hmlp
38 {
39 namespace mpigofmm
40 {
41 
42 
44 // * @biref This class does not have to inherit DistData, but it have to
45 // * inherit DistVirtualMatrix<T>
46 // *
47 // */
48 //template<typename T>
49 //class DistSPDMatrix : public DistData<STAR, CBLK, T>
50 //{
51 // public:
52 //
53 // DistSPDMatrix( size_t m, size_t n, mpi::Comm comm ) :
54 // DistData<STAR, CBLK, T>( m, n, comm )
55 // {
56 // };
57 //
58 //
59 // /** ESSENTIAL: this is an abstract function */
60 // virtual T operator()( size_t i, size_t j, mpi::Comm comm )
61 // {
62 // T Kij = 0;
63 //
64 // /** MPI */
65 // int size, rank;
66 // hmlp::mpi::Comm_size( comm, &size );
67 // hmlp::mpi::Comm_rank( comm, &rank );
68 //
69 // std::vector<std::vector<size_t>> sendrids( size );
70 // std::vector<std::vector<size_t>> recvrids( size );
71 // std::vector<std::vector<size_t>> sendcids( size );
72 // std::vector<std::vector<size_t>> recvcids( size );
73 //
74 // /** request Kij from rank ( j % size ) */
75 // sendrids[ i % size ].push_back( i );
76 // sendcids[ j % size ].push_back( j );
77 //
78 // /** exchange ids */
79 // mpi::AlltoallVector( sendrids, recvrids, comm );
80 // mpi::AlltoallVector( sendcids, recvcids, comm );
81 //
82 // /** allocate buffer for data */
83 // std::vector<std::vector<T>> senddata( size );
84 // std::vector<std::vector<T>> recvdata( size );
85 //
86 // /** fetch subrows */
87 // for ( size_t p = 0; p < size; p ++ )
88 // {
89 // assert( recvrids[ p ].size() == recvcids[ p ].size() );
90 // for ( size_t j = 0; j < recvcids[ p ].size(); j ++ )
91 // {
92 // size_t rid = recvrids[ p ][ j ];
93 // size_t cid = recvcids[ p ][ j ];
94 // senddata[ p ].push_back( (*this)( rid, cid ) );
95 // }
96 // }
97 //
98 // /** exchange data */
99 // mpi::AlltoallVector( senddata, recvdata, comm );
100 //
101 // for ( size_t p = 0; p < size; p ++ )
102 // {
103 // assert( recvdata[ p ].size() <= 1 );
104 // if ( recvdata[ p ] ) Kij = recvdata[ p ][ 0 ];
105 // }
106 //
107 // return Kij;
108 // };
109 //
110 //
111 // /** ESSENTIAL: return a submatrix */
112 // virtual hmlp::Data<T> operator()
113 // ( std::vector<size_t> &imap, std::vector<size_t> &jmap, hmlp::mpi::Comm comm )
114 // {
115 // hmlp::Data<T> KIJ( imap.size(), jmap.size() );
116 //
117 // /** MPI */
118 // int size, rank;
119 // hmlp::mpi::Comm_size( comm, &size );
120 // hmlp::mpi::Comm_rank( comm, &rank );
121 //
122 //
123 //
124 // std::vector<std::vector<size_t>> jmapcids( size );
125 //
126 // std::vector<std::vector<size_t>> sendrids( size );
127 // std::vector<std::vector<size_t>> recvrids( size );
128 // std::vector<std::vector<size_t>> sendcids( size );
129 // std::vector<std::vector<size_t>> recvcids( size );
130 //
131 // /** request KIJ from rank ( j % size ) */
132 // for ( size_t j = 0; j < jmap.size(); j ++ )
133 // {
134 // size_t cid = jmap[ j ];
135 // sendcids[ cid % size ].push_back( cid );
136 // jmapcids[ cid % size ].push_back( j );
137 // }
138 //
139 // for ( size_t p = 0; p < size; p ++ )
140 // {
141 // if ( sendcids[ p ].size() ) sendrids[ p ] = imap;
142 // }
143 //
144 // /** exchange ids */
145 // mpi::AlltoallVector( sendrids, recvrids, comm );
146 // mpi::AlltoallVector( sendcids, recvcids, comm );
147 //
148 // /** allocate buffer for data */
149 // std::vector<hmlp::Data<T>> senddata( size );
150 // std::vector<hmlp::Data<T>> recvdata( size );
151 //
152 // /** fetch submatrix */
153 // for ( size_t p = 0; p < size; p ++ )
154 // {
155 // if ( recvcids[ p ].size() && recvrids[ p ].size() )
156 // {
157 // senddata[ p ] = (*this)( recvrids[ p ], recvcids[ p ] );
158 // }
159 // }
160 //
161 // /** exchange data */
162 // mpi::AlltoallVector( senddata, recvdata, comm );
163 //
164 // /** merging data */
165 // for ( size_t p = 0; j < size; p ++ )
166 // {
167 // assert( recvdata[ p ].size() == imap.size() * recvcids[ p ].size() );
168 // recvdata[ p ].resize( imap.size(), recvcids[ p ].size() );
169 // for ( size_t j = 0; j < recvcids[ p ]; i ++ )
170 // {
171 // for ( size_t i = 0; i < imap.size(); i ++ )
172 // {
173 // KIJ( i, jmapcids[ p ][ j ] ) = recvdata[ p ]( i, j );
174 // }
175 // }
176 // };
177 //
178 // return KIJ;
179 // };
180 //
181 //
182 //
183 //
184 //
185 // virtual hmlp::Data<T> operator()
186 // ( std::vector<int> &imap, std::vector<int> &jmap, hmlp::mpi::Comm comm )
187 // {
188 // printf( "operator() not implemented yet\n" );
189 // exit( 1 );
190 // };
191 //
192 //
193 //
194 // /** overload operator */
195 //
196 //
197 // private:
198 //
199 //}; /** end class DistSPDMatrix */
200 //
201 //
202 
203 
208 template<typename SPDMATRIX, typename SPLITTER, typename T>
209 class Setup : public mpitree::Setup<SPLITTER, T>,
210  public gofmm::Configuration<T>
211 {
212  public:
213 
216  SPDMATRIX &K, SPLITTER &splitter,
217  DistData<STAR, CBLK, pair<T, size_t>>* NN_cblk )
218  {
219  this->CopyFrom( config );
220  this->K = &K;
221  this->splitter = splitter;
222  this->NN_cblk = NN_cblk;
223  };
224 
226  SPDMATRIX *K = NULL;
227 
229  Data<T> *w = NULL;
230  Data<T> *u = NULL;
231 
233  Data<T> *input = NULL;
234  Data<T> *output = NULL;
235 
237  T lambda = 0.0;
238 
240  //bool issymmetric = true;
241 
243  bool do_ulv_factorization = true;
244 
245 
246  private:
247 
248 };
258 template<typename NODE>
259 class DistTreeViewTask : public Task
260 {
261  public:
262 
263  NODE *arg = NULL;
264 
265  void Set( NODE *user_arg )
266  {
267  arg = user_arg;
268  name = string( "TreeView" );
269  label = to_string( arg->treelist_id );
270  cost = 1.0;
271  };
272 
274  void DependencyAnalysis() { arg->DependOnParent( this ); };
275 
276  void Execute( Worker* user_worker )
277  {
278  auto *node = arg;
279 
281  auto &w = *(node->setup->w);
282  auto &u = *(node->setup->u);
283 
285  auto &U = node->data.u_view;
286  auto &W = node->data.w_view;
287 
289  U.Set( u );
290  W.Set( w );
291 
293  if ( !node->isleaf && !node->child )
294  {
295  assert( node->lchild && node->rchild );
296  auto &UL = node->lchild->data.u_view;
297  auto &UR = node->rchild->data.u_view;
298  auto &WL = node->lchild->data.w_view;
299  auto &WR = node->rchild->data.w_view;
304  U.Partition2x1( UL,
305  UR, node->lchild->n, TOP );
306  W.Partition2x1( WL,
307  WR, node->lchild->n, TOP );
308  }
309  };
310 
311 };
322 template<typename T>
323 vector<vector<size_t>> DistMedianSplit( vector<T> &values, mpi::Comm comm )
324 {
325  int n = 0;
326  int num_points_owned = values.size();
328  mpi::Allreduce( &num_points_owned, &n, 1, MPI_SUM, comm );
329  T median = combinatorics::Select( n / 2, values, comm );
330 
331  vector<vector<size_t>> split( 2 );
332  vector<size_t> middle;
333 
334  if ( n == 0 ) return split;
335 
336  for ( size_t i = 0; i < values.size(); i ++ )
337  {
338  auto v = values[ i ];
339  if ( std::fabs( v - median ) < 1E-6 ) middle.push_back( i );
340  else if ( v < median ) split[ 0 ].push_back( i );
341  else split[ 1 ].push_back( i );
342  }
343 
344  int nmid = 0;
345  int nlhs = 0;
346  int nrhs = 0;
347  int num_mid_owned = middle.size();
348  int num_lhs_owned = split[ 0 ].size();
349  int num_rhs_owned = split[ 1 ].size();
350 
352  mpi::Allreduce( &num_mid_owned, &nmid, 1, MPI_SUM, comm );
353  mpi::Allreduce( &num_lhs_owned, &nlhs, 1, MPI_SUM, comm );
354  mpi::Allreduce( &num_rhs_owned, &nrhs, 1, MPI_SUM, comm );
355 
357  if ( nmid )
358  {
359  int nlhs_required, nrhs_required;
360 
361  if ( nlhs > nrhs )
362  {
363  nlhs_required = ( n - 1 ) / 2 + 1 - nlhs;
364  nrhs_required = nmid - nlhs_required;
365  }
366  else
367  {
368  nrhs_required = ( n - 1 ) / 2 + 1 - nrhs;
369  nlhs_required = nmid - nrhs_required;
370  }
371 
372  assert( nlhs_required >= 0 && nrhs_required >= 0 );
373 
375  double lhs_ratio = ( (double)nlhs_required ) / nmid;
376  int nlhs_required_owned = num_mid_owned * lhs_ratio;
377  int nrhs_required_owned = num_mid_owned - nlhs_required_owned;
378 
379  //printf( "rank %d [ %d %d ] [ %d %d ]\n",
380  // global_rank,
381  // nlhs_required_owned, nlhs_required,
382  // nrhs_required_owned, nrhs_required ); fflush( stdout );
383 
384  assert( nlhs_required_owned >= 0 && nrhs_required_owned >= 0 );
385 
386  for ( size_t i = 0; i < middle.size(); i ++ )
387  {
388  if ( i < nlhs_required_owned )
389  split[ 0 ].push_back( middle[ i ] );
390  else
391  split[ 1 ].push_back( middle[ i ] );
392  }
393  }
394 
395  return split;
396 };
407 template<typename SPDMATRIX, int N_SPLIT, typename T>
408 struct centersplit : public gofmm::centersplit<SPDMATRIX, N_SPLIT, T>
409 {
410 
412 
413  centersplit( SPDMATRIX& K ) : gofmm::centersplit<SPDMATRIX, N_SPLIT, T>( K ) {};
414 
416  inline vector<vector<size_t> > operator() ( vector<size_t>& gids ) const
417  {
419  };
420 
422  inline vector<vector<size_t> > operator() ( vector<size_t>& gids, mpi::Comm comm ) const
423  {
425  assert( N_SPLIT == 2 );
426  assert( this->Kptr );
427 
429  int size; mpi::Comm_size( comm, &size );
430  int rank; mpi::Comm_rank( comm, &rank );
431  auto &K = *(this->Kptr);
432 
434  vector<T> temp( gids.size(), 0.0 );
435 
437  auto column_samples = combinatorics::SampleWithoutReplacement(
438  this->n_centroid_samples, gids );
439 
441  mpi::Bcast( column_samples.data(), column_samples.size(), 0, comm );
442  K.BcastIndices( column_samples, 0, comm );
443 
445  auto DIC = K.Distances( this->metric, gids, column_samples );
446 
448  for ( auto & it : temp ) it = 0;
449 
451  for ( size_t j = 0; j < DIC.col(); j ++ )
452  for ( size_t i = 0; i < DIC.row(); i ++ )
453  temp[ i ] += DIC( i, j );
454 
456  auto idf2c = distance( temp.begin(), max_element( temp.begin(), temp.end() ) );
457 
459  mpi::NumberIntPair<T> local_max_pair, max_pair;
460  local_max_pair.val = temp[ idf2c ];
461  local_max_pair.key = rank;
462 
464  mpi::Allreduce( &local_max_pair, &max_pair, 1, MPI_MAXLOC, comm );
465 
467  int gidf2c = gids[ idf2c ];
468  mpi::Bcast( &gidf2c, 1, MPI_INT, max_pair.key, comm );
469 
470 
471  //printf( "rank %d val %E key %d; global val %E key %d\n",
472  // rank, local_max_pair.val, local_max_pair.key,
473  // max_pair.val, max_pair.key ); fflush( stdout );
474  //printf( "rank %d gidf2c %d\n", rank, gidf2c ); fflush( stdout );
475 
477  vector<size_t> P( 1, gidf2c );
478  K.BcastIndices( P, max_pair.key, comm );
479 
481  auto DIP = K.Distances( this->metric, gids, P );
482 
484  auto idf2f = distance( DIP.begin(), max_element( DIP.begin(), DIP.end() ) );
485 
487  local_max_pair.val = DIP[ idf2f ];
488  local_max_pair.key = rank;
489 
491  mpi::Allreduce( &local_max_pair, &max_pair, 1, MPI_MAXLOC, comm );
492 
494  int gidf2f = gids[ idf2f ];
495  mpi::Bcast( &gidf2f, 1, MPI_INT, max_pair.key, comm );
496 
497  //printf( "rank %d val %E key %d; global val %E key %d\n",
498  // rank, local_max_pair.val, local_max_pair.key,
499  // max_pair.val, max_pair.key ); fflush( stdout );
500  //printf( "rank %d gidf2f %d\n", rank, gidf2f ); fflush( stdout );
501 
503  vector<size_t> Q( 1, gidf2f );
504  K.BcastIndices( Q, max_pair.key, comm );
505 
507  auto DIQ = K.Distances( this->metric, gids, P );
508 
510  for ( size_t i = 0; i < temp.size(); i ++ )
511  temp[ i ] = DIP[ i ] - DIQ[ i ];
512 
514  auto split = DistMedianSplit( temp, comm );
515 
517  mpi::Status status;
518  vector<size_t> sent_gids;
519  int partner = ( rank + size / 2 ) % size;
520  if ( rank < size / 2 )
521  {
522  for ( auto it : split[ 1 ] )
523  sent_gids.push_back( gids[ it ] );
524  K.SendIndices( sent_gids, partner, comm );
525  K.RecvIndices( partner, comm, &status );
526  }
527  else
528  {
529  for ( auto it : split[ 0 ] )
530  sent_gids.push_back( gids[ it ] );
531  K.RecvIndices( partner, comm, &status );
532  K.SendIndices( sent_gids, partner, comm );
533  }
534 
535  return split;
536  };
537 
538 
539 };
545 template<typename SPDMATRIX, int N_SPLIT, typename T>
546 struct randomsplit : public gofmm::randomsplit<SPDMATRIX, N_SPLIT, T>
547 {
548 
550 
551  randomsplit( SPDMATRIX& K ) : gofmm::randomsplit<SPDMATRIX, N_SPLIT, T>( K ) {};
552 
554  inline vector<vector<size_t> > operator() ( vector<size_t>& gids ) const
555  {
557  };
558 
560  inline vector<vector<size_t> > operator() ( vector<size_t>& gids, mpi::Comm comm ) const
561  {
563  assert( N_SPLIT == 2 );
564  assert( this->Kptr );
565 
567  int size, rank, global_rank, global_size;
568  mpi::Comm_size( comm, &size );
569  mpi::Comm_rank( comm, &rank );
570  mpi::Comm_rank( MPI_COMM_WORLD, &global_rank );
571  mpi::Comm_size( MPI_COMM_WORLD, &global_size );
572  SPDMATRIX &K = *(this->Kptr);
573  //vector<vector<size_t>> split( N_SPLIT );
574 
575  if ( size == global_size )
576  {
577  for ( size_t i = 0; i < gids.size(); i ++ )
578  assert( gids[ i ] == i * size + rank );
579  }
580 
581 
582 
583 
585  int n = 0;
586  int num_points_owned = gids.size();
587  vector<T> temp( gids.size(), 0.0 );
588 
590  mpi::Allreduce( &num_points_owned, &n, 1, MPI_INT, MPI_SUM, comm );
591 
593  //if ( n == 0 ) return split;
594 
596  size_t gidf2c, gidf2f;
597  if ( gids.size() )
598  {
599  gidf2c = gids[ std::rand() % gids.size() ];
600  gidf2f = gids[ std::rand() % gids.size() ];
601  }
602 
604  mpi::NumberIntPair<T> local_max_pair, max_pair;
605  local_max_pair.val = gids.size();
606  local_max_pair.key = rank;
607 
609  mpi::Allreduce( &local_max_pair, &max_pair, 1, MPI_MAXLOC, comm );
610 
612  mpi::Bcast( &gidf2c, 1, max_pair.key, comm );
613  vector<size_t> P( 1, gidf2c );
614  K.BcastIndices( P, max_pair.key, comm );
615 
617  if ( rank == max_pair.key ) local_max_pair.val = 0;
618 
620  mpi::Allreduce( &local_max_pair, &max_pair, 1, MPI_MAXLOC, comm );
621 
623  mpi::Bcast( &gidf2f, 1, max_pair.key, comm );
624  vector<size_t> Q( 1, gidf2f );
625  K.BcastIndices( Q, max_pair.key, comm );
626 
627 
628  auto DIP = K.Distances( this->metric, gids, P );
629  auto DIQ = K.Distances( this->metric, gids, Q );
630 
632  for ( size_t i = 0; i < temp.size(); i ++ )
633  temp[ i ] = DIP[ i ] - DIQ[ i ];
634 
636  auto split = DistMedianSplit( temp, comm );
637 
639  mpi::Status status;
640  vector<size_t> sent_gids;
641  int partner = ( rank + size / 2 ) % size;
642  if ( rank < size / 2 )
643  {
644  for ( auto it : split[ 1 ] )
645  sent_gids.push_back( gids[ it ] );
646  K.SendIndices( sent_gids, partner, comm );
647  K.RecvIndices( partner, comm, &status );
648  }
649  else
650  {
651  for ( auto it : split[ 0 ] )
652  sent_gids.push_back( gids[ it ] );
653  K.RecvIndices( partner, comm, &status );
654  K.SendIndices( sent_gids, partner, comm );
655  }
656 
657  return split;
658  };
659 
660 
661 };
688 template<typename NODE>
689 void DistUpdateWeights( NODE *node )
690 {
692  using T = typename NODE::T;
694  mpi::Status status;
695  auto comm = node->GetComm();
696  int size = node->GetCommSize();
697  int rank = node->GetCommRank();
698 
700  if ( !node->parent || !node->data.isskel ) return;
701 
702  if ( size < 2 )
703  {
705  gofmm::UpdateWeights( node );
706  }
707  else
708  {
710  auto &w = *node->setup->w;
711  size_t nrhs = w.col();
712 
714  auto &data = node->data;
715  auto &proj = data.proj;
716  auto &w_skel = data.w_skel;
717 
719  if ( rank == 0 )
720  {
721  size_t s = proj.row();
722  size_t sl = node->child->data.skels.size();
723  size_t sr = proj.col() - sl;
725  w_skel.resize( s, nrhs );
727  View<T> P( false, proj ), PL, PR;
728  View<T> W( false, w_skel ), WL( false, node->child->data.w_skel );
730  P.Partition1x2( PL, PR, sl, LEFT );
732  gemm::xgemm<GEMM_NB>( (T)1.0, PL, WL, (T)0.0, W );
733 
734  Data<T> w_skel_sib;
735  mpi::ExchangeVector( w_skel, size / 2, 0, w_skel_sib, size / 2, 0, comm, &status );
737  #pragma omp parallel for
738  for ( size_t i = 0; i < w_skel.size(); i ++ )
739  w_skel[ i ] += w_skel_sib[ i ];
740  }
741 
743  if ( rank == size / 2 )
744  {
745  size_t s = proj.row();
746  size_t sr = node->child->data.skels.size();
747  size_t sl = proj.col() - sr;
749  w_skel.resize( s, nrhs );
751  View<T> P( false, proj ), PL, PR;
752  View<T> W( false, w_skel ), WR( false, node->child->data.w_skel );
754  P.Partition1x2( PL, PR, sl, LEFT );
756  gemm::xgemm<GEMM_NB>( (T)1.0, PR, WR, (T)0.0, W );
757 
758 
759  Data<T> w_skel_sib;
760  mpi::ExchangeVector( w_skel, 0, 0, w_skel_sib, 0, 0, comm, &status );
761  w_skel.clear();
762  }
763  }
764 };
772 template<typename NODE, typename T>
774 {
775  public:
776 
777  NODE *arg = NULL;
778 
779  void Set( NODE *user_arg )
780  {
781  arg = user_arg;
782  name = string( "DistN2S" );
783  label = to_string( arg->treelist_id );
784 
786  double flops = 0.0, mops = 0.0;
787  auto &gids = arg->gids;
788  auto &skels = arg->data.skels;
789  auto &w = *arg->setup->w;
790 
791  if ( !arg->child )
792  {
793  if ( arg->isleaf )
794  {
795  auto m = skels.size();
796  auto n = w.col();
797  auto k = gids.size();
798  flops = 2.0 * m * n * k;
799  mops = 2.0 * ( m * n + m * k + k * n );
800  }
801  else
802  {
803  auto &lskels = arg->lchild->data.skels;
804  auto &rskels = arg->rchild->data.skels;
805  auto m = skels.size();
806  auto n = w.col();
807  auto k = lskels.size() + rskels.size();
808  flops = 2.0 * m * n * k;
809  mops = 2.0 * ( m * n + m * k + k * n );
810  }
811  }
812  else
813  {
814  if ( arg->GetCommRank() == 0 )
815  {
816  auto &lskels = arg->child->data.skels;
817  auto m = skels.size();
818  auto n = w.col();
819  auto k = lskels.size();
820  flops = 2.0 * m * n * k;
821  mops = 2.0 * ( m * n + m * k + k * n );
822  }
823  if ( arg->GetCommRank() == arg->GetCommSize() / 2 )
824  {
825  auto &rskels = arg->child->data.skels;
826  auto m = skels.size();
827  auto n = w.col();
828  auto k = rskels.size();
829  flops = 2.0 * m * n * k;
830  mops = 2.0 * ( m * n + m * k + k * n );
831  }
832  }
833 
835  event.Set( label + name, flops, mops );
837  cost = flops / 1E+9;
839  priority = true;
840  };
841 
842  void DependencyAnalysis() { arg->DependOnChildren( this ); };
843 
844  void Execute( Worker* user_worker ) { DistUpdateWeights( arg ); };
845 
846 };
854 //template<bool NNPRUNE, typename NODE, typename T>
855 //class DistSkeletonsToSkeletonsTask : public Task
856 //{
857 // public:
858 //
859 // NODE *arg = NULL;
860 //
861 // void Set( NODE *user_arg )
862 // {
863 // arg = user_arg;
864 // name = string( "DistS2S" );
865 // label = to_string( arg->treelist_id );
866 // /** compute flops and mops */
867 // double flops = 0.0, mops = 0.0;
868 // auto &w = *arg->setup->w;
869 // size_t m = arg->data.skels.size();
870 // size_t n = w.col();
871 //
872 // auto *FarNodes = &arg->FarNodes;
873 // if ( NNPRUNE ) FarNodes = &arg->NNFarNodes;
874 //
875 // for ( auto it = FarNodes->begin(); it != FarNodes->end(); it ++ )
876 // {
877 // size_t k = (*it)->data.skels.size();
878 // flops += 2.0 * m * n * k;
879 // mops += m * k; // cost of Kab
880 // mops += 2.0 * ( m * n + n * k + k * n );
881 // }
882 //
883 // /** setup the event */
884 // event.Set( label + name, flops, mops );
885 //
886 // /** assume computation bound */
887 // cost = flops / 1E+9;
888 //
889 // /** "LOW" priority */
890 // priority = false;
891 // };
892 //
893 //
894 //
895 // void DependencyAnalysis()
896 // {
897 // for ( auto p : arg->data.FarDependents )
898 // hmlp_msg_dependency_analysis( 306, p, R, this );
899 //
900 // auto *FarNodes = &arg->FarNodes;
901 // if ( NNPRUNE ) FarNodes = &arg->NNFarNodes;
902 // for ( auto it : *FarNodes ) it->DependencyAnalysis( R, this );
903 //
904 // arg->DependencyAnalysis( RW, this );
905 // this->TryEnqueue();
906 // };
907 //
908 // /**
909 // * @brief Notice that S2S depends on all Far interactions, which
910 // * may include local tree nodes or let nodes.
911 // * For HSS case, the only Far interaction is the sibling.
912 // * Skeleton weight of the sibling will always be exchanged
913 // * by default in N2S. Thus, currently we do not need
914 // * a distributed S2S, because the skeleton weight is already
915 // * in place.
916 // *
917 // */
918 // void Execute( Worker* user_worker )
919 // {
920 // auto *node = arg;
921 // /** MPI Support. */
922 // auto comm = node->GetComm();
923 // auto size = node->GetCommSize();
924 // auto rank = node->GetCommRank();
925 //
926 // if ( size < 2 )
927 // {
928 // gofmm::SkeletonsToSkeletons<NNPRUNE, NODE, T>( node );
929 // }
930 // else
931 // {
932 // /** Only 0th rank (owner) will execute this task. */
933 // if ( rank == 0 ) gofmm::SkeletonsToSkeletons<NNPRUNE, NODE, T>( node );
934 // }
935 // };
936 //
937 //}; /** end class DistSkeletonsToSkeletonsTask */
938 //
939 
940 template<typename NODE, typename LETNODE, typename T>
941 class S2STask2 : public Task
942 {
943  public:
944 
945  NODE *arg = NULL;
946 
947  vector<LETNODE*> Sources;
948 
949  int p = 0;
950 
951  Lock *lock = NULL;
952 
953  int *num_arrived_subtasks;
954 
955  void Set( NODE *user_arg, vector<LETNODE*> user_src, int user_p, Lock *user_lock,
956  int *user_num_arrived_subtasks )
957  {
958  arg = user_arg;
959  Sources = user_src;
960  p = user_p;
961  lock = user_lock;
962  num_arrived_subtasks = user_num_arrived_subtasks;
963  name = string( "S2S" );
964  label = to_string( arg->treelist_id );
965 
967  double flops = 0.0, mops = 0.0;
968  size_t nrhs = arg->setup->w->col();
969  size_t m = arg->data.skels.size();
970  for ( auto src : Sources )
971  {
972  size_t k = src->data.skels.size();
973  flops += 2 * m * k * nrhs;
974  mops += 2 * ( m * k + ( m + k ) * nrhs );
975  flops += 2 * m * nrhs;
976  flops += m * k * ( 2 * 18 + 100 );
977  }
979  event.Set( label + name, flops, mops );
981  cost = flops / 1E+9;
983  if ( arg->treelist_id == 0 ) priority = true;
984  };
985 
986  void DependencyAnalysis()
987  {
988  if ( p == hmlp_get_mpi_rank() )
989  {
990  for ( auto src : Sources ) src->DependencyAnalysis( R, this );
991  }
992  else hmlp_msg_dependency_analysis( 306, p, R, this );
993  this->TryEnqueue();
994  };
995 
996  void Execute( Worker* user_worker )
997  {
998  auto *node = arg;
999  if ( !node->parent || !node->data.isskel ) return;
1000  size_t nrhs = node->setup->w->col();
1001  auto &K = *node->setup->K;
1002  auto &I = node->data.skels;
1003 
1005  Data<T> u( I.size(), nrhs, 0.0 );
1006 
1007  for ( auto src : Sources )
1008  {
1009  auto &J = src->data.skels;
1010  auto &w = src->data.w_skel;
1011  bool is_cached = true;
1012 
1013  auto &KIJ = node->DistFar[ p ][ src->morton ];
1014  if ( KIJ.row() != I.size() || KIJ.col() != J.size() )
1015  {
1016  //printf( "KIJ %lu %lu I %lu J %lu\n", KIJ.row(), KIJ.col(), I.size(), J.size() );
1017  KIJ = K( I, J );
1018  is_cached = false;
1019  }
1020 
1021  assert( w.col() == nrhs );
1022  assert( w.row() == J.size() );
1023  //xgemm
1024  //(
1025  // "N", "N", u.row(), u.col(), w.row(),
1026  // 1.0, KIJ.data(), KIJ.row(),
1027  // w.data(), w.row(),
1028  // 1.0, u.data(), u.row()
1029  //);
1030  gemm::xgemm( (T)1.0, KIJ, w, (T)1.0, u );
1031 
1033  if ( !is_cached )
1034  {
1035  KIJ.resize( 0, 0 );
1036  KIJ.shrink_to_fit();
1037  }
1038  }
1039 
1040  lock->Acquire();
1041  {
1042  auto &u_skel = node->data.u_skel;
1043  for ( int i = 0; i < u.size(); i ++ )
1044  u_skel[ i ] += u[ i ];
1045  }
1046  lock->Release();
1047  #pragma omp atomic update
1048  *num_arrived_subtasks += 1;
1049  };
1050 };
1051 
1052 template<typename NODE, typename LETNODE, typename T>
1053 class S2SReduceTask2 : public Task
1054 {
1055  public:
1056 
1057  NODE *arg = NULL;
1058 
1059  vector<S2STask2<NODE, LETNODE, T>*> subtasks;
1060 
1061  Lock lock;
1062 
1063  int num_arrived_subtasks = 0;
1064 
1065  const size_t batch_size = 2;
1066 
1067  void Set( NODE *user_arg )
1068  {
1069  arg = user_arg;
1070  name = string( "S2SR" );
1071  label = to_string( arg->treelist_id );
1072 
1074  if ( arg )
1075  {
1076  size_t nrhs = arg->setup->w->col();
1077  auto &I = arg->data.skels;
1078  arg->data.u_skel.resize( 0, 0 );
1079  arg->data.u_skel.resize( I.size(), nrhs, 0 );
1080  }
1081 
1083  for ( int p = 0; p < hmlp_get_mpi_size(); p ++ )
1084  {
1085  vector<LETNODE*> Sources;
1086  for ( auto &it : arg->DistFar[ p ] )
1087  {
1088  Sources.push_back( (*arg->morton2node)[ it.first ] );
1089  if ( Sources.size() == batch_size )
1090  {
1091  subtasks.push_back( new S2STask2<NODE, LETNODE, T>() );
1092  subtasks.back()->Submit();
1093  subtasks.back()->Set( user_arg, Sources, p, &lock, &num_arrived_subtasks );
1094  subtasks.back()->DependencyAnalysis();
1095  Sources.clear();
1096  }
1097  }
1098  if ( Sources.size() )
1099  {
1100  subtasks.push_back( new S2STask2<NODE, LETNODE, T>() );
1101  subtasks.back()->Submit();
1102  subtasks.back()->Set( user_arg, Sources, p, &lock, &num_arrived_subtasks );
1103  subtasks.back()->DependencyAnalysis();
1104  Sources.clear();
1105  }
1106  }
1108  double flops = 0, mops = 0;
1110  event.Set( label + name, flops, mops );
1112  priority = true;
1113  };
1114 
1115  void DependencyAnalysis()
1116  {
1117  for ( auto task : subtasks ) Scheduler::DependencyAdd( task, this );
1118  arg->DependencyAnalysis( RW, this );
1119  this->TryEnqueue();
1120  };
1121 
1122  void Execute( Worker* user_worker )
1123  {
1125  assert( num_arrived_subtasks == subtasks.size() );
1126  };
1127 };
1128 
1129 
1130 
1131 
1132 
1133 
1134 
1135 
1136 
1137 
1138 
1139 
1140 
1141 
1142 
1143 
1144 
1145 
1146 
1147 
1148 template<bool NNPRUNE, typename NODE, typename T>
1149 void DistSkeletonsToNodes( NODE *node )
1150 {
1152  auto comm = node->GetComm();
1153  auto size = node->GetCommSize();
1154  auto rank = node->GetCommRank();
1155  mpi::Status status;
1156 
1158  auto &K = *node->setup->K;
1159  auto &w = *node->setup->w;
1160 
1161 
1162  size_t nrhs = w.col();
1163 
1164 
1166  if ( !node->parent || !node->data.isskel ) return;
1167 
1168  if ( size < 2 )
1169  {
1171  gofmm::SkeletonsToNodes( node );
1172  }
1173  else
1174  {
1175  auto &data = node->data;
1176  auto &proj = data.proj;
1177  auto &u_skel = data.u_skel;
1178 
1179  if ( rank == 0 )
1180  {
1181  size_t sl = node->child->data.skels.size();
1182  size_t sr = proj.col() - sl;
1184  mpi::SendVector( u_skel, size / 2, 0, comm );
1186  View<T> P( true, proj ), PL, PR;
1187  View<T> U( false, u_skel ), UL( false, node->child->data.u_skel );
1189  P.Partition2x1( PL,
1190  PR, sl, TOP );
1192  gemm::xgemm<GEMM_NB>( (T)1.0, PL, U, (T)1.0, UL );
1193  }
1194 
1196  if ( rank == size / 2 )
1197  {
1198  size_t s = proj.row();
1199  size_t sr = node->child->data.skels.size();
1200  size_t sl = proj.col() - sr;
1202  mpi::RecvVector( u_skel, 0, 0, comm, &status );
1203  u_skel.resize( s, nrhs );
1205  View<T> P( true, proj ), PL, PR;
1206  View<T> U( false, u_skel ), UR( false, node->child->data.u_skel );
1208  P.Partition2x1( PL,
1209  PR, sl, TOP );
1211  gemm::xgemm<GEMM_NB>( (T)1.0, PR, U, (T)1.0, UR );
1212  }
1213  }
1214 };
1220 template<bool NNPRUNE, typename NODE, typename T>
1222 {
1223  public:
1224 
1225  NODE *arg;
1226 
1227  void Set( NODE *user_arg )
1228  {
1229  arg = user_arg;
1230  name = string( "PS2N" );
1231  label = to_string( arg->l );
1232 
1233  double flops = 0.0, mops = 0.0;
1234  auto &gids = arg->gids;
1235  auto &skels = arg->data.skels;
1236  auto &w = *arg->setup->w;
1237 
1238  if ( !arg->child )
1239  {
1240  if ( arg->isleaf )
1241  {
1242  auto m = skels.size();
1243  auto n = w.col();
1244  auto k = gids.size();
1245  flops = 2.0 * m * n * k;
1246  mops = 2.0 * ( m * n + m * k + k * n );
1247  }
1248  else
1249  {
1250  auto &lskels = arg->lchild->data.skels;
1251  auto &rskels = arg->rchild->data.skels;
1252  auto m = skels.size();
1253  auto n = w.col();
1254  auto k = lskels.size() + rskels.size();
1255  flops = 2.0 * m * n * k;
1256  mops = 2.0 * ( m * n + m * k + k * n );
1257  }
1258  }
1259  else
1260  {
1261  if ( arg->GetCommRank() == 0 )
1262  {
1263  auto &lskels = arg->child->data.skels;
1264  auto m = skels.size();
1265  auto n = w.col();
1266  auto k = lskels.size();
1267  flops = 2.0 * m * n * k;
1268  mops = 2.0 * ( m * n + m * k + k * n );
1269  }
1270  if ( arg->GetCommRank() == arg->GetCommSize() / 2 )
1271  {
1272  auto &rskels = arg->child->data.skels;
1273  auto m = skels.size();
1274  auto n = w.col();
1275  auto k = rskels.size();
1276  flops = 2.0 * m * n * k;
1277  mops = 2.0 * ( m * n + m * k + k * n );
1278  }
1279  }
1280 
1282  event.Set( label + name, flops, mops );
1284  cost = flops / 1E+9;
1286  priority = true;
1287  };
1288 
1289  void DependencyAnalysis() { arg->DependOnParent( this ); };
1290 
1291  void Execute( Worker* user_worker ) { DistSkeletonsToNodes<NNPRUNE, NODE, T>( arg ); };
1292 
1293 };
1297 template<typename NODE, typename T>
1298 class L2LTask2 : public Task
1299 {
1300  public:
1301 
1302  NODE *arg = NULL;
1303 
1305  vector<NODE*> Sources;
1306 
1307  int p = 0;
1308 
1310  Lock *lock = NULL;
1311 
1312  int *num_arrived_subtasks;
1313 
1314  void Set( NODE *user_arg, vector<NODE*> user_src, int user_p, Lock *user_lock,
1315  int* user_num_arrived_subtasks )
1316  {
1317  arg = user_arg;
1318  Sources = user_src;
1319  p = user_p;
1320  lock = user_lock;
1321  num_arrived_subtasks = user_num_arrived_subtasks;
1322  name = string( "L2L" );
1323  label = to_string( arg->treelist_id );
1324 
1326  double flops = 0.0, mops = 0.0;
1327  size_t nrhs = arg->setup->w->col();
1328  size_t m = arg->gids.size();
1329  for ( auto src : Sources )
1330  {
1331  size_t k = src->gids.size();
1332  flops += 2 * m * k * nrhs;
1333  mops += 2 * ( m * k + ( m + k ) * nrhs );
1334  flops += 2 * m * nrhs;
1335  flops += m * k * ( 2 * 18 + 100 );
1336  }
1338  event.Set( label + name, flops, mops );
1340  cost = flops / 1E+9;
1342  priority = false;
1343  };
1344 
1346  {
1348  if ( p != hmlp_get_mpi_rank() )
1349  hmlp_msg_dependency_analysis( 300, p, R, this );
1350  this->TryEnqueue();
1351  };
1352 
1353  void Execute( Worker* user_worker )
1354  {
1355  auto *node = arg;
1356  size_t nrhs = node->setup->w->col();
1357  auto &K = *node->setup->K;
1358  auto &I = node->gids;
1359 
1360  double beg = omp_get_wtime();
1362  Data<T> u( I.size(), nrhs, 0.0 );
1363  size_t k;
1364 
1365  for ( auto src : Sources )
1366  {
1368  View<T> &W = src->data.w_view;
1369  Data<T> &w = src->data.w_leaf;
1370 
1371  bool is_cached = true;
1372  auto &J = src->gids;
1373  auto &KIJ = node->DistNear[ p ][ src->morton ];
1374  if ( KIJ.row() != I.size() || KIJ.col() != J.size() )
1375  {
1376  KIJ = K( I, J );
1377  is_cached = false;
1378  }
1379 
1380  if ( W.col() == nrhs && W.row() == J.size() )
1381  {
1382  k += W.row();
1383  xgemm
1384  (
1385  "N", "N", u.row(), u.col(), W.row(),
1386  1.0, KIJ.data(), KIJ.row(),
1387  W.data(), W.ld(),
1388  1.0, u.data(), u.row()
1389  );
1390  }
1391  else
1392  {
1393  k += w.row();
1394  xgemm
1395  (
1396  "N", "N", u.row(), u.col(), w.row(),
1397  1.0, KIJ.data(), KIJ.row(),
1398  w.data(), w.row(),
1399  1.0, u.data(), u.row()
1400  );
1401  }
1402 
1404  if ( !is_cached )
1405  {
1406  KIJ.resize( 0, 0 );
1407  KIJ.shrink_to_fit();
1408  }
1409  }
1410 
1411  double lock_beg = omp_get_wtime();
1412  lock->Acquire();
1413  {
1415  View<T> &U = node->data.u_view;
1416  for ( int j = 0; j < u.col(); j ++ )
1417  for ( int i = 0; i < u.row(); i ++ )
1418  U( i, j ) += u( i, j );
1419  }
1420  lock->Release();
1421  double lock_time = omp_get_wtime() - lock_beg;
1422 
1423  double gemm_time = omp_get_wtime() - beg;
1424  double GFLOPS = 2.0 * u.row() * u.col() * k / ( 1E+9 * gemm_time );
1425  //printf( "GEMM %4lu %4lu %4lu %lf GFLOPS, lock(%lf/%lf)\n",
1426  // u.row(), u.col(), k, GFLOPS, lock_time, gemm_time ); fflush( stdout );
1427  #pragma omp atomic update
1428  *num_arrived_subtasks += 1;
1429  };
1430 };
1431 
1432 
1433 
1434 
1435 template<typename NODE, typename T>
1436 class L2LReduceTask2 : public Task
1437 {
1438  public:
1439 
1440  NODE *arg = NULL;
1441 
1442  vector<L2LTask2<NODE, T>*> subtasks;
1443 
1444  Lock lock;
1445 
1446  int num_arrived_subtasks = 0;
1447 
1448  const size_t batch_size = 2;
1449 
1450  void Set( NODE *user_arg )
1451  {
1452  arg = user_arg;
1453  name = string( "L2LR" );
1454  label = to_string( arg->treelist_id );
1456  for ( int p = 0; p < hmlp_get_mpi_size(); p ++ )
1457  {
1458  vector<NODE*> Sources;
1459  for ( auto &it : arg->DistNear[ p ] )
1460  {
1461  Sources.push_back( (*arg->morton2node)[ it.first ] );
1462  if ( Sources.size() == batch_size )
1463  {
1464  subtasks.push_back( new L2LTask2<NODE, T>() );
1465  subtasks.back()->Submit();
1466  subtasks.back()->Set( user_arg, Sources, p, &lock, &num_arrived_subtasks );
1467  subtasks.back()->DependencyAnalysis();
1468  Sources.clear();
1469  }
1470  }
1471  if ( Sources.size() )
1472  {
1473  subtasks.push_back( new L2LTask2<NODE, T>() );
1474  subtasks.back()->Submit();
1475  subtasks.back()->Set( user_arg, Sources, p, &lock, &num_arrived_subtasks );
1476  subtasks.back()->DependencyAnalysis();
1477  Sources.clear();
1478  }
1479  }
1480 
1481 
1482 
1483 
1485  double flops = 0, mops = 0;
1487  event.Set( label + name, flops, mops );
1489  priority = false;
1490  };
1491 
1492  void DependencyAnalysis()
1493  {
1494  for ( auto task : subtasks ) Scheduler::DependencyAdd( task, this );
1495  arg->DependencyAnalysis( RW, this );
1496  this->TryEnqueue();
1497  };
1498 
1499  void Execute( Worker* user_worker )
1500  {
1501  assert( num_arrived_subtasks == subtasks.size() );
1502  };
1503 };
1504 
1505 
1506 
1507 
1508 
1509 
1510 
1511 
1512 
1513 
1514 
1515 
1516 
1517 
1518 
1519 
1520 
1537 template<typename TREE>
1538 void FindNearInteractions( TREE &tree )
1539 {
1540  mpi::PrintProgress( "[BEG] Finish FindNearInteractions ...", tree.GetComm() );
1542  using NODE = typename TREE::NODE;
1543  auto &setup = tree.setup;
1544  auto &NN = *setup.NN;
1545  double budget = setup.Budget();
1546  size_t n_leafs = ( 1 << tree.depth );
1555  auto level_beg = tree.treelist.begin() + n_leafs - 1;
1556 
1558  #pragma omp parallel for
1559  for ( size_t node_ind = 0; node_ind < n_leafs; node_ind ++ )
1560  {
1561  auto *node = *(level_beg + node_ind);
1562  auto &data = node->data;
1563  size_t n_nodes = ( 1 << node->l );
1564 
1566  node->NNNearNodes.insert( node );
1567  node->NNNearNodeMortonIDs.insert( node->morton );
1568 
1570  multimap<size_t, size_t> sorted_ballot = gofmm::NearNodeBallots( node );
1571 
1573  for ( auto it = sorted_ballot.rbegin();
1574  it != sorted_ballot.rend(); it ++ )
1575  {
1577  if ( node->NNNearNodes.size() >= n_nodes * budget ) break;
1578 
1585  #pragma omp critical
1586  {
1587  if ( !(*node->morton2node).count( (*it).second ) )
1588  {
1590  (*node->morton2node)[ (*it).second ] = new NODE( (*it).second );
1591  }
1593  auto *target = (*node->morton2node)[ (*it).second ];
1594  node->NNNearNodeMortonIDs.insert( (*it).second );
1595  node->NNNearNodes.insert( target );
1596  }
1597  }
1598  }
1599  mpi::PrintProgress( "[END] Finish FindNearInteractions ...", tree.GetComm() );
1600 };
1605 template<typename NODE>
1606 void FindFarNodes( MortonHelper::Recursor r, NODE *target )
1607 {
1609  if ( r.second > target->l ) return;
1611  size_t node_morton = MortonHelper::MortonID( r );
1612 
1613  //bool prunable = true;
1614  auto & NearMortonIDs = target->NNNearNodeMortonIDs;
1615 
1617  if ( MortonHelper::ContainAny( node_morton, NearMortonIDs ) )
1618  {
1619  FindFarNodes( MortonHelper::RecurLeft( r ), target );
1620  FindFarNodes( MortonHelper::RecurRight( r ), target );
1621  }
1622  else
1623  {
1624  if ( node_morton >= target->morton )
1625  target->NNFarNodeMortonIDs.insert( node_morton );
1626  }
1627 };
1634 template<typename TREE>
1635 void SymmetrizeNearInteractions( TREE & tree )
1636 {
1637  mpi::PrintProgress( "[BEG] SymmetrizeNearInteractions ...", tree.GetComm() );
1638 
1640  using NODE = typename TREE::NODE;
1642  int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
1643  int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
1644 
1645  vector<vector<pair<size_t, size_t>>> sendlist( comm_size );
1646  vector<vector<pair<size_t, size_t>>> recvlist( comm_size );
1647 
1648 
1655  int n_nodes = 1 << tree.depth;
1656  auto level_beg = tree.treelist.begin() + n_nodes - 1;
1657 
1658  #pragma omp parallel
1659  {
1661  vector<vector<pair<size_t, size_t>>> list( comm_size );
1662 
1663  #pragma omp for
1664  for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
1665  {
1666  auto *node = *(level_beg + node_ind);
1667  //auto & NearMortonIDs = node->NNNearNodeMortonIDs;
1668  for ( auto it : node->NNNearNodeMortonIDs )
1669  {
1670  int dest = tree.Morton2Rank( it );
1671  if ( dest >= comm_size ) printf( "%8lu dest %d\n", it, dest );
1672  list[ dest ].push_back( make_pair( it, node->morton ) );
1673  }
1674  }
1676  #pragma omp critical
1677  {
1678  for ( int p = 0; p < comm_size; p ++ )
1679  {
1680  sendlist[ p ].insert( sendlist[ p ].end(),
1681  list[ p ].begin(), list[ p ].end() );
1682  }
1683  }
1684  };
1688  mpi::AlltoallVector( sendlist, recvlist, tree.GetComm() );
1689 
1690 
1692  for ( int p = 0; p < comm_size; p ++ )
1693  {
1694  for ( auto & query : recvlist[ p ] )
1695  {
1697  #pragma omp critical
1698  {
1699  auto* node = tree.morton2node[ query.first ];
1700  if ( !tree.morton2node.count( query.second ) )
1701  {
1702  tree.morton2node[ query.second ] = new NODE( query.second );
1703  }
1704  node->data.lock.Acquire();
1705  {
1706  node->NNNearNodes.insert( tree.morton2node[ query.second ] );
1707  node->NNNearNodeMortonIDs.insert( query.second );
1708  }
1709  node->data.lock.Release();
1710  }
1711  };
1712  }
1713  mpi::Barrier( tree.GetComm() );
1714  mpi::PrintProgress( "[END] SymmetrizeNearInteractions ...", tree.GetComm() );
1715 };
1718 template<typename TREE>
1719 void SymmetrizeFarInteractions( TREE & tree )
1720 {
1721  mpi::PrintProgress( "[BEG] SymmetrizeFarInteractions ...", tree.GetComm() );
1722 
1724  using NODE = typename TREE::NODE;
1726  //int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
1727  //int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
1728 
1729  vector<vector<pair<size_t, size_t>>> sendlist( tree.GetCommSize() );
1730  vector<vector<pair<size_t, size_t>>> recvlist( tree.GetCommSize() );
1731 
1733  #pragma omp parallel
1734  {
1736  vector<vector<pair<size_t, size_t>>> list( tree.GetCommSize() );
1737 
1738  #pragma omp for
1739  for ( size_t i = 1; i < tree.treelist.size(); i ++ )
1740  {
1741  auto *node = tree.treelist[ i ];
1742  for ( auto it = node->NNFarNodeMortonIDs.begin();
1743  it != node->NNFarNodeMortonIDs.end(); it ++ )
1744  {
1746  #pragma omp critical
1747  {
1748  if ( !tree.morton2node.count( *it ) )
1749  {
1750  tree.morton2node[ *it ] = new NODE( *it );
1751  }
1752  node->NNFarNodes.insert( tree.morton2node[ *it ] );
1753  }
1754  int dest = tree.Morton2Rank( *it );
1755  if ( dest >= tree.GetCommSize() ) printf( "%8lu dest %d\n", *it, dest );
1756  list[ dest ].push_back( make_pair( *it, node->morton ) );
1757  }
1758  }
1759 
1760  #pragma omp critical
1761  {
1762  for ( int p = 0; p < tree.GetCommSize(); p ++ )
1763  {
1764  sendlist[ p ].insert( sendlist[ p ].end(),
1765  list[ p ].begin(), list[ p ].end() );
1766  }
1767  }
1768  }
1769 
1770 
1772  #pragma omp parallel
1773  {
1775  vector<vector<pair<size_t, size_t>>> list( tree.GetCommSize() );
1776 
1777  #pragma omp for
1778  for ( size_t i = 0; i < tree.mpitreelists.size(); i ++ )
1779  {
1780  auto *node = tree.mpitreelists[ i ];
1781  for ( auto it = node->NNFarNodeMortonIDs.begin();
1782  it != node->NNFarNodeMortonIDs.end(); it ++ )
1783  {
1785  #pragma omp critical
1786  {
1787  if ( !tree.morton2node.count( *it ) )
1788  {
1789  tree.morton2node[ *it ] = new NODE( *it );
1790  }
1791  node->NNFarNodes.insert( tree.morton2node[ *it ] );
1792  }
1793  int dest = tree.Morton2Rank( *it );
1794  if ( dest >= tree.GetCommSize() ) printf( "%8lu dest %d\n", *it, dest ); fflush( stdout );
1795  list[ dest ].push_back( make_pair( *it, node->morton ) );
1796  }
1797  }
1798 
1799  #pragma omp critical
1800  {
1801  for ( int p = 0; p < tree.GetCommSize(); p ++ )
1802  {
1803  sendlist[ p ].insert( sendlist[ p ].end(),
1804  list[ p ].begin(), list[ p ].end() );
1805  }
1806  }
1807  }
1808 
1810  mpi::AlltoallVector( sendlist, recvlist, tree.GetComm() );
1811 
1813  for ( int p = 0; p < tree.GetCommSize(); p ++ )
1814  {
1815  //#pragma omp parallel for
1816  for ( auto & query : recvlist[ p ] )
1817  {
1819  #pragma omp critical
1820  {
1821  if ( !tree.morton2node.count( query.second ) )
1822  {
1823  tree.morton2node[ query.second ] = new NODE( query.second );
1824  //printf( "rank %d, %8lu level %lu creates far LET %8lu (symmetrize)\n",
1825  // comm_rank, node->morton, node->l, query.second );
1826  }
1827  auto* node = tree.morton2node[ query.first ];
1828  node->data.lock.Acquire();
1829  {
1830  node->NNFarNodes.insert( tree.morton2node[ query.second ] );
1831  node->NNFarNodeMortonIDs.insert( query.second );
1832  }
1833  node->data.lock.Release();
1834  assert( tree.Morton2Rank( node->morton ) == tree.GetCommRank() );
1835  }
1836  }
1837  }
1838 
1839  mpi::Barrier( tree.GetComm() );
1840  mpi::PrintProgress( "[END] SymmetrizeFarInteractions ...", tree.GetComm() );
1841 };
1860 template<typename TREE>
1861 void BuildInteractionListPerRank( TREE &tree, bool is_near )
1862 {
1864  using T = typename TREE::T;
1866  int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
1867  int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
1868 
1870  vector<set<size_t>> lists( comm_size );
1871 
1872  if ( is_near )
1873  {
1875  int n_nodes = 1 << tree.depth;
1876  auto level_beg = tree.treelist.begin() + n_nodes - 1;
1877 
1878  #pragma omp parallel
1879  {
1881  vector<set<size_t>> list( comm_size );
1882 
1883  #pragma omp for
1884  for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
1885  {
1886  auto *node = *(level_beg + node_ind);
1887  auto & NearMortonIDs = node->NNNearNodeMortonIDs;
1888  node->DistNear.resize( comm_size );
1889  for ( auto it : NearMortonIDs )
1890  {
1891  int dest = tree.Morton2Rank( it );
1892  if ( dest >= comm_size ) printf( "%8lu dest %d\n", it, dest );
1893  if ( dest != comm_rank ) list[ dest ].insert( node->morton );
1894  node->DistNear[ dest ][ it ] = Data<T>();
1895  }
1896  }
1898  #pragma omp critical
1899  {
1900  for ( int p = 0; p < comm_size; p ++ )
1901  lists[ p ].insert( list[ p ].begin(), list[ p ].end() );
1902  }
1903  };
1907  vector<vector<size_t>> recvlist( comm_size );
1908  if ( !tree.NearSentToRank.size() ) tree.NearSentToRank.resize( comm_size );
1909  if ( !tree.NearRecvFromRank.size() ) tree.NearRecvFromRank.resize( comm_size );
1910  #pragma omp parallel for
1911  for ( int p = 0; p < comm_size; p ++ )
1912  {
1913  tree.NearSentToRank[ p ].insert( tree.NearSentToRank[ p ].end(),
1914  lists[ p ].begin(), lists[ p ].end() );
1915  }
1916 
1918  mpi::AlltoallVector( tree.NearSentToRank, recvlist, tree.GetComm() );
1919 
1921  #pragma omp parallel for
1922  for ( int p = 0; p < comm_size; p ++ )
1923  for ( int i = 0; i < recvlist[ p ].size(); i ++ )
1924  tree.NearRecvFromRank[ p ][ recvlist[ p ][ i ] ] = i;
1925  }
1926  else
1927  {
1928  #pragma omp parallel
1929  {
1931  vector<set<size_t>> list( comm_size );
1932 
1934  #pragma omp for
1935  for ( size_t i = 1; i < tree.treelist.size(); i ++ )
1936  {
1937  auto *node = tree.treelist[ i ];
1938  node->DistFar.resize( comm_size );
1939  for ( auto it = node->NNFarNodeMortonIDs.begin();
1940  it != node->NNFarNodeMortonIDs.end(); it ++ )
1941  {
1942  int dest = tree.Morton2Rank( *it );
1943  if ( dest >= comm_size ) printf( "%8lu dest %d\n", *it, dest );
1944  if ( dest != comm_rank )
1945  {
1946  list[ dest ].insert( node->morton );
1947  //node->data.FarDependents.insert( dest );
1948  }
1949  node->DistFar[ dest ][ *it ] = Data<T>();
1950  }
1951  }
1952 
1954  #pragma omp for
1955  for ( size_t i = 0; i < tree.mpitreelists.size(); i ++ )
1956  {
1957  auto *node = tree.mpitreelists[ i ];
1958  node->DistFar.resize( comm_size );
1960  if ( tree.Morton2Rank( node->morton ) == comm_rank )
1961  {
1962  for ( auto it = node->NNFarNodeMortonIDs.begin();
1963  it != node->NNFarNodeMortonIDs.end(); it ++ )
1964  {
1965  int dest = tree.Morton2Rank( *it );
1966  if ( dest >= comm_size ) printf( "%8lu dest %d\n", *it, dest );
1967  if ( dest != comm_rank )
1968  {
1969  list[ dest ].insert( node->morton );
1970  //node->data.FarDependents.insert( dest );
1971  }
1972  node->DistFar[ dest ][ *it ] = Data<T>();
1973  }
1974  }
1975  }
1977  #pragma omp critical
1978  {
1979  for ( int p = 0; p < comm_size; p ++ )
1980  lists[ p ].insert( list[ p ].begin(), list[ p ].end() );
1981  }
1983  };
1986  vector<vector<size_t>> recvlist( comm_size );
1987  if ( !tree.FarSentToRank.size() ) tree.FarSentToRank.resize( comm_size );
1988  if ( !tree.FarRecvFromRank.size() ) tree.FarRecvFromRank.resize( comm_size );
1989  #pragma omp parallel for
1990  for ( int p = 0; p < comm_size; p ++ )
1991  {
1992  tree.FarSentToRank[ p ].insert( tree.FarSentToRank[ p ].end(),
1993  lists[ p ].begin(), lists[ p ].end() );
1994  }
1995 
1996 
1998  mpi::AlltoallVector( tree.FarSentToRank, recvlist, tree.GetComm() );
1999 
2001  #pragma omp parallel for
2002  for ( int p = 0; p < comm_size; p ++ )
2003  for ( int i = 0; i < recvlist[ p ].size(); i ++ )
2004  tree.FarRecvFromRank[ p ][ recvlist[ p ][ i ] ] = i;
2005  }
2006 
2007  mpi::Barrier( tree.GetComm() );
2008 };
2011 template<typename TREE>
2012 pair<double, double> NonCompressedRatio( TREE &tree )
2013 {
2015  int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
2016  int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
2017 
2019  double ratio_n = 0.0;
2020  double ratio_f = 0.0;
2021 
2022 
2024  for ( auto &tar : tree.treelist )
2025  {
2026  if ( tar->isleaf )
2027  {
2028  for ( auto nearID : tar->NNNearNodeMortonIDs )
2029  {
2030  auto *src = tree.morton2node[ nearID ];
2031  assert( src );
2032  double m = tar->gids.size();
2033  double n = src->gids.size();
2034  double N = tree.n;
2035  ratio_n += ( m / N ) * ( n / N );
2036  }
2037  }
2038 
2039  for ( auto farID : tar->NNFarNodeMortonIDs )
2040  {
2041  auto *src = tree.morton2node[ farID ];
2042  assert( src );
2043  double m = tar->data.skels.size();
2044  double n = src->data.skels.size();
2045  double N = tree.n;
2046  ratio_f += ( m / N ) * ( n / N );
2047  }
2048  }
2049 
2051  for ( auto &tar : tree.mpitreelists )
2052  {
2053  if ( !tar->child || tar->GetCommRank() ) continue;
2054  for ( auto farID : tar->NNFarNodeMortonIDs )
2055  {
2056  auto *src = tree.morton2node[ farID ];
2057  assert( src );
2058  double m = tar->data.skels.size();
2059  double n = src->data.skels.size();
2060  double N = tree.n;
2061  ratio_f += ( m / N ) * ( n / N );
2062  }
2063  }
2064 
2066  pair<double, double> ret( 0, 0 );
2067  mpi::Allreduce( &ratio_n, &(ret.first), 1, MPI_SUM, tree.GetComm() );
2068  mpi::Allreduce( &ratio_f, &(ret.second), 1, MPI_SUM, tree.GetComm() );
2069 
2070  return ret;
2071 };
2072 
2073 
2074 
2075 template<typename T, typename TREE>
2076 void PackNear( TREE &tree, string option, int p,
2077  vector<size_t> &sendsizes,
2078  vector<size_t> &sendskels,
2079  vector<T> &sendbuffs )
2080 {
2081  vector<size_t> offsets( 1, 0 );
2082 
2083  for ( auto it : tree.NearSentToRank[ p ] )
2084  {
2085  auto *node = tree.morton2node[ it ];
2086  auto &gids = node->gids;
2087  if ( !option.compare( string( "leafgids" ) ) )
2088  {
2089  sendsizes.push_back( gids.size() );
2090  sendskels.insert( sendskels.end(), gids.begin(), gids.end() );
2091  }
2092  else
2093  {
2094  auto &w_view = node->data.w_view;
2095  sendsizes.push_back( gids.size() * w_view.col() );
2096  offsets.push_back( sendsizes.back() + offsets.back() );
2097  }
2098  }
2099 
2100  if ( offsets.size() ) sendbuffs.resize( offsets.back() );
2101 
2102  if ( !option.compare( string( "leafweights" ) ) )
2103  {
2104  #pragma omp parallel for
2105  for ( size_t i = 0; i < tree.NearSentToRank[ p ].size(); i ++ )
2106  {
2107  auto *node = tree.morton2node[ tree.NearSentToRank[ p ][ i ] ];
2108  auto &gids = node->gids;
2109  auto &w_view = node->data.w_view;
2110  auto w_leaf = w_view.toData();
2111  size_t offset = offsets[ i ];
2112  for ( size_t j = 0; j < w_leaf.size(); j ++ )
2113  sendbuffs[ offset + j ] = w_leaf[ j ];
2114  }
2115  }
2116 };
2117 
2118 
2119 template<typename T, typename TREE>
2120 void UnpackLeaf( TREE &tree, string option, int p,
2121  const vector<size_t> &recvsizes,
2122  const vector<size_t> &recvskels,
2123  const vector<T> &recvbuffs )
2124 {
2125  vector<size_t> offsets( 1, 0 );
2126  for ( auto it : recvsizes ) offsets.push_back( offsets.back() + it );
2127 
2128  for ( auto it : tree.NearRecvFromRank[ p ] )
2129  {
2130  auto *node = tree.morton2node[ it.first ];
2131  if ( !option.compare( string( "leafgids" ) ) )
2132  {
2133  auto &gids = node->gids;
2134  size_t i = it.second;
2135  gids.reserve( recvsizes[ i ] );
2136  for ( uint64_t j = offsets[ i + 0 ];
2137  j < offsets[ i + 1 ];
2138  j ++ )
2139  {
2140  gids.push_back( recvskels[ j ] );
2141  }
2142  }
2143  else
2144  {
2146  size_t nrhs = tree.setup.w->col();
2147  auto &w_leaf = node->data.w_leaf;
2148  size_t i = it.second;
2149  w_leaf.resize( recvsizes[ i ] / nrhs, nrhs );
2150  //printf( "%d recv w_leaf from %d [%lu %lu]\n",
2151  // comm_rank, p, w_leaf.row(), w_leaf.col() ); fflush( stdout );
2152  for ( uint64_t j = offsets[ i + 0 ], jj = 0;
2153  j < offsets[ i + 1 ];
2154  j ++, jj ++ )
2155  {
2156  w_leaf[ jj ] = recvbuffs[ j ];
2157  }
2158  }
2159  }
2160 };
2161 
2162 
2163 template<typename T, typename TREE>
2164 void PackFar( TREE &tree, string option, int p,
2165  vector<size_t> &sendsizes,
2166  vector<size_t> &sendskels,
2167  vector<T> &sendbuffs )
2168 {
2169  for ( auto it : tree.FarSentToRank[ p ] )
2170  {
2171  auto *node = tree.morton2node[ it ];
2172  auto &skels = node->data.skels;
2173  if ( !option.compare( string( "skelgids" ) ) )
2174  {
2175  sendsizes.push_back( skels.size() );
2176  sendskels.insert( sendskels.end(), skels.begin(), skels.end() );
2177  }
2178  else
2179  {
2180  auto &w_skel = node->data.w_skel;
2181  sendsizes.push_back( w_skel.size() );
2182  sendbuffs.insert( sendbuffs.end(), w_skel.begin(), w_skel.end() );
2183  }
2184  }
2185 };
2199 template<typename TREE, typename T>
2200 void PackWeights( TREE &tree, int p,
2201  vector<T> &sendbuffs, vector<size_t> &sendsizes )
2202 {
2203  for ( auto it : tree.NearSentToRank[ p ] )
2204  {
2205  auto *node = tree.morton2node[ it ];
2206  auto w_leaf = node->data.w_view.toData();
2207  sendbuffs.insert( sendbuffs.end(), w_leaf.begin(), w_leaf.end() );
2208  sendsizes.push_back( w_leaf.size() );
2209  }
2210 };
2214 template<typename TREE, typename T>
2215 void UnpackWeights( TREE &tree, int p,
2216  const vector<T> recvbuffs, const vector<size_t> &recvsizes )
2217 {
2218  vector<size_t> offsets( 1, 0 );
2219  for ( auto it : recvsizes ) offsets.push_back( offsets.back() + it );
2220 
2221  for ( auto it : tree.NearRecvFromRank[ p ] )
2222  {
2224  auto *node = tree.morton2node[ it.first ];
2226  size_t nrhs = tree.setup.w->col();
2227  auto &w_leaf = node->data.w_leaf;
2228  size_t i = it.second;
2229  w_leaf.resize( recvsizes[ i ] / nrhs, nrhs );
2230  for ( uint64_t j = offsets[ i + 0 ], jj = 0;
2231  j < offsets[ i + 1 ];
2232  j ++, jj ++ )
2233  {
2234  w_leaf[ jj ] = recvbuffs[ j ];
2235  }
2236  }
2237 };
2242 template<typename TREE>
2243 void PackSkeletons( TREE &tree, int p,
2244  vector<size_t> &sendbuffs, vector<size_t> &sendsizes )
2245 {
2246  for ( auto it : tree.FarSentToRank[ p ] )
2247  {
2249  auto *node = tree.morton2node[ it ];
2250  auto &skels = node->data.skels;
2251  sendbuffs.insert( sendbuffs.end(), skels.begin(), skels.end() );
2252  sendsizes.push_back( skels.size() );
2253  }
2254 };
2258 template<typename TREE>
2259 void UnpackSkeletons( TREE &tree, int p,
2260  const vector<size_t> recvbuffs, const vector<size_t> &recvsizes )
2261 {
2262  vector<size_t> offsets( 1, 0 );
2263  for ( auto it : recvsizes ) offsets.push_back( offsets.back() + it );
2264 
2265  for ( auto it : tree.FarRecvFromRank[ p ] )
2266  {
2268  auto *node = tree.morton2node[ it.first ];
2269  auto &skels = node->data.skels;
2270  size_t i = it.second;
2271  skels.clear();
2272  skels.reserve( recvsizes[ i ] );
2273  for ( uint64_t j = offsets[ i + 0 ];
2274  j < offsets[ i + 1 ];
2275  j ++ )
2276  {
2277  skels.push_back( recvbuffs[ j ] );
2278  }
2279  }
2280 };
2285 template<typename TREE, typename T>
2286 void PackSkeletonWeights( TREE &tree, int p,
2287  vector<T> &sendbuffs, vector<size_t> &sendsizes )
2288 {
2289  for ( auto it : tree.FarSentToRank[ p ] )
2290  {
2291  auto *node = tree.morton2node[ it ];
2292  auto &w_skel = node->data.w_skel;
2293  sendbuffs.insert( sendbuffs.end(), w_skel.begin(), w_skel.end() );
2294  sendsizes.push_back( w_skel.size() );
2295  }
2296 };
2300 template<typename TREE, typename T>
2301 void UnpackSkeletonWeights( TREE &tree, int p,
2302  const vector<T> recvbuffs, const vector<size_t> &recvsizes )
2303 {
2304  vector<size_t> offsets( 1, 0 );
2305  for ( auto it : recvsizes ) offsets.push_back( offsets.back() + it );
2306 
2307  for ( auto it : tree.FarRecvFromRank[ p ] )
2308  {
2310  auto *node = tree.morton2node[ it.first ];
2312  size_t nrhs = tree.setup.w->col();
2313  auto &w_skel = node->data.w_skel;
2314  size_t i = it.second;
2315  w_skel.resize( recvsizes[ i ] / nrhs, nrhs );
2316  for ( uint64_t j = offsets[ i + 0 ], jj = 0;
2317  j < offsets[ i + 1 ];
2318  j ++, jj ++ )
2319  {
2320  w_skel[ jj ] = recvbuffs[ j ];
2321  }
2322  }
2323 };
2330 template<typename T, typename TREE>
2331 void UnpackFar( TREE &tree, string option, int p,
2332  const vector<size_t> &recvsizes,
2333  const vector<size_t> &recvskels,
2334  const vector<T> &recvbuffs )
2335 {
2336  vector<size_t> offsets( 1, 0 );
2337  for ( auto it : recvsizes ) offsets.push_back( offsets.back() + it );
2338 
2339  for ( auto it : tree.FarRecvFromRank[ p ] )
2340  {
2342  auto *node = tree.morton2node[ it.first ];
2343  if ( !option.compare( string( "skelgids" ) ) )
2344  {
2345  auto &skels = node->data.skels;
2346  size_t i = it.second;
2347  skels.clear();
2348  skels.reserve( recvsizes[ i ] );
2349  for ( uint64_t j = offsets[ i + 0 ];
2350  j < offsets[ i + 1 ];
2351  j ++ )
2352  {
2353  skels.push_back( recvskels[ j ] );
2354  }
2355  }
2356  else
2357  {
2359  size_t nrhs = tree.setup.w->col();
2360  auto &w_skel = node->data.w_skel;
2361  size_t i = it.second;
2362  w_skel.resize( recvsizes[ i ] / nrhs, nrhs );
2363  //printf( "%d recv w_skel (%8lu) from %d [%lu %lu], i %lu, offset[%lu %lu] \n",
2364  // comm_rank, (*it).first, p, w_skel.row(), w_skel.col(), i,
2365  // offsets[ p ][ i + 0 ], offsets[ p ][ i + 1 ] ); fflush( stdout );
2366  for ( uint64_t j = offsets[ i + 0 ], jj = 0;
2367  j < offsets[ i + 1 ];
2368  j ++, jj ++ )
2369  {
2370  w_skel[ jj ] = recvbuffs[ j ];
2371  //if ( jj < 5 ) printf( "%E ", w_skel[ jj ] ); fflush( stdout );
2372  }
2373  //printf( "\n" ); fflush( stdout );
2374  }
2375  }
2376 };
2377 
2378 
2379 template<typename T, typename TREE>
2380 class PackNearTask : public SendTask<T, TREE>
2381 {
2382  public:
2383 
2384  PackNearTask( TREE *tree, int src, int tar, int key )
2385  : SendTask<T, TREE>( tree, src, tar, key )
2386  {
2388  this->Submit();
2389  this->DependencyAnalysis();
2390  };
2391 
2392  void DependencyAnalysis()
2393  {
2394  TREE &tree = *(this->arg);
2395  tree.DependOnNearInteractions( this->tar, this );
2396  };
2397 
2399  void Pack()
2400  {
2401  PackWeights( *this->arg, this->tar,
2402  this->send_buffs, this->send_sizes );
2403  };
2404 
2405 };
2421 template<typename T, typename TREE>
2422 class UnpackLeafTask : public RecvTask<T, TREE>
2423 {
2424  public:
2425 
2426  UnpackLeafTask( TREE *tree, int src, int tar, int key )
2427  : RecvTask<T, TREE>( tree, src, tar, key )
2428  {
2430  this->Submit();
2431  this->DependencyAnalysis();
2432  };
2433 
2434  void Unpack()
2435  {
2436  UnpackWeights( *this->arg, this->src,
2437  this->recv_buffs, this->recv_sizes );
2438  };
2439 
2440 };
2444 template<typename T, typename TREE>
2445 class PackFarTask : public SendTask<T, TREE>
2446 {
2447  public:
2448 
2449  PackFarTask( TREE *tree, int src, int tar, int key )
2450  : SendTask<T, TREE>( tree, src, tar, key )
2451  {
2453  this->Submit();
2454  this->DependencyAnalysis();
2455  };
2456 
2457  void DependencyAnalysis()
2458  {
2459  TREE &tree = *(this->arg);
2460  tree.DependOnFarInteractions( this->tar, this );
2461  };
2462 
2464  void Pack()
2465  {
2466  PackSkeletonWeights( *this->arg, this->tar,
2467  this->send_buffs, this->send_sizes );
2468  };
2469 
2470 };
2474 template<typename T, typename TREE>
2475 class UnpackFarTask : public RecvTask<T, TREE>
2476 {
2477  public:
2478 
2479  UnpackFarTask( TREE *tree, int src, int tar, int key )
2480  : RecvTask<T, TREE>( tree, src, tar, key )
2481  {
2483  this->Submit();
2484  this->DependencyAnalysis();
2485  };
2486 
2487  void Unpack()
2488  {
2489  UnpackSkeletonWeights( *this->arg, this->src,
2490  this->recv_buffs, this->recv_sizes );
2491  };
2492 
2493 };
2512 template<typename TREE>
2513 void ExchangeLET( TREE &tree, string option )
2514 {
2516  using T = typename TREE::T;
2518  int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
2519  int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
2520 
2522  vector<vector<size_t>> sendsizes( comm_size );
2523  vector<vector<size_t>> recvsizes( comm_size );
2524  vector<vector<size_t>> sendskels( comm_size );
2525  vector<vector<size_t>> recvskels( comm_size );
2526  vector<vector<T>> sendbuffs( comm_size );
2527  vector<vector<T>> recvbuffs( comm_size );
2528 
2530  #pragma omp parallel for
2531  for ( int p = 0; p < comm_size; p ++ )
2532  {
2533  if ( !option.compare( 0, 4, "leaf" ) )
2534  {
2535  PackNear( tree, option, p, sendsizes[ p ], sendskels[ p ], sendbuffs[ p ] );
2536  }
2537  else if ( !option.compare( 0, 4, "skel" ) )
2538  {
2539  PackFar( tree, option, p, sendsizes[ p ], sendskels[ p ], sendbuffs[ p ] );
2540  }
2541  else
2542  {
2543  printf( "ExchangeLET: option <%s> not available.\n", option.data() );
2544  exit( 1 );
2545  }
2546  }
2547 
2549  mpi::AlltoallVector( sendsizes, recvsizes, tree.GetComm() );
2550  if ( !option.compare( string( "skelgids" ) ) ||
2551  !option.compare( string( "leafgids" ) ) )
2552  {
2553  auto &K = *tree.setup.K;
2554  mpi::AlltoallVector( sendskels, recvskels, tree.GetComm() );
2555  K.RequestIndices( recvskels );
2556  }
2557  else
2558  {
2559  double beg = omp_get_wtime();
2560  mpi::AlltoallVector( sendbuffs, recvbuffs, tree.GetComm() );
2561  double a2av_time = omp_get_wtime() - beg;
2562  if ( comm_rank == 0 ) printf( "a2av_time %lfs\n", a2av_time );
2563  }
2564 
2565 
2567  #pragma omp parallel for
2568  for ( int p = 0; p < comm_size; p ++ )
2569  {
2570  if ( !option.compare( 0, 4, "leaf" ) )
2571  {
2572  UnpackLeaf( tree, option, p, recvsizes[ p ], recvskels[ p ], recvbuffs[ p ] );
2573  }
2574  else if ( !option.compare( 0, 4, "skel" ) )
2575  {
2576  UnpackFar( tree, option, p, recvsizes[ p ], recvskels[ p ], recvbuffs[ p ] );
2577  }
2578  else
2579  {
2580  printf( "ExchangeLET: option <%s> not available.\n", option.data() );
2581  exit( 1 );
2582  }
2583  }
2584 
2585 
2586 };
2590 template<typename T, typename TREE>
2591 void AsyncExchangeLET( TREE &tree, string option )
2592 {
2594  int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
2595  int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
2596 
2598  for ( int p = 0; p < comm_size; p ++ )
2599  {
2600  if ( !option.compare( 0, 4, "leaf" ) )
2601  {
2602  auto *task = new PackNearTask<T, TREE>( &tree, comm_rank, p, 300 );
2604  //task->Set( &tree, comm_rank, p, 300 );
2605  //task->Submit();
2606  //task->DependencyAnalysis();
2607  }
2608  else if ( !option.compare( 0, 4, "skel" ) )
2609  {
2610  auto *task = new PackFarTask<T, TREE>( &tree, comm_rank, p, 306 );
2612  //task->Set( &tree, comm_rank, p, 306 );
2613  //task->Submit();
2614  //task->DependencyAnalysis();
2615  }
2616  else
2617  {
2618  printf( "AsyncExchangeLET: option <%s> not available.\n", option.data() );
2619  exit( 1 );
2620  }
2621  }
2622 
2624  for ( int p = 0; p < comm_size; p ++ )
2625  {
2626  if ( !option.compare( 0, 4, "leaf" ) )
2627  {
2628  auto *task = new UnpackLeafTask<T, TREE>( &tree, p, comm_rank, 300 );
2630  //task->Set( &tree, p, comm_rank, 300 );
2631  //task->Submit();
2632  //task->DependencyAnalysis();
2633  }
2634  else if ( !option.compare( 0, 4, "skel" ) )
2635  {
2636  auto *task = new UnpackFarTask<T, TREE>( &tree, p, comm_rank, 306 );
2638  //task->Set( &tree, p, comm_rank, 306 );
2639  //task->Submit();
2640  //task->DependencyAnalysis();
2641  }
2642  else
2643  {
2644  printf( "AsyncExchangeLET: option <%s> not available.\n", option.data() );
2645  exit( 1 );
2646  }
2647  }
2648 
2649 };
2654 template<typename T, typename TREE>
2655 void ExchangeNeighbors( TREE &tree )
2656 {
2657  mpi::PrintProgress( "[BEG] ExchangeNeighbors ...", tree.GetComm() );
2658 
2659  int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
2660  int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
2661 
2663  vector<vector<size_t>> send_buff( comm_size );
2664  vector<vector<size_t>> recv_buff( comm_size );
2665 
2667  unordered_set<size_t> requested_gids;
2668  auto &NN = *tree.setup.NN;
2669 
2671  for ( auto & it : NN )
2672  {
2673  if ( it.second >= 0 && it.second < tree.n )
2674  requested_gids.insert( it.second );
2675  }
2676 
2678  for ( auto it : tree.treelist[ 0 ]->gids ) requested_gids.erase( it );
2679 
2681  for ( auto it :requested_gids )
2682  {
2683  int p = it % comm_size;
2684  if ( p != comm_rank ) send_buff[ p ].push_back( it );
2685  }
2686 
2688  auto &K = *tree.setup.K;
2689  K.RequestIndices( send_buff );
2690 
2691  mpi::PrintProgress( "[END] ExchangeNeighbors ...", tree.GetComm() );
2692 };
2704 template<bool SYMMETRIC, typename NODE, typename T>
2705 void MergeFarNodes( NODE *node )
2706 {
2708  //if ( !node->data.isskel ) return;
2709 
2713  //if ( node->isleaf )
2714  //{
2715  // auto & NearMortonIDs = node->NNNearNodeMortonIDs;
2716  // #pragma omp critical
2717  // {
2718  // int rank;
2719  // mpi::Comm_rank( MPI_COMM_WORLD, &rank );
2720  // string outfile = to_string( rank );
2721  // FILE * pFile = fopen( outfile.data(), "a+" );
2722  // fprintf( pFile, "(%8lu) ", node->morton );
2723  // for ( auto it = NearMortonIDs.begin(); it != NearMortonIDs.end(); it ++ )
2724  // fprintf( pFile, "%8lu, ", (*it) );
2725  // fprintf( pFile, "\n" ); //fflush( stdout );
2726  // }
2727 
2728  // //auto & NearNodes = node->NNNearNodes;
2729  // //for ( auto it = NearNodes.begin(); it != NearNodes.end(); it ++ )
2730  // //{
2731  // // if ( !(*it)->NNNearNodes.count( node ) )
2732  // // {
2733  // // printf( "(%8lu) misses %lu\n", (*it)->morton, node->morton ); fflush( stdout );
2734  // // }
2735  // //}
2736  //};
2737 
2738 
2740  assert( !node->FarNodeMortonIDs.size() );
2741  assert( !node->FarNodes.size() );
2742  node->FarNodeMortonIDs.insert( node->sibling->morton );
2743  node->FarNodes.insert( node->sibling );
2744 
2746  if ( node->isleaf )
2747  {
2748  FindFarNodes( MortonHelper::Root(), node );
2749  }
2750  else
2751  {
2753  auto *lchild = node->lchild;
2754  auto *rchild = node->rchild;
2755 
2757  auto &pNNFarNodes = node->NNFarNodeMortonIDs;
2758  auto &lNNFarNodes = lchild->NNFarNodeMortonIDs;
2759  auto &rNNFarNodes = rchild->NNFarNodeMortonIDs;
2760 
2762  for ( auto it = lNNFarNodes.begin();
2763  it != lNNFarNodes.end(); it ++ )
2764  {
2765  if ( rNNFarNodes.count( *it ) )
2766  {
2767  pNNFarNodes.insert( *it );
2768  }
2769  }
2771  for ( auto it = pNNFarNodes.begin();
2772  it != pNNFarNodes.end(); it ++ )
2773  {
2774  lNNFarNodes.erase( *it );
2775  rNNFarNodes.erase( *it );
2776  }
2777  }
2778 
2779 };
2783 template<bool SYMMETRIC, typename NODE, typename T>
2784 class MergeFarNodesTask : public Task
2785 {
2786  public:
2787 
2788  NODE *arg;
2789 
2790  void Set( NODE *user_arg )
2791  {
2792  arg = user_arg;
2793  name = string( "merge" );
2794  label = to_string( arg->treelist_id );
2796  cost = 5.0;
2798  priority = true;
2799  };
2800 
2803  {
2804  arg->DependencyAnalysis( RW, this );
2805  if ( !arg->isleaf )
2806  {
2807  arg->lchild->DependencyAnalysis( RW, this );
2808  arg->rchild->DependencyAnalysis( RW, this );
2809  }
2810  this->TryEnqueue();
2811  };
2812 
2813  void Execute( Worker* user_worker )
2814  {
2815  MergeFarNodes<SYMMETRIC, NODE, T>( arg );
2816  };
2817 
2818 };
2831 template<bool SYMMETRIC, typename NODE, typename T>
2832 void DistMergeFarNodes( NODE *node )
2833 {
2835  mpi::Status status;
2836  mpi::Comm comm = node->GetComm();
2837  int comm_size = node->GetCommSize();
2838  int comm_rank = node->GetCommRank();
2839 
2841  //if ( !node->data.isskel ) return;
2842 
2843 
2845  if ( !node->parent ) return;
2846 
2848  if ( node->GetCommSize() < 2 )
2849  {
2850  MergeFarNodes<SYMMETRIC, NODE, T>( node );
2851  }
2852  else
2853  {
2855  auto *child = node->child;
2856 
2857  if ( comm_rank == 0 )
2858  {
2859  auto &pNNFarNodes = node->NNFarNodeMortonIDs;
2860  auto &lNNFarNodes = child->NNFarNodeMortonIDs;
2861  vector<size_t> recvFarNodes;
2862 
2864  mpi::RecvVector( recvFarNodes, comm_size / 2, 0, comm, &status );
2865 
2867  for ( auto it : recvFarNodes )
2868  {
2869  if ( lNNFarNodes.count( it ) )
2870  {
2871  pNNFarNodes.insert( it );
2872  }
2873  }
2874 
2876  recvFarNodes.clear();
2877  recvFarNodes.reserve( pNNFarNodes.size() );
2878 
2880  for ( auto it : pNNFarNodes )
2881  {
2882  lNNFarNodes.erase( it );
2883  recvFarNodes.push_back( it );
2884  }
2885 
2887  mpi::SendVector( recvFarNodes, comm_size / 2, 0, comm );
2888  }
2889 
2890 
2891  if ( comm_rank == comm_size / 2 )
2892  {
2893  auto &rNNFarNodes = child->NNFarNodeMortonIDs;
2894  vector<size_t> sendFarNodes( rNNFarNodes.begin(), rNNFarNodes.end() );
2895 
2897  mpi::SendVector( sendFarNodes, 0, 0, comm );
2899  mpi::RecvVector( sendFarNodes, 0, 0, comm, &status );
2901  for ( auto it : sendFarNodes ) rNNFarNodes.erase( it );
2902  }
2903  }
2904 
2905 };
2909 template<bool SYMMETRIC, typename NODE, typename T>
2911 {
2912  public:
2913 
2914  NODE *arg = NULL;
2915 
2916  void Set( NODE *user_arg )
2917  {
2918  arg = user_arg;
2919  name = string( "dist-merge" );
2920  label = to_string( arg->treelist_id );
2922  cost = 5.0;
2924  priority = true;
2925  };
2926 
2929  {
2930  arg->DependencyAnalysis( RW, this );
2931  if ( !arg->isleaf )
2932  {
2933  if ( arg->GetCommSize() > 1 )
2934  {
2935  arg->child->DependencyAnalysis( RW, this );
2936  }
2937  else
2938  {
2939  arg->lchild->DependencyAnalysis( RW, this );
2940  arg->rchild->DependencyAnalysis( RW, this );
2941  }
2942  }
2943  this->TryEnqueue();
2944  };
2945 
2946  void Execute( Worker* user_worker )
2947  {
2948  DistMergeFarNodes<SYMMETRIC, NODE, T>( arg );
2949  };
2950 
2951 };
2961 template<bool NNPRUNE, typename NODE>
2962 class CacheFarNodesTask : public Task
2963 {
2964  public:
2965 
2966  NODE *arg = NULL;
2967 
2968  void Set( NODE *user_arg )
2969  {
2970  arg = user_arg;
2971  name = string( "FKIJ" );
2972  label = to_string( arg->treelist_id );
2974  double flops = 0, mops = 0;
2976  cost = 5.0;
2977  };
2978 
2979  void DependencyAnalysis()
2980  {
2981  arg->DependencyAnalysis( RW, this );
2982  this->TryEnqueue();
2983  };
2984 
2985  void Execute( Worker* user_worker )
2986  {
2987  auto *node = arg;
2988  auto &K = *node->setup->K;
2989 
2990  for ( int p = 0; p < node->DistFar.size(); p ++ )
2991  {
2992  for ( auto &it : node->DistFar[ p ] )
2993  {
2994  auto *src = (*node->morton2node)[ it.first ];
2995  auto &I = node->data.skels;
2996  auto &J = src->data.skels;
2997  it.second = K( I, J );
2998  //printf( "Cache I %lu J %lu\n", I.size(), J.size() ); fflush( stdout );
2999  }
3000  }
3001  };
3002 
3003 };
3013 template<bool NNPRUNE, typename NODE>
3014 class CacheNearNodesTask : public Task
3015 {
3016  public:
3017 
3018  NODE *arg = NULL;
3019 
3020  void Set( NODE *user_arg )
3021  {
3022  arg = user_arg;
3023  name = string( "NKIJ" );
3024  label = to_string( arg->treelist_id );
3026  cost = 5.0;
3027  };
3028 
3029  void DependencyAnalysis()
3030  {
3031  arg->DependencyAnalysis( RW, this );
3032  this->TryEnqueue();
3033  };
3034 
3035  void Execute( Worker* user_worker )
3036  {
3037  auto *node = arg;
3038  auto &K = *node->setup->K;
3039 
3040  for ( int p = 0; p < node->DistNear.size(); p ++ )
3041  {
3042  for ( auto &it : node->DistNear[ p ] )
3043  {
3044  auto *src = (*node->morton2node)[ it.first ];
3045  auto &I = node->gids;
3046  auto &J = src->gids;
3047  it.second = K( I, J );
3048  //printf( "Cache I %lu J %lu\n", I.size(), J.size() ); fflush( stdout );
3049  }
3050  }
3051  };
3052 
3053 };
3066 template<typename NODE, typename T>
3067 void DistRowSamples( NODE *node, size_t nsamples )
3068 {
3070  mpi::Comm comm = node->GetComm();
3071  int size = node->GetCommSize();
3072  int rank = node->GetCommRank();
3073 
3075  auto &K = *node->setup->K;
3076 
3078  vector<size_t> &I = node->data.candidate_rows;
3079 
3081  I.clear();
3082 
3084  if ( rank == 0 )
3085  {
3087  I.reserve( nsamples );
3088 
3089  auto &snids = node->data.snids;
3090  multimap<T, size_t> ordered_snids = gofmm::flip_map( snids );
3091 
3092  for ( auto it = ordered_snids.begin();
3093  it != ordered_snids.end(); it++ )
3094  {
3096  I.push_back( (*it).second );
3097  if ( I.size() >= nsamples ) break;
3098  }
3099  }
3100 
3102  vector<size_t> candidates( nsamples );
3103 
3104  size_t n_required = nsamples - I.size();
3105 
3107  mpi::Bcast( &n_required, 1, 0, comm );
3108 
3109  while ( n_required )
3110  {
3111  if ( rank == 0 )
3112  {
3113  for ( size_t i = 0; i < nsamples; i ++ )
3114  {
3115  auto important_sample = K.ImportantSample( 0 );
3116  candidates[ i ] = important_sample.second;
3117  }
3118  }
3119 
3121  mpi::Bcast( candidates.data(), candidates.size(), 0, comm );
3122 
3124  vector<size_t> vconsensus( nsamples, 0 );
3125  vector<size_t> validation = node->setup->ContainAny( candidates, node->morton );
3126 
3128  mpi::Reduce( validation.data(), vconsensus.data(), nsamples, MPI_SUM, 0, comm );
3129 
3130  if ( rank == 0 )
3131  {
3132  for ( size_t i = 0; i < nsamples; i ++ )
3133  {
3135  if ( I.size() >= nsamples )
3136  {
3137  I.resize( nsamples );
3138  break;
3139  }
3141  if ( !vconsensus[ i ] )
3142  {
3143  if ( find( I.begin(), I.end(), candidates[ i ] ) == I.end() )
3144  I.push_back( candidates[ i ] );
3145  }
3146  };
3147 
3149  n_required = nsamples - I.size();
3150  }
3151 
3153  mpi::Bcast( &n_required, 1, 0, comm );
3154  }
3155 
3156 };
3164 template<bool NNPRUNE, typename NODE>
3165 void DistSkeletonKIJ( NODE *node )
3166 {
3168  using T = typename NODE::T;
3170  if ( !node->parent ) return;
3172  auto &K = *(node->setup->K);
3174  auto &data = node->data;
3175  auto &candidate_rows = data.candidate_rows;
3176  auto &candidate_cols = data.candidate_cols;
3177  auto &KIJ = data.KIJ;
3178 
3180  auto comm = node->GetComm();
3181  auto size = node->GetCommSize();
3182  auto rank = node->GetCommRank();
3183  mpi::Status status;
3184 
3185  if ( size < 2 )
3186  {
3188  gofmm::SkeletonKIJ<NNPRUNE>( node );
3189  }
3190  else
3191  {
3199  NODE *child = node->child;
3200  size_t nsamples = 0;
3201 
3203  int child_isskel = child->data.isskel;
3204  mpi::Bcast( &child_isskel, 1, 0, child->GetComm() );
3205  child->data.isskel = child_isskel;
3206 
3207 
3209  if ( rank == 0 )
3210  {
3211  candidate_cols = child->data.skels;
3212  vector<size_t> rskel;
3214  mpi::RecvVector( rskel, size / 2, 10, comm, &status );
3216  K.RecvIndices( size / 2, comm, &status );
3218  candidate_cols.insert( candidate_cols.end(), rskel.begin(), rskel.end() );
3220  nsamples = 2 * candidate_cols.size();
3222  if ( nsamples < 2 * node->setup->LeafNodeSize() )
3223  nsamples = 2 * node->setup->LeafNodeSize();
3224 
3226  auto &lsnids = node->child->data.snids;
3227  vector<T> recv_rsdist;
3228  vector<size_t> recv_rsnids;
3229 
3231  mpi::RecvVector( recv_rsdist, size / 2, 20, comm, &status );
3232  mpi::RecvVector( recv_rsnids, size / 2, 30, comm, &status );
3234  K.RecvIndices( size / 2, comm, &status );
3235 
3236 
3238  auto &snids = node->data.snids;
3239  snids = lsnids;
3240 
3241  for ( size_t i = 0; i < recv_rsdist.size(); i ++ )
3242  {
3243  pair<size_t, T> query( recv_rsnids[ i ], recv_rsdist[ i ] );
3244  auto ret = snids.insert( query );
3245  if ( !ret.second )
3246  {
3247  if ( ret.first->second > recv_rsdist[ i ] )
3248  ret.first->second = recv_rsdist[ i ];
3249  }
3250  }
3251 
3253  for ( auto gid : node->gids ) snids.erase( gid );
3254  }
3255 
3256  if ( rank == size / 2 )
3257  {
3259  mpi::SendVector( child->data.skels, 0, 10, comm );
3261  K.SendIndices( child->data.skels, 0, comm );
3262 
3264  auto &rsnids = node->child->data.snids;
3265  vector<T> send_rsdist;
3266  vector<size_t> send_rsnids;
3267 
3269  send_rsdist.reserve( rsnids.size() );
3270  send_rsnids.reserve( rsnids.size() );
3271 
3272  for ( auto it = rsnids.begin(); it != rsnids.end(); it ++ )
3273  {
3275  send_rsnids.push_back( (*it).first );
3276  send_rsdist.push_back( (*it).second );
3277  }
3278 
3280  mpi::SendVector( send_rsdist, 0, 20, comm );
3281  mpi::SendVector( send_rsnids, 0, 30, comm );
3282 
3284  K.SendIndices( send_rsnids, 0, comm );
3285  }
3286 
3288  mpi::Bcast( &nsamples, 1, 0, comm );
3290  DistRowSamples<NODE, T>( node, nsamples );
3292  if ( rank != 0 )
3293  {
3294  assert( !candidate_rows.size() );
3295  assert( !candidate_cols.size() );
3296  }
3302  KIJ = K( candidate_rows, candidate_cols );
3303  }
3304 };
3310 template<bool NNPRUNE, typename NODE, typename T>
3312 {
3313  public:
3314 
3315  NODE *arg = NULL;
3316 
3317  void Set( NODE *user_arg )
3318  {
3319  arg = user_arg;
3320  name = string( "par-gskm" );
3321  label = to_string( arg->treelist_id );
3323  cost = 5.0;
3325  priority = true;
3326  };
3327 
3328  void DependencyAnalysis() { arg->DependOnChildren( this ); };
3329 
3330  void Execute( Worker* user_worker ) { DistSkeletonKIJ<NNPRUNE>( arg ); };
3331 
3332 };
3348 template<typename NODE, typename T>
3349 void DistSkeletonize( NODE *node )
3350 {
3352  if ( !node->parent ) return;
3353 
3355  auto &K = *(node->setup->K);
3356  auto &NN = *(node->setup->NN);
3357  auto maxs = node->setup->MaximumRank();
3358  auto stol = node->setup->Tolerance();
3359  bool secure_accuracy = node->setup->SecureAccuracy();
3360  bool use_adaptive_ranks = node->setup->UseAdaptiveRanks();
3361 
3363  auto &data = node->data;
3364  auto &skels = data.skels;
3365  auto &proj = data.proj;
3366  auto &jpvt = data.jpvt;
3367  auto &KIJ = data.KIJ;
3368  auto &candidate_cols = data.candidate_cols;
3369 
3371  size_t N = K.col();
3372  size_t m = KIJ.row();
3373  size_t n = KIJ.col();
3374  size_t q = node->n;
3375 
3376  if ( secure_accuracy )
3377  {
3379  }
3380 
3381 
3383  T scaled_stol = std::sqrt( (T)n / q ) * std::sqrt( (T)m / (N - q) ) * stol;
3384 
3386  scaled_stol *= std::sqrt( (T)q / N );
3387 
3388  lowrank::id
3389  (
3390  use_adaptive_ranks, secure_accuracy,
3391  KIJ.row(), KIJ.col(), maxs, scaled_stol,
3392  KIJ, skels, proj, jpvt
3393  );
3394 
3396  KIJ.resize( 0, 0 );
3397  KIJ.shrink_to_fit();
3398 
3400  if ( secure_accuracy )
3401  {
3403  data.isskel = (skels.size() != 0);
3404  }
3405  else
3406  {
3407  assert( skels.size() );
3408  assert( proj.size() );
3409  assert( jpvt.size() );
3410  data.isskel = true;
3411  }
3412 
3414  for ( size_t i = 0; i < skels.size(); i ++ )
3415  {
3416  skels[ i ] = candidate_cols[ skels[ i ] ];
3417  }
3418 
3419 
3420 };
3425 template<typename NODE, typename T>
3427 {
3428  public:
3429 
3430  NODE *arg;
3431 
3432  void Set( NODE *user_arg )
3433  {
3434  arg = user_arg;
3435  name = string( "SK" );
3436  label = to_string( arg->treelist_id );
3438  cost = 5.0;
3440  priority = true;
3441  };
3442 
3444  {
3445  double flops = 0.0, mops = 0.0;
3446 
3447  auto &K = *arg->setup->K;
3448  size_t n = arg->data.proj.col();
3449  size_t m = 2 * n;
3450  size_t k = arg->data.proj.row();
3451 
3453  flops += ( 4.0 / 3.0 ) * n * n * ( 3 * m - n );
3454  mops += ( 2.0 / 3.0 ) * n * n * ( 3 * m - n );
3455 
3456  /* TRSM */
3457  flops += k * ( k - 1 ) * ( n + 1 );
3458  mops += 2.0 * ( k * k + k * n );
3459 
3460  event.Set( label + name, flops, mops );
3461  arg->data.skeletonize = event;
3462  };
3463 
3464  void DependencyAnalysis()
3465  {
3466  arg->DependencyAnalysis( RW, this );
3467  this->TryEnqueue();
3468  };
3469 
3470  void Execute( Worker* user_worker )
3471  {
3472  //printf( "%d Par-Skel beg\n", global_rank );
3473 
3474  DistSkeletonize<NODE, T>( arg );
3475 
3476  //printf( "%d Par-Skel end\n", global_rank );
3477  };
3478 
3479 };
3487 template<typename NODE, typename T>
3489 {
3490  public:
3491 
3492  NODE *arg;
3493 
3494  void Set( NODE *user_arg )
3495  {
3496  arg = user_arg;
3497  name = string( "PSK" );
3498  label = to_string( arg->treelist_id );
3499 
3501  cost = 5.0;
3503  priority = true;
3504  };
3505 
3507  {
3508  double flops = 0.0, mops = 0.0;
3509 
3510  auto &K = *arg->setup->K;
3511  size_t n = arg->data.proj.col();
3512  size_t m = 2 * n;
3513  size_t k = arg->data.proj.row();
3514 
3515  if ( arg->GetCommRank() == 0 )
3516  {
3518  flops += ( 4.0 / 3.0 ) * n * n * ( 3 * m - n );
3519  mops += ( 2.0 / 3.0 ) * n * n * ( 3 * m - n );
3520 
3521  /* TRSM */
3522  flops += k * ( k - 1 ) * ( n + 1 );
3523  mops += 2.0 * ( k * k + k * n );
3524  }
3525 
3526  event.Set( label + name, flops, mops );
3527  arg->data.skeletonize = event;
3528  };
3529 
3530  void DependencyAnalysis()
3531  {
3532  arg->DependencyAnalysis( RW, this );
3533  this->TryEnqueue();
3534  };
3535 
3536  void Execute( Worker* user_worker )
3537  {
3538  mpi::Comm comm = arg->GetComm();
3539 
3540  double beg = omp_get_wtime();
3541  if ( arg->GetCommRank() == 0 )
3542  {
3543  DistSkeletonize<NODE, T>( arg );
3544  }
3545  double skel_t = omp_get_wtime() - beg;
3546 
3548  int isskel = arg->data.isskel;
3549  mpi::Bcast( &isskel, 1, 0, comm );
3550  arg->data.isskel = isskel;
3551 
3553  auto &skels = arg->data.skels;
3554  size_t nskels = skels.size();
3555  mpi::Bcast( &nskels, 1, 0, comm );
3556  if ( skels.size() != nskels ) skels.resize( nskels );
3557  mpi::Bcast( skels.data(), skels.size(), 0, comm );
3558 
3559  };
3560 
3561 };
3569 template<typename NODE>
3570 class InterpolateTask : public Task
3571 {
3572  public:
3573 
3574  NODE *arg = NULL;
3575 
3576  void Set( NODE *user_arg )
3577  {
3578  arg = user_arg;
3579  name = string( "PROJ" );
3580  label = to_string( arg->treelist_id );
3581  // Need an accurate cost model.
3582  cost = 1.0;
3583  };
3584 
3585  void DependencyAnalysis() { arg->DependOnNoOne( this ); };
3586 
3587  void Execute( Worker* user_worker )
3588  {
3590  auto comm = arg->GetComm();
3592  if ( arg->GetCommRank() == 0 ) gofmm::Interpolate( arg );
3593 
3594  auto &proj = arg->data.proj;
3595  size_t nrow = proj.row();
3596  size_t ncol = proj.col();
3597  mpi::Bcast( &nrow, 1, 0, comm );
3598  mpi::Bcast( &ncol, 1, 0, comm );
3599  if ( proj.row() != nrow || proj.col() != ncol ) proj.resize( nrow, ncol );
3600  mpi::Bcast( proj.data(), proj.size(), 0, comm );
3601  };
3602 
3603 };
3637 template<bool NNPRUNE = true, typename TREE, typename T>
3638 DistData<RIDS, STAR, T> Evaluate( TREE &tree, DistData<RIDS, STAR, T> &weights )
3639 {
3640  try
3641  {
3643  int size; mpi::Comm_size( tree.GetComm(), &size );
3644  int rank; mpi::Comm_rank( tree.GetComm(), &rank );
3646  using NODE = typename TREE::NODE;
3647  using MPINODE = typename TREE::MPINODE;
3648 
3650  double beg, time_ratio, evaluation_time = 0.0;
3651  double direct_evaluation_time = 0.0, computeall_time, telescope_time, let_exchange_time, async_time;
3652  double overhead_time;
3653  double forward_permute_time, backward_permute_time;
3654 
3656  tree.DependencyCleanUp();
3657 
3659  size_t n = weights.row();
3660  size_t nrhs = weights.col();
3661 
3663  auto &gids_owned = tree.treelist[ 0 ]->gids;
3664  DistData<RIDS, STAR, T> potentials( n, nrhs, gids_owned, tree.GetComm() );
3665  potentials.setvalue( 0.0 );
3666 
3668  tree.setup.w = &weights;
3669  tree.setup.u = &potentials;
3670 
3672  gofmm::TreeViewTask<NODE> seqVIEWtask;
3678  //mpigofmm::DistLeavesToLeavesTask<NNPRUNE, NODE, T> seqL2Ltask;
3679  //mpigofmm::L2LReduceTask<NODE, T> seqL2LReducetask;
3680  mpigofmm::L2LReduceTask2<NODE, T> seqL2LReducetask2;
3682  //gofmm::SkeletonsToSkeletonsTask<NNPRUNE, NODE, T> seqS2Stask;
3683  //mpigofmm::DistSkeletonsToSkeletonsTask<NNPRUNE, MPINODE, T> mpiS2Stask;
3684  //mpigofmm::S2SReduceTask<NODE, T> seqS2SReducetask;
3685  //mpigofmm::S2SReduceTask<MPINODE, T> mpiS2SReducetask;
3686  mpigofmm::S2SReduceTask2<NODE, NODE, T> seqS2SReducetask2;
3691 
3693  mpi::Barrier( tree.GetComm() );
3694 
3695  //{
3696  // /** Stage 1: TreeView and upward telescoping */
3697  // beg = omp_get_wtime();
3698  // tree.DependencyCleanUp();
3699  // tree.DistTraverseDown( mpiVIEWtask );
3700  // tree.LocaTraverseDown( seqVIEWtask );
3701  // tree.LocaTraverseUp( seqN2Stask );
3702  // tree.DistTraverseUp( mpiN2Stask );
3703  // hmlp_run();
3704  // mpi::Barrier( tree.GetComm() );
3705  // telescope_time = omp_get_wtime() - beg;
3706 
3707  // /** Stage 2: LET exchange */
3708  // beg = omp_get_wtime();
3709  // ExchangeLET<T>( tree, string( "skelweights" ) );
3710  // mpi::Barrier( tree.GetComm() );
3711  // ExchangeLET<T>( tree, string( "leafweights" ) );
3712  // mpi::Barrier( tree.GetComm() );
3713  // let_exchange_time = omp_get_wtime() - beg;
3714 
3715  // /** Stage 3: L2L */
3716  // beg = omp_get_wtime();
3717  // tree.DependencyCleanUp();
3718  // tree.LocaTraverseLeafs( seqL2LReducetask2 );
3719  // hmlp_run();
3720  // mpi::Barrier( tree.GetComm() );
3721  // direct_evaluation_time = omp_get_wtime() - beg;
3722 
3723  // /** Stage 4: S2S and downward telescoping */
3724  // beg = omp_get_wtime();
3725  // tree.DependencyCleanUp();
3726  // tree.LocaTraverseUnOrdered( seqS2SReducetask2 );
3727  // tree.DistTraverseUnOrdered( mpiS2SReducetask2 );
3728  // tree.DistTraverseDown( mpiS2Ntask );
3729  // tree.LocaTraverseDown( seqS2Ntask );
3730  // hmlp_run();
3731  // mpi::Barrier( tree.GetComm() );
3732  // computeall_time = omp_get_wtime() - beg;
3733  //}
3734 
3735 
3737  potentials.setvalue( 0.0 );
3738  mpi::Barrier( tree.GetComm() );
3739 
3741  beg = omp_get_wtime();
3742  tree.DependencyCleanUp();
3743  tree.DistTraverseDown( mpiVIEWtask );
3744  tree.LocaTraverseDown( seqVIEWtask );
3745  tree.ExecuteAllTasks();
3747  AsyncExchangeLET<T>( tree, string( "leafweights" ) );
3749  tree.LocaTraverseUp( seqN2Stask );
3750  tree.DistTraverseUp( mpiN2Stask );
3752  AsyncExchangeLET<T>( tree, string( "skelweights" ) );
3754  tree.LocaTraverseLeafs( seqL2LReducetask2 );
3756  tree.LocaTraverseUnOrdered( seqS2SReducetask2 );
3757  tree.DistTraverseUnOrdered( mpiS2SReducetask2 );
3759  tree.DistTraverseDown( mpiS2Ntask );
3760  tree.LocaTraverseDown( seqS2Ntask );
3761  overhead_time = omp_get_wtime() - beg;
3762  tree.ExecuteAllTasks();
3763  async_time = omp_get_wtime() - beg;
3764 
3765 
3766 
3768  evaluation_time += direct_evaluation_time;
3769  evaluation_time += telescope_time;
3770  evaluation_time += let_exchange_time;
3771  evaluation_time += computeall_time;
3772  time_ratio = 100 / evaluation_time;
3773 
3774  if ( rank == 0 && REPORT_EVALUATE_STATUS )
3775  {
3776  printf( "========================================================\n");
3777  printf( "GOFMM evaluation phase\n" );
3778  printf( "========================================================\n");
3779  //printf( "Allocate ------------------------------ %5.2lfs (%5.1lf%%)\n",
3780  // allocate_time, allocate_time * time_ratio );
3781  //printf( "Forward permute ----------------------- %5.2lfs (%5.1lf%%)\n",
3782  // forward_permute_time, forward_permute_time * time_ratio );
3783  printf( "Upward telescope ---------------------- %5.2lfs (%5.1lf%%)\n",
3784  telescope_time, telescope_time * time_ratio );
3785  printf( "LET exchange -------------------------- %5.2lfs (%5.1lf%%)\n",
3786  let_exchange_time, let_exchange_time * time_ratio );
3787  printf( "L2L ----------------------------------- %5.2lfs (%5.1lf%%)\n",
3788  direct_evaluation_time, direct_evaluation_time * time_ratio );
3789  printf( "S2S, S2N ------------------------------ %5.2lfs (%5.1lf%%)\n",
3790  computeall_time, computeall_time * time_ratio );
3791  //printf( "Backward permute ---------------------- %5.2lfs (%5.1lf%%)\n",
3792  // backward_permute_time, backward_permute_time * time_ratio );
3793  printf( "========================================================\n");
3794  printf( "Evaluate ------------------------------ %5.2lfs (%5.1lf%%)\n",
3795  evaluation_time, evaluation_time * time_ratio );
3796  printf( "Evaluate (Async) ---------------------- %5.2lfs (%5.2lfs)\n",
3797  async_time, overhead_time );
3798  printf( "========================================================\n\n");
3799  }
3800 
3801  return potentials;
3802  }
3803  catch ( const exception & e )
3804  {
3805  cout << e.what() << endl;
3806  exit( 1 );
3807  }
3808 };
3813 template<bool NNPRUNE = true, typename TREE, typename T>
3814 DistData<RBLK, STAR, T> Evaluate( TREE &tree, DistData<RBLK, STAR, T> &w_rblk )
3815 {
3816  size_t n = w_rblk.row();
3817  size_t nrhs = w_rblk.col();
3819  DistData<RIDS, STAR, T> w_rids( n, nrhs, tree.treelist[ 0 ]->gids, tree.GetComm() );
3820  w_rids = w_rblk;
3822  auto u_rids = Evaluate<NNPRUNE>( tree, w_rids );
3823  mpi::Barrier( tree.GetComm() );
3825  DistData<RBLK, STAR, T> u_rblk( n, nrhs, tree.GetComm() );
3826  u_rblk = u_rids;
3828  return u_rblk;
3829 };
3833 template<typename SPLITTER, typename T, typename SPDMATRIX>
3835 (
3836  SPDMATRIX &K,
3837  SPLITTER splitter,
3838  gofmm::Configuration<T> &config,
3839  mpi::Comm CommGOFMM,
3840  size_t n_iter = 10
3841 )
3842 {
3844  using DATA = gofmm::NodeData<T>;
3846  using TREE = mpitree::Tree<SETUP, DATA>;
3848  using NODE = typename TREE::NODE;
3850  DistanceMetric metric = config.MetricType();
3851  size_t n = config.ProblemSize();
3852  size_t k = config.NeighborSize();
3854  pair<T, size_t> init( numeric_limits<T>::max(), n );
3855  gofmm::NeighborsTask<NODE, T> NEIGHBORStask;
3856  TREE rkdt( CommGOFMM );
3857  rkdt.setup.FromConfiguration( config, K, splitter, NULL );
3858  return rkdt.AllNearestNeighbor( n_iter, n, k, init, NEIGHBORStask );
3859 };
3872 template<typename SPLITTER, typename RKDTSPLITTER, typename T, typename SPDMATRIX>
3874 *Compress
3875 (
3876  SPDMATRIX &K,
3877  DistData<STAR, CBLK, pair<T, size_t>> &NN_cblk,
3878  SPLITTER splitter,
3879  RKDTSPLITTER rkdtsplitter,
3880  gofmm::Configuration<T> &config,
3881  mpi::Comm CommGOFMM
3882 )
3883 {
3884  try
3885  {
3887  int size; mpi::Comm_size( CommGOFMM, &size );
3888  int rank; mpi::Comm_rank( CommGOFMM, &rank );
3889 
3891  DistanceMetric metric = config.MetricType();
3892  size_t n = config.ProblemSize();
3893  size_t m = config.LeafNodeSize();
3894  size_t k = config.NeighborSize();
3895  size_t s = config.MaximumRank();
3896 
3898  const bool SYMMETRIC = true;
3899  const bool NNPRUNE = true;
3900  const bool CACHE = true;
3901 
3904  using DATA = gofmm::NodeData<T>;
3905  using TREE = mpitree::Tree<SETUP, DATA>;
3907  using NODE = typename TREE::NODE;
3908  using MPINODE = typename TREE::MPINODE;
3909 
3911  double beg, omptask45_time, omptask_time, ref_time;
3912  double time_ratio, compress_time = 0.0, other_time = 0.0;
3913  double ann_time, tree_time, skel_time, mpi_skel_time, mergefarnodes_time, cachefarnodes_time;
3914  double local_skel_time, dist_skel_time, let_time;
3915  double nneval_time, nonneval_time, fmm_evaluation_time, symbolic_evaluation_time;
3916  double exchange_neighbor_time, symmetrize_time;
3917 
3919  beg = omp_get_wtime();
3920  if ( k && NN_cblk.row() * NN_cblk.col() != k * n )
3921  {
3922  NN_cblk = mpigofmm::FindNeighbors( K, rkdtsplitter,
3923  config, CommGOFMM );
3924  }
3925  ann_time = omp_get_wtime() - beg;
3926 
3928  auto *tree_ptr = new TREE( CommGOFMM );
3929  auto &tree = *tree_ptr;
3930 
3932  tree.setup.FromConfiguration( config, K, splitter, &NN_cblk );
3933 
3935  beg = omp_get_wtime();
3936  tree.TreePartition();
3937  tree_time = omp_get_wtime() - beg;
3938 
3940  vector<size_t> perm = tree.GetPermutation();
3941  if ( rank == 0 )
3942  {
3943  ofstream perm_file( "perm.txt" );
3944  for ( auto &id : perm ) perm_file << id << " ";
3945  perm_file.close();
3946  }
3947 
3948 
3950  DistData<STAR, CIDS, pair<T, size_t>> NN( k, n, tree.treelist[ 0 ]->gids, tree.GetComm() );
3951  NN = NN_cblk;
3952  tree.setup.NN = &NN;
3953  beg = omp_get_wtime();
3954  ExchangeNeighbors<T>( tree );
3955  exchange_neighbor_time = omp_get_wtime() - beg;
3956 
3957 
3958  beg = omp_get_wtime();
3960  FindNearInteractions( tree );
3962  mpigofmm::SymmetrizeNearInteractions( tree );
3964  BuildInteractionListPerRank( tree, true );
3966  ExchangeLET( tree, string( "leafgids" ) );
3967  symmetrize_time = omp_get_wtime() - beg;
3968 
3969 
3971  mpi::PrintProgress( "[BEG] MergeFarNodes ...", tree.GetComm() );
3972  beg = omp_get_wtime();
3973  tree.DependencyCleanUp();
3974  MergeFarNodesTask<true, NODE, T> seqMERGEtask;
3976  tree.LocaTraverseUp( seqMERGEtask );
3977  tree.DistTraverseUp( mpiMERGEtask );
3978  tree.ExecuteAllTasks();
3979  mergefarnodes_time += omp_get_wtime() - beg;
3980  mpi::PrintProgress( "[END] MergeFarNodes ...", tree.GetComm() );
3981 
3983  beg = omp_get_wtime();
3984  mpigofmm::SymmetrizeFarInteractions( tree );
3986  BuildInteractionListPerRank( tree, false );
3987  symmetrize_time += omp_get_wtime() - beg;
3988 
3989  mpi::PrintProgress( "[BEG] Skeletonization ...", tree.GetComm() );
3991  beg = omp_get_wtime();
3992  tree.DependencyCleanUp();
3998  tree.LocaTraverseUp( seqGETMTXtask, seqSKELtask );
3999  //tree.DistTraverseUp( mpiGETMTXtask, mpiSKELtask );
4001  gofmm::InterpolateTask<NODE> seqPROJtask;
4003  tree.LocaTraverseUnOrdered( seqPROJtask );
4004  //tree.DistTraverseUnOrdered( mpiPROJtask );
4005 
4008  //tree.LocaTraverseLeafs( seqNEARKIJtask );
4009 
4010  tree.ExecuteAllTasks();
4011  skel_time = omp_get_wtime() - beg;
4012 
4013  beg = omp_get_wtime();
4014  tree.DistTraverseUp( mpiGETMTXtask, mpiSKELtask );
4015  tree.DistTraverseUnOrdered( mpiPROJtask );
4016  tree.ExecuteAllTasks();
4017  mpi_skel_time = omp_get_wtime() - beg;
4018  mpi::PrintProgress( "[END] Skeletonization ...", tree.GetComm() );
4019 
4020 
4021 
4023  ExchangeLET( tree, string( "skelgids" ) );
4024 
4025  beg = omp_get_wtime();
4027  //mpigofmm::CacheNearNodesTask<NNPRUNE, NODE> seqNEARKIJtask;
4028  //tree.LocaTraverseLeafs( seqNEARKIJtask );
4032  //tree.LocaTraverseUnOrdered( seqFARKIJtask );
4033  //tree.DistTraverseUnOrdered( mpiFARKIJtask );
4034  cachefarnodes_time = omp_get_wtime() - beg;
4035  tree.ExecuteAllTasks();
4036  cachefarnodes_time = omp_get_wtime() - beg;
4037 
4038 
4039 
4041  auto ratio = NonCompressedRatio( tree );
4042 
4043  double exact_ratio = (double) m / n;
4044 
4045  if ( rank == 0 && REPORT_COMPRESS_STATUS )
4046  {
4047  compress_time += ann_time;
4048  compress_time += tree_time;
4049  compress_time += exchange_neighbor_time;
4050  compress_time += symmetrize_time;
4051  compress_time += skel_time;
4052  compress_time += mpi_skel_time;
4053  compress_time += mergefarnodes_time;
4054  compress_time += cachefarnodes_time;
4055  time_ratio = 100.0 / compress_time;
4056  printf( "========================================================\n");
4057  printf( "GOFMM compression phase\n" );
4058  printf( "========================================================\n");
4059  printf( "NeighborSearch ------------------------ %5.2lfs (%5.1lf%%)\n", ann_time, ann_time * time_ratio );
4060  printf( "TreePartitioning ---------------------- %5.2lfs (%5.1lf%%)\n", tree_time, tree_time * time_ratio );
4061  printf( "ExchangeNeighbors --------------------- %5.2lfs (%5.1lf%%)\n", exchange_neighbor_time, exchange_neighbor_time * time_ratio );
4062  printf( "MergeFarNodes ------------------------- %5.2lfs (%5.1lf%%)\n", mergefarnodes_time, mergefarnodes_time * time_ratio );
4063  printf( "Symmetrize ---------------------------- %5.2lfs (%5.1lf%%)\n", symmetrize_time, symmetrize_time * time_ratio );
4064  printf( "Skeletonization (HMLP Runtime ) ----- %5.2lfs (%5.1lf%%)\n", skel_time, skel_time * time_ratio );
4065  printf( "Skeletonization (MPI ) ----- %5.2lfs (%5.1lf%%)\n", mpi_skel_time, mpi_skel_time * time_ratio );
4066  printf( "Cache KIJ ----------------------------- %5.2lfs (%5.1lf%%)\n", cachefarnodes_time, cachefarnodes_time * time_ratio );
4067  printf( "========================================================\n");
4068  printf( "%5.3lf%% and %5.3lf%% uncompressed--------- %5.2lfs (%5.1lf%%)\n",
4069  100 * ratio.first, 100 * ratio.second, compress_time, compress_time * time_ratio );
4070  printf( "========================================================\n\n");
4071  }
4072 
4074  tree_ptr->DependencyCleanUp();
4076  mpi::Barrier( tree.GetComm() );
4077 
4078  return tree_ptr;
4079  }
4080  catch ( const exception & e )
4081  {
4082  cout << e.what() << endl;
4083  exit( 1 );
4084  }
4085 };
4089 template<typename TREE, typename T>
4090 pair<T, T> ComputeError( TREE &tree, size_t gid, Data<T> potentials )
4091 {
4092  int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
4093  int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
4094 
4096  pair<T, T> ret( 0, 0 );
4097 
4098  auto &K = *tree.setup.K;
4099  auto &w = *tree.setup.w;
4100 
4101  auto I = vector<size_t>( 1, gid );
4102  auto &J = tree.treelist[ 0 ]->gids;
4103 
4105  K.BcastIndices( I, gid % comm_size, tree.GetComm() );
4106 
4107  Data<T> Kab = K( I, J );
4108 
4109  auto loc_exact = potentials;
4110  auto glb_exact = potentials;
4111 
4112  xgemm( "N", "N", Kab.row(), w.col(), w.row(),
4113  1.0, Kab.data(), Kab.row(),
4114  w.data(), w.row(),
4115  0.0, loc_exact.data(), loc_exact.row() );
4116  //gemm::xgemm( (T)1.0, Kab, w, (T)0.0, loc_exact );
4117 
4118 
4119 
4120 
4122  mpi::Allreduce( loc_exact.data(), glb_exact.data(),
4123  loc_exact.size(), MPI_SUM, tree.GetComm() );
4124 
4125  for ( uint64_t j = 0; j < w.col(); j ++ )
4126  {
4127  T exac = glb_exact[ j ];
4128  T pred = potentials[ j ];
4130  ret.first += ( pred - exac ) * ( pred - exac );
4131  ret.second += exac * exac;
4132  }
4133 
4134  return ret;
4135 };
4146 template<typename TREE>
4147 void SelfTesting( TREE &tree, size_t ntest, size_t nrhs )
4148 {
4150  using T = typename TREE::T;
4152  int rank; mpi::Comm_rank( tree.GetComm(), &rank );
4153  int size; mpi::Comm_size( tree.GetComm(), &size );
4155  size_t n = tree.n;
4157  if ( ntest > n ) ntest = n;
4159  vector<size_t> all_rhs( nrhs );
4160  for ( size_t rhs = 0; rhs < nrhs; rhs ++ ) all_rhs[ rhs ] = rhs;
4161 
4162  //auto A = tree.CheckAllInteractions();
4163 
4165  DistData<RIDS, STAR, T> w_rids( n, nrhs, tree.treelist[ 0 ]->gids, tree.GetComm() );
4166  DistData<RBLK, STAR, T> u_rblk( n, nrhs, tree.GetComm() );
4168  w_rids.randn();
4170  auto u_rids = mpigofmm::Evaluate<true>( tree, w_rids );
4172  assert( !u_rids.HasIllegalValue() );
4174  u_rblk = u_rids;
4176  if ( rank == 0 )
4177  {
4178  printf( "========================================================\n");
4179  printf( "Accuracy report\n" );
4180  printf( "========================================================\n");
4181  }
4183  T nnerr_avg = 0.0, nonnerr_avg = 0.0, fmmerr_avg = 0.0;
4184  T sse_2norm = 0.0, ssv_2norm = 0.0;
4186  for ( size_t i = 0; i < ntest; i ++ )
4187  {
4188  size_t tar = i * n / ntest;
4189  Data<T> potentials( (size_t)1, nrhs );
4190  if ( rank == ( tar % size ) ) potentials = u_rblk( vector<size_t>( 1, tar ), all_rhs );
4192  mpi::Bcast( potentials.data(), nrhs, tar % size, tree.GetComm() );
4194  auto sse_ssv = mpigofmm::ComputeError( tree, tar, potentials );
4196  auto fmmerr = sqrt( sse_ssv.first / sse_ssv.second );
4198  fmmerr_avg += fmmerr;
4200  sse_2norm += sse_ssv.first;
4201  ssv_2norm += sse_ssv.second;
4203  if ( i < 10 && rank == 0 )
4204  {
4205  printf( "gid %6lu, ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n",
4206  tar, 0.0, 0.0, fmmerr );
4207  }
4208  }
4209  if ( rank == 0 )
4210  {
4211  printf( "========================================================\n");
4212  printf( "Elementwise ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n",
4213  nnerr_avg / ntest , nonnerr_avg / ntest, fmmerr_avg / ntest );
4214  printf( "F-norm ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n",
4215  0.0, 0.0, sqrt( sse_2norm / ssv_2norm ) );
4216  printf( "========================================================\n");
4217  }
4218 
4220  T lambda = 10.0;
4221  mpigofmm::DistFactorize( tree, lambda );
4222  mpigofmm::ComputeError( tree, lambda, w_rids, u_rids );
4223 };
4227 template<typename SPDMATRIX>
4228 void LaunchHelper( SPDMATRIX &K, gofmm::CommandLineHelper &cmd, mpi::Comm CommGOFMM )
4229 {
4230  using T = typename SPDMATRIX::T;
4231  const int N_CHILDREN = 2;
4236  SPLITTER splitter( K );
4237  splitter.Kptr = &K;
4238  splitter.metric = cmd.metric;
4240  RKDTSPLITTER rkdtsplitter( K );
4241  rkdtsplitter.Kptr = &K;
4242  rkdtsplitter.metric = cmd.metric;
4244  gofmm::Configuration<T> config( cmd.metric,
4245  cmd.n, cmd.m, cmd.k, cmd.s, cmd.stol, cmd.budget );
4247  DistData<STAR, CBLK, pair<T, size_t>> NN( 0, cmd.n, CommGOFMM );
4249  auto *tree_ptr = mpigofmm::Compress( K, NN, splitter, rkdtsplitter, config, CommGOFMM );
4250  auto &tree = *tree_ptr;
4251 
4253  mpigofmm::SelfTesting( tree, 100, cmd.nrhs );
4254 
4256  delete tree_ptr;
4257 };
4260 };
4261 };
4263 #endif
This distributed tree inherits the shared memory tree with some additional MPI data structure and fun...
Definition: tree_mpi.hpp:620
void Set(NODE *user_arg, vector< NODE * > user_src, int user_p, Lock *user_lock, int *user_num_arrived_subtasks)
Definition: gofmm_mpi.hpp:1314
Configuration contains all user-defined parameters.
Definition: gofmm.hpp:212
void GetEventRecord()
Definition: gofmm_mpi.hpp:3443
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:2968
void Set(NODE *user_arg, vector< LETNODE * > user_src, int user_p, Lock *user_lock, int *user_num_arrived_subtasks)
Definition: gofmm_mpi.hpp:955
void DependencyAnalysis()
Definition: gofmm_mpi.hpp:2802
Definition: gofmm_mpi.hpp:2445
Definition: mpi_prototypes.h:81
Definition: runtime.hpp:331
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:779
Definition: hmlp_mpi.hpp:152
This the main splitter used to build the Spd-Askit tree. First compute the approximate center using s...
Definition: gofmm_mpi.hpp:408
Definition: DistData.hpp:249
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:3494
Definition: gofmm_mpi.hpp:2910
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:3432
void hmlp_msg_dependency_analysis(int key, int p, ReadWriteType type, Task *task)
Definition: runtime.cpp:1485
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:2790
Definition: runtime.hpp:380
This task creates an hierarchical tree view for w<RIDS> and u<RIDS>.
Definition: gofmm.hpp:414
Definition: gofmm.hpp:1867
size_t row()
Definition: View.hpp:345
void Execute(Worker *user_worker)
Definition: gofmm_mpi.hpp:1122
Definition: gofmm_mpi.hpp:2380
Definition: gofmm_mpi.hpp:1298
This task creates an hierarchical tree view for weights<RIDS> and potentials<RIDS>.
Definition: gofmm_mpi.hpp:259
Definition: gofmm_mpi.hpp:2784
void Execute(Worker *user_worker)
Definition: gofmm_mpi.hpp:1353
void DependencyAnalysis()
Definition: gofmm_mpi.hpp:2928
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
Definition: gofmm_mpi.hpp:3014
Data and setup that are shared with all nodes.
Definition: tree_mpi.hpp:260
Definition: gofmm_mpi.hpp:1053
Definition: gofmm.hpp:1168
void Execute(Worker *user_worker)
Definition: gofmm_mpi.hpp:276
size_t col()
Definition: View.hpp:348
T * data()
Definition: View.hpp:354
PackFarTask(TREE *tree, int src, int tar, int key)
Definition: gofmm_mpi.hpp:2449
UnpackLeafTask(TREE *tree, int src, int tar, int key)
Definition: gofmm_mpi.hpp:2426
void DependencyAnalysis()
Definition: gofmm_mpi.hpp:274
void Execute(Worker *user_worker)
Definition: gofmm_mpi.hpp:3587
void Execute(Worker *user_worker)
Definition: gofmm_mpi.hpp:3536
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:3317
Definition: gofmm_mpi.hpp:1436
This the splitter used in the randomized tree.
Definition: gofmm.hpp:658
Ecah MPI process own ( rids.size() ) rows of A, and rids denote the distribution. i...
Definition: DistData.hpp:1219
void Pack()
Definition: gofmm_mpi.hpp:2399
PackNearTask(TREE *tree, int src, int tar, int key)
Definition: gofmm_mpi.hpp:2384
Wrapper for omp or pthread mutex.
Definition: tci.hpp:50
Definition: gofmm_mpi.hpp:1221
void GetEventRecord()
Definition: gofmm_mpi.hpp:3506
size_t row()
Definition: DistData.hpp:217
size_t row() const noexcept
Definition: Data.hpp:278
Definition: gofmm_mpi.hpp:2422
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:1067
void Pack()
Definition: gofmm_mpi.hpp:2464
Notice that NODE here is MPITree::Node.
Definition: gofmm_mpi.hpp:773
Notice that S2S depends on all Far interactions, which may include local tree nodes or let nodes...
Definition: gofmm_mpi.hpp:941
These are data that shared by the whole local tree. Distributed setup inherits mpitree::Setup.
Definition: gofmm_mpi.hpp:209
Definition: gofmm_mpi.hpp:2475
Definition: gofmm.hpp:738
void Acquire()
Definition: tci.cpp:53
Definition: View.hpp:43
void Execute(Worker *user_worker)
Definition: gofmm_mpi.hpp:996
size_t ld()
Definition: View.hpp:351
The correponding task of Interpolate().
Definition: gofmm.hpp:945
DistanceMetric metric
Definition: gofmm.hpp:193
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:3020
This is a helper class that parses the arguments from command lines.
Definition: gofmm.hpp:89
void FromConfiguration(gofmm::Configuration< T > &config, SPDMATRIX &K, SPLITTER &splitter, DistData< STAR, CBLK, pair< T, size_t >> *NN_cblk)
Definition: gofmm_mpi.hpp:215
Definition: DistData.hpp:156
size_t n
Definition: gofmm.hpp:185
Definition: gofmm_mpi.hpp:3488
UnpackFarTask(TREE *tree, int src, int tar, int key)
Definition: gofmm_mpi.hpp:2479
Definition: gofmm.hpp:1470
Definition: gofmm_mpi.hpp:2962
Ecah MPI process own ( n / size ) rows of A in a cyclic fashion (Round Robin). i.e. If there are 3 MPI processes, then.
Definition: DistData.hpp:619
vector< NODE * > Sources
Definition: gofmm_mpi.hpp:1305
double stol
Definition: gofmm.hpp:190
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:1227
Definition: gofmm_mpi.hpp:3570
This the main splitter used to build the Spd-Askit tree. First compute the approximate center using s...
Definition: gofmm.hpp:580
Definition: gofmm_mpi.hpp:3426
Definition: gofmm.hpp:83
This class contains all GOFMM related data carried by a tree node.
Definition: gofmm.hpp:345
Definition: gofmm_mpi.hpp:3311
Definition: runtime.hpp:174
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:1450
void DependencyAnalysis()
Definition: gofmm_mpi.hpp:1345
Definition: gofmm_mpi.hpp:546
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:2916
Definition: thread.hpp:166