HMLP: High-performance Machine Learning Primitives
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Pages
tree_mpi.hpp
1 
22 #ifndef MPITREE_HPP
23 #define MPITREE_HPP
24 
26 #include <tree.hpp>
28 //#include <DistData.hpp>
30 using namespace std;
31 using namespace hmlp;
32 
33 
34 namespace hmlp
35 {
36 namespace mpitree
37 {
38 
39 
40 
42 // * @brief This is the default ball tree splitter. Given coordinates,
43 // * compute the direction from the two most far away points.
44 // * Project all points to this line and split into two groups
45 // * using a median select.
46 // *
47 // * @para
48 // *
49 // * @TODO Need to explit the parallelism.
50 // */
51 //template<int N_SPLIT, typename T>
52 //struct centersplit
53 //{
54 // // closure
55 // Data<T> *Coordinate;
56 //
57 // inline vector<vector<size_t> > operator()
58 // (
59 // vector<size_t>& gids
60 // ) const
61 // {
62 // assert( N_SPLIT == 2 );
63 //
64 // Data<T> &X = *Coordinate;
65 // size_t d = X.row();
66 // size_t n = gids.size();
67 //
68 // T rcx0 = 0.0, rx01 = 0.0;
69 // size_t x0, x1;
70 // vector<vector<size_t> > split( N_SPLIT );
71 //
72 //
73 // vector<T> centroid = combinatorics::Mean( d, n, X, gids );
74 // vector<T> direction( d );
75 // vector<T> projection( n, 0.0 );
76 //
77 // //printf( "After Mean\n" );
78 //
79 // // Compute the farest x0 point from the centroid
80 // for ( int i = 0; i < n; i ++ )
81 // {
82 // T rcx = 0.0;
83 // for ( int p = 0; p < d; p ++ )
84 // {
85 // T tmp = X[ gids[ i ] * d + p ] - centroid[ p ];
86 // rcx += tmp * tmp;
87 // }
88 // //printf( "\n" );
89 // if ( rcx > rcx0 )
90 // {
91 // rcx0 = rcx;
92 // x0 = i;
93 // }
94 // }
95 //
96 //
97 // // Compute the farest point x1 from x0
98 // for ( int i = 0; i < n; i ++ )
99 // {
100 // T rxx = 0.0;
101 // for ( int p = 0; p < d; p ++ )
102 // {
103 // T tmp = X[ gids[ i ] * d + p ] - X[ gids[ x0 ] * d + p ];
104 // rxx += tmp * tmp;
105 // }
106 // if ( rxx > rx01 )
107 // {
108 // rx01 = rxx;
109 // x1 = i;
110 // }
111 // }
112 //
113 // // Compute direction
114 // for ( int p = 0; p < d; p ++ )
115 // {
116 // direction[ p ] = X[ gids[ x1 ] * d + p ] - X[ gids[ x0 ] * d + p ];
117 // }
118 //
119 // // Compute projection
120 // projection.resize( n, 0.0 );
121 // for ( int i = 0; i < n; i ++ )
122 // for ( int p = 0; p < d; p ++ )
123 // projection[ i ] += X[ gids[ i ] * d + p ] * direction[ p ];
124 //
125 // /** Parallel median search */
126 // T median;
127 //
128 // if ( 1 )
129 // {
130 // median = hmlp::combinatorics::Select( n, n / 2, projection );
131 // }
132 // else
133 // {
134 // auto proj_copy = projection;
135 // std::sort( proj_copy.begin(), proj_copy.end() );
136 // median = proj_copy[ n / 2 ];
137 // }
138 //
139 // split[ 0 ].reserve( n / 2 + 1 );
140 // split[ 1 ].reserve( n / 2 + 1 );
141 //
142 //
143 // /** TODO: Can be parallelized */
144 // std::vector<std::size_t> middle;
145 // for ( size_t i = 0; i < n; i ++ )
146 // {
147 // if ( projection[ i ] < median ) split[ 0 ].push_back( i );
148 // else if ( projection[ i ] > median ) split[ 1 ].push_back( i );
149 // else middle.push_back( i );
150 // }
151 //
152 // for ( size_t i = 0; i < middle.size(); i ++ )
153 // {
154 // if ( split[ 0 ].size() <= split[ 1 ].size() ) split[ 0 ].push_back( middle[ i ] );
155 // else split[ 1 ].push_back( middle[ i ] );
156 // }
157 //
158 //
159 // return split;
160 // };
161 //
162 //
163 // inline std::vector<std::vector<size_t> > operator()
164 // (
165 // std::vector<size_t>& gids,
166 // hmlp::mpi::Comm comm
167 // ) const
168 // {
169 // std::vector<std::vector<size_t> > split( N_SPLIT );
170 //
171 // return split;
172 // };
173 //
174 //};
175 //
176 //
177 //
178 //
179 //
180 //template<int N_SPLIT, typename T>
181 //struct randomsplit
182 //{
183 // Data<T> *Coordinate = NULL;
184 //
185 // inline vector<vector<size_t> > operator() ( vector<size_t>& gids ) const
186 // {
187 // vector<vector<size_t> > split( N_SPLIT );
188 // return split;
189 // };
190 //
191 // inline vector<vector<size_t> > operator() ( vector<size_t>& gids, mpi::Comm comm ) const
192 // {
193 // vector<vector<size_t> > split( N_SPLIT );
194 // return split;
195 // };
196 //};
197 //
198 
199 
200 
201 
202 
203 template<typename NODE>
204 class DistSplitTask : public Task
205 {
206  public:
207 
208  NODE *arg = NULL;
209 
210  void Set( NODE *user_arg )
211  {
212  arg = user_arg;
213  name = string( "DistSplit" );
214  label = to_string( arg->treelist_id );
215 
216  double flops = 6.0 * arg->n;
217  double mops = 6.0 * arg->n;
218 
220  event.Set( label + name, flops, mops );
222  cost = mops / 1E+9;
224  priority = true;
225  };
226 
227 
228  void DependencyAnalysis()
229  {
230  arg->DependencyAnalysis( R, this );
231 
232  if ( !arg->isleaf )
233  {
234  if ( arg->GetCommSize() > 1 )
235  {
236  assert( arg->child );
237  arg->child->DependencyAnalysis( RW, this );
238  }
239  else
240  {
241  assert( arg->lchild && arg->rchild );
242  arg->lchild->DependencyAnalysis( RW, this );
243  arg->rchild->DependencyAnalysis( RW, this );
244  }
245  }
246  this->TryEnqueue();
247  };
248 
249  void Execute( Worker* user_worker ) { arg->Split(); };
250 
251 };
259 template<typename SPLITTER, typename DATATYPE>
260 class Setup
261 {
262  public:
263 
264  typedef DATATYPE T;
265 
266  Setup() {};
267 
268  ~Setup() {};
269 
270 
271 
272 
279  vector<size_t> ContainAny( vector<size_t> &queries, size_t target )
280  {
281  vector<size_t> validation( queries.size(), 0 );
282 
283  if ( !morton.size() )
284  {
285  printf( "Morton id was not initialized.\n" );
286  exit( 1 );
287  }
288 
289  for ( size_t i = 0; i < queries.size(); i ++ )
290  {
292  //auto it = this->setup->morton.find( queries[ i ] );
293 
294  //if ( it != this->setup->morton.end() )
295  //{
296  // if ( tree::IsMyParent( *it, this->morton ) ) validation[ i ] = 1;
297  //}
298 
299 
300  if ( MortonHelper::IsMyParent( morton[ queries[ i ] ], target ) )
301  validation[ i ] = 1;
302 
303  }
304  return validation;
305 
306  };
312  size_t m;
313 
315  size_t max_depth = 15;
316 
318  //DistData<STAR, CBLK, T> *X_cblk = NULL;
319  //DistData<STAR, CIDS, T> *X = NULL;
320 
324 
326  vector<size_t> morton;
327 
329  SPLITTER splitter;
330 
331 };
335 template<typename NODE>
337 {
338  public:
339 
340  NODE *arg = NULL;
341 
342  void Set( NODE *user_arg )
343  {
344  name = std::string( "Permutation" );
345  arg = user_arg;
346  // Need an accurate cost model.
347  cost = 1.0;
348  };
349 
350  void DependencyAnalysis() { arg->DependOnChildren( this ); };
351  //{
352  // arg->DependencyAnalysis( hmlp::ReadWriteType::RW, this );
353  // if ( !arg->isleaf && !arg->child )
354  // {
355  // arg->lchild->DependencyAnalysis( hmlp::ReadWriteType::R, this );
356  // arg->rchild->DependencyAnalysis( hmlp::ReadWriteType::R, this );
357  // }
358  // this->TryEnqueue();
359  //};
360 
361 
362  void Execute( Worker* user_worker )
363  {
364  if ( !arg->isleaf && !arg->child )
365  {
366  auto &gids = arg->gids;
367  auto &lgids = arg->lchild->gids;
368  auto &rgids = arg->rchild->gids;
369  gids = lgids;
370  gids.insert( gids.end(), rgids.begin(), rgids.end() );
371  }
372  };
373 
374 };
394 //template<typename SETUP, int N_CHILDREN, typename NODEDATA>
395 template<typename SETUP, typename NODEDATA>
396 class Node : public tree::Node<SETUP, NODEDATA>
397 {
398  public:
399 
401  typedef typename SETUP::T T;
402 
403  static const int N_CHILDREN = 2;
404 
409  Node( SETUP *setup, size_t n, size_t l,
410  Node *parent,
411  unordered_map<size_t, tree::Node<SETUP, NODEDATA>*> *morton2node,
412  Lock *treelock, mpi::Comm comm )
413  : tree::Node<SETUP, NODEDATA>( setup, n, l,
414  parent, morton2node, treelock )
415  {
417  this->comm = comm;
419  mpi::Comm_size( comm, &size );
420  mpi::Comm_rank( comm, &rank );
421  };
422 
424  Node( SETUP *setup, size_t n, size_t l, vector<size_t> &gids,
425  Node *parent,
426  unordered_map<size_t, tree::Node<SETUP, NODEDATA>*> *morton2node,
427  Lock *treelock, mpi::Comm comm )
428  : Node<SETUP, NODEDATA>( setup, n, l, parent,
429  morton2node, treelock, comm )
430  {
432  this->gids = gids;
433  };
434 
436  Node( size_t morton ) : tree::Node<SETUP, NODEDATA>( morton )
437  {
438  };
439 
440 
441  //void SetupChild( class Node *child )
442  //{
443  // this->kids[ 0 ] = child;
444  // this->child = child;
445  //};
446 
447 
448 
450  void Split()
451  {
453  int num_points_total = 0;
454  int num_points_owned = (this->gids).size();
456  mpi::Allreduce( &num_points_owned, &num_points_total, 1, MPI_SUM, comm );
457  this->n = num_points_total;
458 
459  if ( child )
460  {
462  assert( size > 1 );
463 
465  auto split = this->setup->splitter( this->gids, comm );
466 
468  int partner_rank = 0;
469  int sent_size = 0;
470  int recv_size = 0;
471  vector<size_t> &kept_gids = child->gids;
472  vector<int> sent_gids;
473  vector<int> recv_gids;
474 
475  if ( rank < size / 2 )
476  {
478  partner_rank = rank + size / 2;
480  kept_gids.resize( split[ 0 ].size() );
481  for ( size_t i = 0; i < kept_gids.size(); i ++ )
482  kept_gids[ i ] = this->gids[ split[ 0 ][ i ] ];
484  sent_gids.resize( split[ 1 ].size() );
485  sent_size = sent_gids.size();
486  for ( size_t i = 0; i < sent_gids.size(); i ++ )
487  sent_gids[ i ] = this->gids[ split[ 1 ][ i ] ];
488  }
489  else
490  {
492  partner_rank = rank - size / 2;
494  kept_gids.resize( split[ 1 ].size() );
495  for ( size_t i = 0; i < kept_gids.size(); i ++ )
496  kept_gids[ i ] = this->gids[ split[ 1 ][ i ] ];
498  sent_gids.resize( split[ 0 ].size() );
499  sent_size = sent_gids.size();
500  for ( size_t i = 0; i < sent_gids.size(); i ++ )
501  sent_gids[ i ] = this->gids[ split[ 0 ][ i ] ];
502  }
503  assert( partner_rank >= 0 );
504 
505 
506 
507 
508 
510  //mpi::Sendrecv( &sent_size, 1, partner_rank, 10,
511  // &recv_size, 1, partner_rank, 10, comm, &status );
512 
513  //printf( "rank %d kept_size %lu sent_size %d recv_size %d\n",
514  // rank, kept_gids.size(), sent_size, recv_size ); fflush( stdout );
515 
517  //recv_gids.resize( recv_size );
518 
520  //mpi::Sendrecv(
521  // sent_gids.data(), sent_size, MPI_INT, partner_rank, 20,
522  // recv_gids.data(), recv_size, MPI_INT, partner_rank, 20,
523  // comm, &status );
524 
526  mpi::ExchangeVector( sent_gids, partner_rank, 20,
527  recv_gids, partner_rank, 20, comm, &status );
529  for ( auto it : recv_gids ) kept_gids.push_back( it );
530  //kept_gids.reserve( kept_gids.size() + recv_gids.size() );
531  //for ( size_t i = 0; i < recv_gids.size(); i ++ )
532  // kept_gids.push_back( recv_gids[ i ] );
533 
534 
535  }
536  else
537  {
539  }
541  mpi::Barrier( comm );
542  };
546  void DependOnChildren( Task *task )
547  {
548  this->DependencyAnalysis( RW, task );
549  if ( size < 2 )
550  {
551  if ( this->lchild ) this->lchild->DependencyAnalysis( R, task );
552  if ( this->rchild ) this->rchild->DependencyAnalysis( R, task );
553  }
554  else
555  {
556  if ( child ) child->DependencyAnalysis( R, task );
557  }
559  task->TryEnqueue();
560  };
563  void DependOnParent( Task *task )
564  {
565  this->DependencyAnalysis( R, task );
566  if ( size < 2 )
567  {
568  if ( this->lchild ) this->lchild->DependencyAnalysis( RW, task );
569  if ( this->rchild ) this->rchild->DependencyAnalysis( RW, task );
570  }
571  else
572  {
573  if ( child ) child->DependencyAnalysis( RW, task );
574  }
576  task->TryEnqueue();
577  };
581  void Print() {};
582 
584  mpi::Comm GetComm() { return comm; };
586  int GetCommSize() { return size; };
588  int GetCommRank() { return rank; };
589 
591  Node *child = NULL;
592 
593  private:
594 
596  mpi::Comm comm = MPI_COMM_WORLD;
598  mpi::Status status;
600  int size = 1;
602  int rank = 0;
603 
604 };
619 template<class SETUP, class NODEDATA>
620 class Tree : public tree::Tree<SETUP, NODEDATA>,
621  public mpi::MPIObject
622 {
623  public:
624 
625  typedef typename SETUP::T T;
626 
648 
651 
658  vector<MPINODE*> mpitreelists;
659 
660 
662  Tree( mpi::Comm comm ) : tree::Tree<SETUP, NODEDATA>::Tree(),
663  mpi::MPIObject( comm )
664  {
665  //this->comm = comm;
667  //mpi::Comm_size( comm, &size );
668  //mpi::Comm_rank( comm, &rank );
670  //NearRecvFrom.resize( size );
671  NearRecvFrom.resize( this->GetCommSize() );
672  //FarRecvFrom.resize( size );
673  FarRecvFrom.resize( this->GetCommSize() );
674  };
675 
678  {
679  //printf( "~Tree() distributed, mpitreelists.size() %lu\n",
680  // mpitreelists.size() ); fflush( stdout );
685  if ( mpitreelists.size() )
686  {
687  for ( size_t i = 0; i < mpitreelists.size() - 1; i ++ )
688  if ( mpitreelists[ i ] ) delete mpitreelists[ i ];
689  mpitreelists.clear();
690  }
691  //printf( "end ~Tree() distributed\n" ); fflush( stdout );
692  };
693 
694 
699  void CleanUp()
700  {
702  if ( this->treelist.size() )
703  {
704  for ( size_t i = 0; i < this->treelist.size(); i ++ )
705  if ( this->treelist[ i ] ) delete this->treelist[ i ];
706  }
707  this->treelist.clear();
708 
710  if ( mpitreelists.size() )
711  {
712  for ( size_t i = 0; i < mpitreelists.size() - 1; i ++ )
713  if ( mpitreelists[ i ] ) delete mpitreelists[ i ];
714  }
715  mpitreelists.clear();
716 
717  };
721  void AllocateNodes( vector<size_t> &gids )
722  {
724  //auto mycomm = comm;
725  auto mycomm = this->GetComm();
726  //int mysize = size;
727  int mysize = this->GetCommSize();
728  //int myrank = rank;
729  int myrank = this->GetCommRank();
730  int mycolor = 0;
731  size_t mylevel = 0;
732 
734  auto *root = new MPINODE( &(this->setup),
735  this->n, mylevel, gids, NULL,
736  &(this->morton2node), &(this->lock), mycomm );
737 
739  mpitreelists.push_back( root );
740 
742  while ( mysize > 1 )
743  {
744  mpi::Comm childcomm;
745 
747  mylevel += 1;
749  mycolor = ( myrank < mysize / 2 ) ? 0 : 1;
751  ierr = mpi::Comm_split( mycomm, mycolor, myrank, &(childcomm) );
753  mycomm = childcomm;
754  mpi::Comm_size( mycomm, &mysize );
755  mpi::Comm_rank( mycomm, &myrank );
756 
758  auto *parent = mpitreelists.back();
759  auto *child = new MPINODE( &(this->setup),
760  (size_t)0, mylevel, parent,
761  &(this->morton2node), &(this->lock), mycomm );
762 
764  child->sibling = new NODE( (size_t)0 ); // Node morton is computed later.
766  //parent->SetupChild( child );
767  parent->kids[ 0 ] = child;
768  parent->child = child;
770  mpitreelists.push_back( child );
771  }
773  this->Barrier();
774 
776  auto *local_tree_root = mpitreelists.back();
778 
779  };
784  vector<size_t> GetPermutation()
785  {
786  vector<size_t> perm_loc, perm_glb;
788  mpi::GatherVector( perm_loc, perm_glb, 0, this->GetComm() );
789 
790  //if ( rank == 0 )
791  //{
792  // /** Sanity check using an 0:N-1 table. */
793  // vector<bool> Table( this->n, false );
794  // for ( size_t i = 0; i < perm_glb.size(); i ++ )
795  // Table[ perm_glb[ i ] ] = true;
796  // for ( size_t i = 0; i < Table.size(); i ++ ) assert( Table[ i ] );
797  //}
798 
799  return perm_glb;
800  };
807  template<typename KNNTASK>
809  AllNearestNeighbor( size_t n_tree, size_t n, size_t k,
810  pair<T, size_t> initNN, KNNTASK &dummy )
811  {
812  mpi::PrintProgress( "[BEG] NeighborSearch ...", this->GetComm() );
813 
815  this->n = n;
817  DistData<STAR, CBLK, pair<T, size_t>> NN( k, n, initNN, this->GetComm() );
819  this->setup.m = 4 * k;
820  if ( this->setup.m < 512 ) this->setup.m = 512;
821  this->m = this->setup.m;
822 
823 
826  //num_points_owned = ( n - 1 ) / this->GetCommSize() + 1;
827 
829  //if ( n % this->GetCommSize() )
830  //{
831  // //if ( rank >= ( n % size ) ) num_points_owned -= 1;
832  // if ( this->GetCommRank() >= ( n % this->GetCommSize() ) )
833  // num_points_owned -= 1;
834  //}
835 
837  num_points_owned = n / this->GetCommSize();
839  if ( this->GetCommRank() < ( n % this->GetCommSize() ) )
840  num_points_owned += 1;
841 
842 
843 
845  vector<size_t> gids( num_points_owned, 0 );
846  for ( size_t i = 0; i < num_points_owned; i ++ )
847  //gids[ i ] = i * size + rank;
848  gids[ i ] = i * this->GetCommSize() + this->GetCommRank();
849 
851  AllocateNodes( gids );
852 
853 
855  DistSplitTask<MPINODE> mpisplittask;
856  tree::SplitTask<NODE> seqsplittask;
857  for ( size_t t = 0; t < n_tree; t ++ )
858  {
859  DistTraverseDown( mpisplittask );
860  LocaTraverseDown( seqsplittask );
861  ExecuteAllTasks();
862 
864  DistData<STAR, CIDS, pair<T, size_t>> Q_cids( k, this->n, this->treelist[ 0 ]->gids, initNN, this->GetComm() );
866  this->setup.NN = &Q_cids;
867  LocaTraverseLeafs( dummy );
868  ExecuteAllTasks();
869 
871  DistData<STAR, CBLK, pair<T, size_t>> Q_cblk( k, this->n, this->GetComm() );
873  Q_cblk = Q_cids;
875  assert( Q_cblk.col_owned() == NN.col_owned() );
876  MergeNeighbors( k, NN.col_owned(), NN, Q_cblk );
877  }
878 
879 
880 
881 
882 
883 
884 
885 // /** Metric tree partitioning. */
886 // DistSplitTask<MPINODE> mpisplittask;
887 // tree::SplitTask<NODE> seqsplittask;
888 // DependencyCleanUp();
889 // DistTraverseDown( mpisplittask );
890 // LocaTraverseDown( seqsplittask );
891 // ExecuteAllTasks();
892 //
893 //
894 // for ( size_t t = 0; t < n_tree; t ++ )
895 // {
896 // this->Barrier();
897 // //if ( this->GetCommRank() == 0 ) printf( "Iteration #%lu\n", t );
898 //
899 // /** Query neighbors computed in CIDS distribution. */
900 // DistData<STAR, CIDS, pair<T, size_t>> Q_cids( k, this->n,
901 // this->treelist[ 0 ]->gids, initNN, this->GetComm() );
902 // /** Pass in neighbor pointer. */
903 // this->setup.NN = &Q_cids;
904 // /** Overlap */
905 // if ( t != n_tree - 1 )
906 // {
907 // //DependencyCleanUp();
908 // DistTraverseDown( mpisplittask );
909 // ExecuteAllTasks();
910 // }
911 // mpi::PrintProgress( "[MID] Here ...", this->GetComm() );
912 // DependencyCleanUp();
913 // LocaTraverseLeafs( dummy );
914 // LocaTraverseDown( seqsplittask );
915 // ExecuteAllTasks();
916 // mpi::PrintProgress( "[MID] Here 22...", this->GetComm() );
917 //
918 // if ( t == 0 )
919 // {
920 // /** Redistribute from CIDS to CBLK */
921 // NN = Q_cids;
922 // }
923 // else
924 // {
925 // /** Queries computed in CBLK distribution */
926 // DistData<STAR, CBLK, pair<T, size_t>> Q_cblk( k, this->n, this->GetComm() );
927 // /** Redistribute from CIDS to CBLK */
928 // Q_cblk = Q_cids;
929 // /** Merge Q_cblk into NN (sort and remove duplication) */
930 // assert( Q_cblk.col_owned() == NN.col_owned() );
931 // MergeNeighbors( k, NN.col_owned(), NN, Q_cblk );
932 // }
933 //
934 // //double mer_time = omp_get_wtime() - beg;
935 //
936 // //if ( rank == 0 )
937 // //printf( "%lfs %lfs %lfs\n", mpi_time, seq_time, mer_time ); fflush( stdout );
938 // }
939 
941  for ( auto &neig : NN )
942  {
943  if ( neig.second < 0 || neig.second >= NN.col() )
944  {
945  printf( "Illegle neighbor gid %lu\n", neig.second );
946  break;
947  }
948  }
949 
950  mpi::PrintProgress( "[END] NeighborSearch ...", this->GetComm() );
951  return NN;
952  };
959  {
960  mpi::PrintProgress( "[BEG] TreePartitioning ...", this->GetComm() );
961 
963  this->n = this->setup.ProblemSize();
964  this->m = this->setup.LeafNodeSize();
965 
967  //for ( size_t i = rank; i < this->n; i += size )
968  for ( size_t i = this->GetCommRank(); i < this->n; i += this->GetCommSize() )
969  this->global_indices.push_back( i );
971  num_points_owned = this->global_indices.size();
973  AllocateNodes( this->global_indices );
974 
975 
976 
977  DependencyCleanUp();
978 
979 
980 
981 
982  DistSplitTask<MPINODE> mpiSPLITtask;
983  tree::SplitTask<NODE> seqSPLITtask;
984  DistTraverseDown( mpiSPLITtask );
985  LocaTraverseDown( seqSPLITtask );
986  ExecuteAllTasks();
987 
988 
989 
990  tree::IndexPermuteTask<NODE> seqINDXtask;
991  LocaTraverseUp( seqINDXtask );
992  DistIndexPermuteTask<MPINODE> mpiINDXtask;
993  DistTraverseUp( mpiINDXtask );
994  ExecuteAllTasks();
995 
996  //printf( "rank %d finish split\n", rank ); fflush( stdout );
997 
998 
1000  (this->setup).morton.resize( this->n );
1001 
1003  RecursiveMorton( mpitreelists[ 0 ], MortonHelper::Root() );
1004 
1006  this->morton2node.clear();
1007 
1009  for ( auto node : this->treelist ) this->morton2node[ node->morton ] = node;
1010 
1012  for ( auto node : this->mpitreelists )
1013  {
1014  this->morton2node[ node->morton ] = node;
1015  auto *sibling = node->sibling;
1016  if ( node->l ) this->morton2node[ sibling->morton ] = sibling;
1017  }
1018 
1019  this->Barrier();
1020  mpi::PrintProgress( "[END] TreePartitioning ...", this->GetComm() );
1021  };
1025  void RecursiveMorton( MPINODE *node, MortonHelper::Recursor r )
1026  {
1028  int comm_size = this->GetCommSize();
1029  int comm_rank = this->GetCommRank();
1030  int node_size = node->GetCommSize();
1031  int node_rank = node->GetCommRank();
1032 
1034  node->morton = MortonHelper::MortonID( r );
1036  if ( node->sibling )
1037  node->sibling->morton = MortonHelper::SiblingMortonID( r );
1038 
1039  if ( node_size < 2 )
1040  {
1044  auto &gids = this->treelist[ 0 ]->gids;
1045  vector<int> recv_size( comm_size, 0 );
1046  vector<int> recv_disp( comm_size, 0 );
1047  vector<pair<size_t, size_t>> send_pairs;
1048  vector<pair<size_t, size_t>> recv_pairs( this->n );
1049 
1051  for ( auto it : gids )
1052  {
1053  send_pairs.push_back(
1054  pair<size_t, size_t>( it, this->setup.morton[ it ]) );
1055  }
1056 
1058  int send_size = send_pairs.size();
1059  mpi::Allgather( &send_size, 1, recv_size.data(), 1, this->GetComm() );
1061  for ( size_t p = 1; p < comm_size; p ++ )
1062  {
1063  recv_disp[ p ] = recv_disp[ p - 1 ] + recv_size[ p - 1 ];
1064  }
1066  size_t total_gids = 0;
1067  for ( size_t p = 0; p < comm_size; p ++ )
1068  {
1069  total_gids += recv_size[ p ];
1070  }
1071  assert( total_gids == this->n );
1073  mpi::Allgatherv( send_pairs.data(), send_size,
1074  recv_pairs.data(), recv_size.data(), recv_disp.data(), this->GetComm() );
1076  for ( auto it : recv_pairs ) this->setup.morton[ it.first ] = it.second;
1077  }
1078  else
1079  {
1080  if ( node_rank < node_size / 2 )
1081  {
1082  RecursiveMorton( node->child, MortonHelper::RecurLeft( r ) );
1083  }
1084  else
1085  {
1086  RecursiveMorton( node->child, MortonHelper::RecurRight( r ) );
1087  }
1088  }
1089  };
1095  {
1097  int total_depth = this->treelist.back()->l;
1099  int num_leafs = 1 << total_depth;
1101  Data<int> A( num_leafs, num_leafs, 0 );
1102  Data<int> B( num_leafs, num_leafs, 0 );
1104  for ( int t = 1; t < this->treelist.size(); t ++ )
1105  {
1106  auto *node = this->treelist[ t ];
1108  //for ( auto it : node->NNNearNodeMortonIDs )
1109  //{
1110  // auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1111  // auto J = MortonHelper::Morton2Offsets( it, total_depth );
1112  // for ( auto i : I ) for ( auto j : J ) A( i, j ) += 1;
1113  //}
1115  //for ( auto it : node->NNFarNodeMortonIDs )
1116  //{
1117  // auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1118  // auto J = MortonHelper::Morton2Offsets( it, total_depth );
1119  // for ( auto i : I ) for ( auto j : J ) A( i, j ) += 1;
1120  //}
1121 
1122  for ( int p = 0; p < this->GetCommSize(); p ++ )
1123  {
1124  if ( node->isleaf )
1125  {
1126  for ( auto & it : node->DistNear[ p ] )
1127  {
1128  auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1129  auto J = MortonHelper::Morton2Offsets( it.first, total_depth );
1130  for ( auto i : I ) for ( auto j : J )
1131  {
1132  assert( i < num_leafs && j < num_leafs );
1133  A( i, j ) += 1;
1134  }
1135  }
1136  }
1137  for ( auto & it : node->DistFar[ p ] )
1138  {
1139  auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1140  auto J = MortonHelper::Morton2Offsets( it.first, total_depth );
1141  for ( auto i : I ) for ( auto j : J )
1142  {
1143  assert( i < num_leafs && j < num_leafs );
1144  A( i, j ) += 1;
1145  }
1146  }
1147  }
1148  }
1149 
1150  for ( auto *node : mpitreelists )
1151  {
1153  //for ( auto it : node->NNNearNodeMortonIDs )
1154  //{
1155  // auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1156  // auto J = MortonHelper::Morton2Offsets( it, total_depth );
1157  // for ( auto i : I ) for ( auto j : J ) A( i, j ) += 1;
1158  //}
1160  //for ( auto it : node->NNFarNodeMortonIDs )
1161  //{
1162  // auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1163  // auto J = MortonHelper::Morton2Offsets( it, total_depth );
1164  // for ( auto i : I ) for ( auto j : J ) A( i, j ) += 1;
1165  //}
1166  for ( int p = 0; p < this->GetCommSize(); p ++ )
1167  {
1168  if ( node->isleaf )
1169  {
1170  for ( auto & it : node->DistNear[ p ] )
1171  {
1172  auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1173  auto J = MortonHelper::Morton2Offsets( it.first, total_depth );
1174  for ( auto i : I ) for ( auto j : J )
1175  {
1176  assert( i < num_leafs && j < num_leafs );
1177  A( i, j ) += 1;
1178  }
1179  }
1180  }
1181  for ( auto & it : node->DistFar[ p ] )
1182  {
1183  auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1184  auto J = MortonHelper::Morton2Offsets( it.first, total_depth );
1185  for ( auto i : I ) for ( auto j : J )
1186  {
1187  assert( i < num_leafs && j < num_leafs );
1188  A( i, j ) += 1;
1189  }
1190  }
1191  }
1192  }
1193 
1195  mpi::Reduce( A.data(), B.data(), A.size(), MPI_SUM, 0, this->GetComm() );
1196 
1197  if ( this->GetCommRank() == 0 )
1198  {
1199  for ( size_t i = 0; i < num_leafs; i ++ )
1200  {
1201  for ( size_t j = 0; j < num_leafs; j ++ ) printf( "%d", B( i, j ) );
1202  printf( "\n" );
1203  }
1204  }
1205 
1206  return B;
1207  };
1222  int Morton2Rank( size_t it )
1223  {
1224  return MortonHelper::Morton2Rank( it, this->GetCommSize() );
1225  };
1227  int Index2Rank( size_t gid )
1228  {
1229  return Morton2Rank( this->setup.morton[ gid ] );
1230  };
1236  template<typename TASK, typename... Args>
1237  void LocaTraverseUp( TASK &dummy, Args&... args )
1238  {
1240  assert( this->treelist.size() );
1241 
1250  //printf( "depth %lu\n", this->depth ); fflush( stdout );
1251 
1252  for ( int l = this->depth; l >= 1; l -- )
1253  {
1254  size_t n_nodes = 1 << l;
1255  auto level_beg = this->treelist.begin() + n_nodes - 1;
1256 
1258  for ( size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1259  {
1260  auto *node = *(level_beg + node_ind);
1261  RecuTaskSubmit( node, dummy, args... );
1262  }
1263  }
1264  };
1267  template<typename TASK, typename... Args>
1268  void DistTraverseUp( TASK &dummy, Args&... args )
1269  {
1270  MPINODE *node = mpitreelists.back();
1271  while ( node )
1272  {
1273  if ( this->DoOutOfOrder() ) RecuTaskSubmit( node, dummy, args... );
1274  else RecuTaskExecute( node, dummy, args... );
1276  node = (MPINODE*)node->parent;
1277  }
1278  };
1281  template<typename TASK, typename... Args>
1282  void LocaTraverseDown( TASK &dummy, Args&... args )
1283  {
1285  assert( this->treelist.size() );
1286 
1294  for ( int l = 1; l <= this->depth; l ++ )
1295  {
1296  size_t n_nodes = 1 << l;
1297  auto level_beg = this->treelist.begin() + n_nodes - 1;
1298 
1299  for ( size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1300  {
1301  auto *node = *(level_beg + node_ind);
1302  RecuTaskSubmit( node, dummy, args... );
1303  }
1304  }
1305  };
1308  template<typename TASK, typename... Args>
1309  void DistTraverseDown( TASK &dummy, Args&... args )
1310  {
1311  auto *node = mpitreelists.front();
1312  while ( node )
1313  {
1314  //printf( "now at level %lu\n", node->l ); fflush( stdout );
1315  if ( this->DoOutOfOrder() ) RecuTaskSubmit( node, dummy, args... );
1316  else RecuTaskExecute( node, dummy, args... );
1317  //printf( "RecuTaskSubmit at level %lu\n", node->l ); fflush( stdout );
1318 
1323  node = node->child;
1324  }
1325  };
1328  template<typename TASK, typename... Args>
1329  void LocaTraverseLeafs( TASK &dummy, Args&... args )
1330  {
1332  assert( this->treelist.size() );
1333 
1334  int n_nodes = 1 << this->depth;
1335  auto level_beg = this->treelist.begin() + n_nodes - 1;
1336 
1337  for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
1338  {
1339  auto *node = *(level_beg + node_ind);
1340  RecuTaskSubmit( node, dummy, args... );
1341  }
1342  };
1349  template<typename TASK, typename... Args>
1350  void LocaTraverseUnOrdered( TASK &dummy, Args&... args )
1351  {
1352  LocaTraverseDown( dummy, args... );
1353  };
1360  template<typename TASK, typename... Args>
1361  void DistTraverseUnOrdered( TASK &dummy, Args&... args )
1362  {
1363  DistTraverseDown( dummy, args... );
1364  };
1371  {
1372  for ( auto node : mpitreelists ) node->DependencyCleanUp();
1373  //for ( size_t i = 0; i < mpitreelists.size(); i ++ )
1374  //{
1375  // mpitreelists[ i ]->DependencyCleanUp();
1376  //}
1377 
1379 
1380 
1381 
1382 
1383 
1384 
1385  for ( auto p : NearRecvFrom ) p.DependencyCleanUp();
1386  for ( auto p : FarRecvFrom ) p.DependencyCleanUp();
1387 
1390  };
1395  {
1396  hmlp_run();
1397  this->Barrier();
1398  DependencyCleanUp();
1399  };
1402  void DependOnNearInteractions( int p, Task *task )
1403  {
1405  for ( auto it : NearSentToRank[ p ] )
1406  {
1407  auto *node = this->morton2node[ it ];
1408  node->DependencyAnalysis( R, task );
1409  }
1411  task->TryEnqueue();
1412  };
1415  void DependOnFarInteractions( int p, Task *task )
1416  {
1418  for ( auto it : FarSentToRank[ p ] )
1419  {
1420  auto *node = this->morton2node[ it ];
1421  node->DependencyAnalysis( R, task );
1422  }
1424  task->TryEnqueue();
1425  };
1436  vector<vector<size_t>> NearSentToRank;
1437  vector<map<size_t, int>> NearRecvFromRank;
1438  vector<vector<size_t>> FarSentToRank;
1439  vector<map<size_t, int>> FarRecvFromRank;
1440 
1441  vector<ReadWrite> NearRecvFrom;
1442  vector<ReadWrite> FarRecvFrom;
1443 
1444  private:
1445 
1447  int ierr = 0;
1449  size_t num_points_owned = 0;
1450 
1451 };
1454 };
1455 };
1457 #endif
Definition: tree.hpp:955
This distributed tree inherits the shared memory tree with some additional MPI data structure and fun...
Definition: tree_mpi.hpp:620
tree::Node< SETUP, NODEDATA > NODE
Definition: tree_mpi.hpp:647
void DependencyCleanUp()
Definition: tree_mpi.hpp:1370
Data< int > CheckAllInteractions()
Definition: tree_mpi.hpp:1094
vector< MPINODE * > mpitreelists
Definition: tree_mpi.hpp:658
Definition: mpi_prototypes.h:81
void DependOnNearInteractions(int p, Task *task)
Definition: tree_mpi.hpp:1402
Permuate the order of gids for each internal node to the order of leaf nodes.
Definition: tree.hpp:263
size_t m
Definition: tree_mpi.hpp:306
void CleanUp()
free all tree nodes including local tree nodes, distributed tree nodes and let nodes ...
Definition: tree_mpi.hpp:699
void DependOnParent(Task *task)
Definition: tree_mpi.hpp:563
void RecuTaskSubmit(ARG *arg)
Recursive task sibmission (base case).
Definition: runtime.hpp:446
Tree(mpi::Comm comm)
Definition: tree_mpi.hpp:662
Definition: DistData.hpp:249
Node< SETUP, NODEDATA > MPINODE
Definition: tree_mpi.hpp:650
Definition: tree.hpp:312
int GetCommSize()
Definition: tree_mpi.hpp:586
Definition: tree_mpi.hpp:396
size_t morton
Definition: tree.hpp:808
Definition: hmlp_mpi.hpp:89
int Index2Rank(size_t gid)
Definition: tree_mpi.hpp:1227
bool TryEnqueue()
Try to dispatch the task if there is no dependency left.
Definition: runtime.cpp:339
SETUP::T T
Definition: tree_mpi.hpp:401
void DependOnChildren(Task *task)
Support dependency analysis.
Definition: tree_mpi.hpp:546
SPLITTER splitter
Definition: tree_mpi.hpp:329
Data and setup that are shared with all nodes.
Definition: tree_mpi.hpp:260
void DependOnFarInteractions(int p, Task *task)
Definition: tree_mpi.hpp:1415
void Set(NODE *user_arg)
Definition: tree_mpi.hpp:210
Node(SETUP *setup, size_t n, size_t l, vector< size_t > &gids, Node *parent, unordered_map< size_t, tree::Node< SETUP, NODEDATA > * > *morton2node, Lock *treelock, mpi::Comm comm)
Definition: tree_mpi.hpp:424
mpi::Comm GetComm()
Definition: tree_mpi.hpp:584
void LocaTraverseDown(TASK &dummy, Args &...args)
Definition: tree_mpi.hpp:1282
This is the default ball tree splitter. Given coordinates, compute the direction from the two most fa...
Definition: tree.hpp:595
vector< vector< size_t > > NearSentToRank
Definition: tree_mpi.hpp:1425
void LocaTraverseUnOrdered(TASK &dummy, Args &...args)
For unordered traversal, we just call local downward traversal.
Definition: tree_mpi.hpp:1350
void TreePartition()
partition n points using a distributed binary tree.
Definition: tree_mpi.hpp:958
void LocaTraverseUp(TASK &dummy, Args &...args)
Definition: tree_mpi.hpp:1237
int GetCommRank()
Definition: tree_mpi.hpp:588
Wrapper for omp or pthread mutex.
Definition: tci.hpp:50
void DistTraverseUnOrdered(TASK &dummy, Args &...args)
For unordered traversal, we just call distributed downward traversal.
Definition: tree_mpi.hpp:1361
void MergeNeighbors(size_t k, pair< T, size_t > *A, pair< T, size_t > *B, vector< pair< T, size_t >> &aux)
Definition: tree.hpp:212
void RecuTaskExecute(ARG *arg)
Recursive task execution (base case).
Definition: runtime.hpp:469
Node(SETUP *setup, size_t n, size_t l, Node *parent, unordered_map< size_t, tree::Node< SETUP, NODEDATA > * > *morton2node, Lock *treelock, mpi::Comm comm)
Definition: tree_mpi.hpp:409
void DistTraverseDown(TASK &dummy, Args &...args)
Definition: tree_mpi.hpp:1309
Definition: Data.hpp:134
int Morton2Rank(size_t it)
Definition: tree_mpi.hpp:1222
~Tree()
Definition: tree_mpi.hpp:677
void RecursiveMorton(MPINODE *node, MortonHelper::Recursor r)
Definition: tree_mpi.hpp:1025
void ExecuteAllTasks()
Definition: tree_mpi.hpp:1394
vector< size_t > ContainAny(vector< size_t > &queries, size_t target)
Check if this node contain any query using morton. Notice that queries[] contains gids; thus...
Definition: tree_mpi.hpp:279
Definition: tree_mpi.hpp:336
void AllocateNodes(vector< size_t > &gids)
Definition: tree_mpi.hpp:721
Node * child
Definition: tree_mpi.hpp:591
void Print()
Definition: tree_mpi.hpp:581
vector< size_t > GetPermutation()
Definition: tree_mpi.hpp:784
void LocaTraverseLeafs(TASK &dummy, Args &...args)
Definition: tree_mpi.hpp:1329
Definition: gofmm.hpp:83
vector< size_t > morton
Definition: tree_mpi.hpp:326
DistData< STAR, CBLK, pair< T, size_t > > AllNearestNeighbor(size_t n_tree, size_t n, size_t k, pair< T, size_t > initNN, KNNTASK &dummy)
Definition: tree_mpi.hpp:809
Definition: runtime.hpp:174
void DistTraverseUp(TASK &dummy, Args &...args)
Definition: tree_mpi.hpp:1268
void Split()
Definition: tree_mpi.hpp:450
Node(size_t morton)
Definition: tree_mpi.hpp:436
Definition: thread.hpp:166
This is the default ball tree splitter. Given coordinates, compute the direction from the two most fa...
Definition: tree_mpi.hpp:204