203 template<
typename NODE>
210 void Set( NODE *user_arg )
213 name = string(
"DistSplit" );
214 label = to_string( arg->treelist_id );
216 double flops = 6.0 * arg->n;
217 double mops = 6.0 * arg->n;
220 event.Set( label + name, flops, mops );
228 void DependencyAnalysis()
230 arg->DependencyAnalysis( R,
this );
234 if ( arg->GetCommSize() > 1 )
236 assert( arg->child );
237 arg->child->DependencyAnalysis( RW,
this );
241 assert( arg->lchild && arg->rchild );
242 arg->lchild->DependencyAnalysis( RW,
this );
243 arg->rchild->DependencyAnalysis( RW,
this );
249 void Execute(
Worker* user_worker ) { arg->Split(); };
259 template<
typename SPLITTER,
typename DATATYPE>
279 vector<size_t>
ContainAny( vector<size_t> &queries,
size_t target )
281 vector<size_t> validation( queries.size(), 0 );
283 if ( !morton.size() )
285 printf(
"Morton id was not initialized.\n" );
289 for (
size_t i = 0; i < queries.size(); i ++ )
300 if ( MortonHelper::IsMyParent( morton[ queries[ i ] ], target ) )
315 size_t max_depth = 15;
335 template<
typename NODE>
342 void Set( NODE *user_arg )
344 name = std::string(
"Permutation" );
350 void DependencyAnalysis() { arg->DependOnChildren(
this ); };
362 void Execute(
Worker* user_worker )
364 if ( !arg->isleaf && !arg->child )
366 auto &gids = arg->gids;
367 auto &lgids = arg->lchild->gids;
368 auto &rgids = arg->rchild->gids;
370 gids.insert( gids.end(), rgids.begin(), rgids.end() );
395 template<
typename SETUP,
typename NODEDATA>
401 typedef typename SETUP::T
T;
403 static const int N_CHILDREN = 2;
409 Node( SETUP *setup,
size_t n,
size_t l,
412 Lock *treelock, mpi::Comm comm )
413 : tree::
Node<SETUP, NODEDATA>( setup, n, l,
414 parent, morton2node, treelock )
419 mpi::Comm_size( comm, &size );
420 mpi::Comm_rank( comm, &rank );
424 Node( SETUP *setup,
size_t n,
size_t l, vector<size_t> &gids,
427 Lock *treelock, mpi::Comm comm )
428 :
Node<SETUP, NODEDATA>( setup, n, l, parent,
429 morton2node, treelock, comm )
436 Node(
size_t morton ) : tree::
Node<SETUP, NODEDATA>( morton )
453 int num_points_total = 0;
454 int num_points_owned = (this->gids).size();
456 mpi::Allreduce( &num_points_owned, &num_points_total, 1, MPI_SUM, comm );
457 this->n = num_points_total;
465 auto split = this->setup->splitter( this->gids, comm );
468 int partner_rank = 0;
471 vector<size_t> &kept_gids = child->gids;
472 vector<int> sent_gids;
473 vector<int> recv_gids;
475 if ( rank < size / 2 )
478 partner_rank = rank + size / 2;
480 kept_gids.resize( split[ 0 ].size() );
481 for (
size_t i = 0; i < kept_gids.size(); i ++ )
482 kept_gids[ i ] = this->gids[ split[ 0 ][ i ] ];
484 sent_gids.resize( split[ 1 ].size() );
485 sent_size = sent_gids.size();
486 for (
size_t i = 0; i < sent_gids.size(); i ++ )
487 sent_gids[ i ] = this->gids[ split[ 1 ][ i ] ];
492 partner_rank = rank - size / 2;
494 kept_gids.resize( split[ 1 ].size() );
495 for (
size_t i = 0; i < kept_gids.size(); i ++ )
496 kept_gids[ i ] = this->gids[ split[ 1 ][ i ] ];
498 sent_gids.resize( split[ 0 ].size() );
499 sent_size = sent_gids.size();
500 for (
size_t i = 0; i < sent_gids.size(); i ++ )
501 sent_gids[ i ] = this->gids[ split[ 0 ][ i ] ];
503 assert( partner_rank >= 0 );
526 mpi::ExchangeVector( sent_gids, partner_rank, 20,
527 recv_gids, partner_rank, 20, comm, &status );
529 for (
auto it : recv_gids ) kept_gids.push_back( it );
541 mpi::Barrier( comm );
548 this->DependencyAnalysis( RW, task );
551 if ( this->lchild ) this->lchild->DependencyAnalysis( R, task );
552 if ( this->rchild ) this->rchild->DependencyAnalysis( R, task );
556 if ( child ) child->DependencyAnalysis( R, task );
565 this->DependencyAnalysis( R, task );
568 if ( this->lchild ) this->lchild->DependencyAnalysis( RW, task );
569 if ( this->rchild ) this->rchild->DependencyAnalysis( RW, task );
573 if ( child ) child->DependencyAnalysis( RW, task );
596 mpi::Comm comm = MPI_COMM_WORLD;
619 template<
class SETUP,
class NODEDATA>
625 typedef typename SETUP::T T;
663 mpi::MPIObject( comm )
671 NearRecvFrom.resize( this->GetCommSize() );
673 FarRecvFrom.resize( this->GetCommSize() );
685 if ( mpitreelists.size() )
687 for (
size_t i = 0; i < mpitreelists.size() - 1; i ++ )
688 if ( mpitreelists[ i ] )
delete mpitreelists[ i ];
689 mpitreelists.clear();
702 if ( this->treelist.size() )
704 for (
size_t i = 0; i < this->treelist.size(); i ++ )
705 if ( this->treelist[ i ] )
delete this->treelist[ i ];
707 this->treelist.clear();
710 if ( mpitreelists.size() )
712 for (
size_t i = 0; i < mpitreelists.size() - 1; i ++ )
713 if ( mpitreelists[ i ] )
delete mpitreelists[ i ];
715 mpitreelists.clear();
725 auto mycomm = this->GetComm();
727 int mysize = this->GetCommSize();
729 int myrank = this->GetCommRank();
734 auto *root =
new MPINODE( &(this->setup),
735 this->n, mylevel, gids, NULL,
736 &(this->morton2node), &(this->lock), mycomm );
739 mpitreelists.push_back( root );
749 mycolor = ( myrank < mysize / 2 ) ? 0 : 1;
751 ierr = mpi::Comm_split( mycomm, mycolor, myrank, &(childcomm) );
754 mpi::Comm_size( mycomm, &mysize );
755 mpi::Comm_rank( mycomm, &myrank );
758 auto *parent = mpitreelists.back();
759 auto *child =
new MPINODE( &(this->setup),
760 (
size_t)0, mylevel, parent,
761 &(this->morton2node), &(this->lock), mycomm );
764 child->sibling =
new NODE( (
size_t)0 );
767 parent->kids[ 0 ] = child;
768 parent->child = child;
770 mpitreelists.push_back( child );
776 auto *local_tree_root = mpitreelists.back();
786 vector<size_t> perm_loc, perm_glb;
788 mpi::GatherVector( perm_loc, perm_glb, 0, this->GetComm() );
807 template<
typename KNNTASK>
810 pair<T, size_t> initNN, KNNTASK &dummy )
812 mpi::PrintProgress(
"[BEG] NeighborSearch ...", this->GetComm() );
819 this->setup.m = 4 * k;
820 if ( this->setup.m < 512 ) this->setup.m = 512;
821 this->m = this->setup.m;
837 num_points_owned = n / this->GetCommSize();
839 if ( this->GetCommRank() < ( n % this->GetCommSize() ) )
840 num_points_owned += 1;
845 vector<size_t> gids( num_points_owned, 0 );
846 for (
size_t i = 0; i < num_points_owned; i ++ )
848 gids[ i ] = i * this->GetCommSize() + this->GetCommRank();
851 AllocateNodes( gids );
857 for (
size_t t = 0; t < n_tree; t ++ )
859 DistTraverseDown( mpisplittask );
860 LocaTraverseDown( seqsplittask );
866 this->setup.NN = &Q_cids;
867 LocaTraverseLeafs( dummy );
875 assert( Q_cblk.col_owned() == NN.col_owned() );
941 for (
auto &neig : NN )
943 if ( neig.second < 0 || neig.second >= NN.col() )
945 printf(
"Illegle neighbor gid %lu\n", neig.second );
950 mpi::PrintProgress(
"[END] NeighborSearch ...", this->GetComm() );
960 mpi::PrintProgress(
"[BEG] TreePartitioning ...", this->GetComm() );
963 this->n = this->setup.ProblemSize();
964 this->m = this->setup.LeafNodeSize();
968 for (
size_t i = this->GetCommRank(); i < this->n; i += this->GetCommSize() )
969 this->global_indices.push_back( i );
971 num_points_owned = this->global_indices.size();
973 AllocateNodes( this->global_indices );
984 DistTraverseDown( mpiSPLITtask );
985 LocaTraverseDown( seqSPLITtask );
991 LocaTraverseUp( seqINDXtask );
993 DistTraverseUp( mpiINDXtask );
1000 (this->setup).morton.resize( this->n );
1003 RecursiveMorton( mpitreelists[ 0 ], MortonHelper::Root() );
1006 this->morton2node.clear();
1009 for (
auto node : this->treelist ) this->morton2node[ node->morton ] = node;
1012 for (
auto node : this->mpitreelists )
1014 this->morton2node[ node->morton ] = node;
1015 auto *sibling = node->sibling;
1016 if ( node->l ) this->morton2node[ sibling->morton ] = sibling;
1020 mpi::PrintProgress(
"[END] TreePartitioning ...", this->GetComm() );
1028 int comm_size = this->GetCommSize();
1029 int comm_rank = this->GetCommRank();
1034 node->
morton = MortonHelper::MortonID( r );
1036 if ( node->sibling )
1037 node->sibling->
morton = MortonHelper::SiblingMortonID( r );
1039 if ( node_size < 2 )
1044 auto &gids = this->treelist[ 0 ]->gids;
1045 vector<int> recv_size( comm_size, 0 );
1046 vector<int> recv_disp( comm_size, 0 );
1047 vector<pair<size_t, size_t>> send_pairs;
1048 vector<pair<size_t, size_t>> recv_pairs( this->n );
1051 for (
auto it : gids )
1053 send_pairs.push_back(
1054 pair<size_t, size_t>( it, this->setup.morton[ it ]) );
1058 int send_size = send_pairs.size();
1059 mpi::Allgather( &send_size, 1, recv_size.data(), 1, this->GetComm() );
1061 for (
size_t p = 1; p < comm_size; p ++ )
1063 recv_disp[ p ] = recv_disp[ p - 1 ] + recv_size[ p - 1 ];
1066 size_t total_gids = 0;
1067 for (
size_t p = 0; p < comm_size; p ++ )
1069 total_gids += recv_size[ p ];
1071 assert( total_gids == this->n );
1073 mpi::Allgatherv( send_pairs.data(), send_size,
1074 recv_pairs.data(), recv_size.data(), recv_disp.data(), this->GetComm() );
1076 for (
auto it : recv_pairs ) this->setup.morton[ it.first ] = it.second;
1080 if ( node_rank < node_size / 2 )
1082 RecursiveMorton( node->
child, MortonHelper::RecurLeft( r ) );
1086 RecursiveMorton( node->
child, MortonHelper::RecurRight( r ) );
1097 int total_depth = this->treelist.back()->l;
1099 int num_leafs = 1 << total_depth;
1104 for (
int t = 1; t < this->treelist.size(); t ++ )
1106 auto *node = this->treelist[ t ];
1122 for (
int p = 0; p < this->GetCommSize(); p ++ )
1126 for (
auto & it : node->DistNear[ p ] )
1128 auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1129 auto J = MortonHelper::Morton2Offsets( it.first, total_depth );
1130 for (
auto i : I )
for (
auto j : J )
1132 assert( i < num_leafs && j < num_leafs );
1137 for (
auto & it : node->DistFar[ p ] )
1139 auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1140 auto J = MortonHelper::Morton2Offsets( it.first, total_depth );
1141 for (
auto i : I )
for (
auto j : J )
1143 assert( i < num_leafs && j < num_leafs );
1150 for (
auto *node : mpitreelists )
1166 for (
int p = 0; p < this->GetCommSize(); p ++ )
1170 for (
auto & it : node->DistNear[ p ] )
1172 auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1173 auto J = MortonHelper::Morton2Offsets( it.first, total_depth );
1174 for (
auto i : I )
for (
auto j : J )
1176 assert( i < num_leafs && j < num_leafs );
1181 for (
auto & it : node->DistFar[ p ] )
1183 auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1184 auto J = MortonHelper::Morton2Offsets( it.first, total_depth );
1185 for (
auto i : I )
for (
auto j : J )
1187 assert( i < num_leafs && j < num_leafs );
1195 mpi::Reduce( A.data(), B.data(), A.size(), MPI_SUM, 0, this->GetComm() );
1197 if ( this->GetCommRank() == 0 )
1199 for (
size_t i = 0; i < num_leafs; i ++ )
1201 for (
size_t j = 0; j < num_leafs; j ++ ) printf(
"%d", B( i, j ) );
1224 return MortonHelper::Morton2Rank( it, this->GetCommSize() );
1229 return Morton2Rank( this->setup.morton[ gid ] );
1236 template<
typename TASK,
typename... Args>
1240 assert( this->treelist.size() );
1252 for (
int l = this->depth; l >= 1; l -- )
1254 size_t n_nodes = 1 << l;
1255 auto level_beg = this->treelist.begin() + n_nodes - 1;
1258 for (
size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1260 auto *node = *(level_beg + node_ind);
1267 template<
typename TASK,
typename... Args>
1270 MPINODE *node = mpitreelists.back();
1273 if ( this->DoOutOfOrder() )
RecuTaskSubmit( node, dummy, args... );
1276 node = (MPINODE*)node->parent;
1281 template<
typename TASK,
typename... Args>
1285 assert( this->treelist.size() );
1294 for (
int l = 1; l <= this->depth; l ++ )
1296 size_t n_nodes = 1 << l;
1297 auto level_beg = this->treelist.begin() + n_nodes - 1;
1299 for (
size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1301 auto *node = *(level_beg + node_ind);
1308 template<
typename TASK,
typename... Args>
1311 auto *node = mpitreelists.front();
1315 if ( this->DoOutOfOrder() )
RecuTaskSubmit( node, dummy, args... );
1328 template<
typename TASK,
typename... Args>
1332 assert( this->treelist.size() );
1334 int n_nodes = 1 << this->depth;
1335 auto level_beg = this->treelist.begin() + n_nodes - 1;
1337 for (
int node_ind = 0; node_ind < n_nodes; node_ind ++ )
1339 auto *node = *(level_beg + node_ind);
1349 template<
typename TASK,
typename... Args>
1352 LocaTraverseDown( dummy, args... );
1360 template<
typename TASK,
typename... Args>
1363 DistTraverseDown( dummy, args... );
1372 for (
auto node : mpitreelists ) node->DependencyCleanUp();
1385 for (
auto p : NearRecvFrom ) p.DependencyCleanUp();
1386 for (
auto p : FarRecvFrom ) p.DependencyCleanUp();
1398 DependencyCleanUp();
1405 for (
auto it : NearSentToRank[ p ] )
1407 auto *node = this->morton2node[ it ];
1408 node->DependencyAnalysis( R, task );
1418 for (
auto it : FarSentToRank[ p ] )
1420 auto *node = this->morton2node[ it ];
1421 node->DependencyAnalysis( R, task );
1437 vector<map<size_t, int>> NearRecvFromRank;
1438 vector<vector<size_t>> FarSentToRank;
1439 vector<map<size_t, int>> FarRecvFromRank;
1441 vector<ReadWrite> NearRecvFrom;
1442 vector<ReadWrite> FarRecvFrom;
1449 size_t num_points_owned = 0;
This distributed tree inherits the shared memory tree with some additional MPI data structure and fun...
Definition: tree_mpi.hpp:620
tree::Node< SETUP, NODEDATA > NODE
Definition: tree_mpi.hpp:647
void DependencyCleanUp()
Definition: tree_mpi.hpp:1370
Data< int > CheckAllInteractions()
Definition: tree_mpi.hpp:1094
vector< MPINODE * > mpitreelists
Definition: tree_mpi.hpp:658
Definition: mpi_prototypes.h:81
void DependOnNearInteractions(int p, Task *task)
Definition: tree_mpi.hpp:1402
Permuate the order of gids for each internal node to the order of leaf nodes.
Definition: tree.hpp:263
size_t m
Definition: tree_mpi.hpp:306
void CleanUp()
free all tree nodes including local tree nodes, distributed tree nodes and let nodes ...
Definition: tree_mpi.hpp:699
void DependOnParent(Task *task)
Definition: tree_mpi.hpp:563
void RecuTaskSubmit(ARG *arg)
Recursive task sibmission (base case).
Definition: runtime.hpp:446
Tree(mpi::Comm comm)
Definition: tree_mpi.hpp:662
Definition: DistData.hpp:249
Node< SETUP, NODEDATA > MPINODE
Definition: tree_mpi.hpp:650
int GetCommSize()
Definition: tree_mpi.hpp:586
Definition: tree_mpi.hpp:396
size_t morton
Definition: tree.hpp:808
Definition: hmlp_mpi.hpp:89
int Index2Rank(size_t gid)
Definition: tree_mpi.hpp:1227
bool TryEnqueue()
Try to dispatch the task if there is no dependency left.
Definition: runtime.cpp:339
SETUP::T T
Definition: tree_mpi.hpp:401
void DependOnChildren(Task *task)
Support dependency analysis.
Definition: tree_mpi.hpp:546
SPLITTER splitter
Definition: tree_mpi.hpp:329
Data and setup that are shared with all nodes.
Definition: tree_mpi.hpp:260
void DependOnFarInteractions(int p, Task *task)
Definition: tree_mpi.hpp:1415
void Set(NODE *user_arg)
Definition: tree_mpi.hpp:210
Node(SETUP *setup, size_t n, size_t l, vector< size_t > &gids, Node *parent, unordered_map< size_t, tree::Node< SETUP, NODEDATA > * > *morton2node, Lock *treelock, mpi::Comm comm)
Definition: tree_mpi.hpp:424
mpi::Comm GetComm()
Definition: tree_mpi.hpp:584
void LocaTraverseDown(TASK &dummy, Args &...args)
Definition: tree_mpi.hpp:1282
This is the default ball tree splitter. Given coordinates, compute the direction from the two most fa...
Definition: tree.hpp:595
vector< vector< size_t > > NearSentToRank
Definition: tree_mpi.hpp:1425
void LocaTraverseUnOrdered(TASK &dummy, Args &...args)
For unordered traversal, we just call local downward traversal.
Definition: tree_mpi.hpp:1350
void TreePartition()
partition n points using a distributed binary tree.
Definition: tree_mpi.hpp:958
void LocaTraverseUp(TASK &dummy, Args &...args)
Definition: tree_mpi.hpp:1237
int GetCommRank()
Definition: tree_mpi.hpp:588
Wrapper for omp or pthread mutex.
Definition: tci.hpp:50
void DistTraverseUnOrdered(TASK &dummy, Args &...args)
For unordered traversal, we just call distributed downward traversal.
Definition: tree_mpi.hpp:1361
void MergeNeighbors(size_t k, pair< T, size_t > *A, pair< T, size_t > *B, vector< pair< T, size_t >> &aux)
Definition: tree.hpp:212
void RecuTaskExecute(ARG *arg)
Recursive task execution (base case).
Definition: runtime.hpp:469
Node(SETUP *setup, size_t n, size_t l, Node *parent, unordered_map< size_t, tree::Node< SETUP, NODEDATA > * > *morton2node, Lock *treelock, mpi::Comm comm)
Definition: tree_mpi.hpp:409
void DistTraverseDown(TASK &dummy, Args &...args)
Definition: tree_mpi.hpp:1309
int Morton2Rank(size_t it)
Definition: tree_mpi.hpp:1222
~Tree()
Definition: tree_mpi.hpp:677
void RecursiveMorton(MPINODE *node, MortonHelper::Recursor r)
Definition: tree_mpi.hpp:1025
void ExecuteAllTasks()
Definition: tree_mpi.hpp:1394
vector< size_t > ContainAny(vector< size_t > &queries, size_t target)
Check if this node contain any query using morton. Notice that queries[] contains gids; thus...
Definition: tree_mpi.hpp:279
Definition: tree_mpi.hpp:336
void AllocateNodes(vector< size_t > &gids)
Definition: tree_mpi.hpp:721
Node * child
Definition: tree_mpi.hpp:591
void Print()
Definition: tree_mpi.hpp:581
vector< size_t > GetPermutation()
Definition: tree_mpi.hpp:784
void LocaTraverseLeafs(TASK &dummy, Args &...args)
Definition: tree_mpi.hpp:1329
vector< size_t > morton
Definition: tree_mpi.hpp:326
DistData< STAR, CBLK, pair< T, size_t > > AllNearestNeighbor(size_t n_tree, size_t n, size_t k, pair< T, size_t > initNN, KNNTASK &dummy)
Definition: tree_mpi.hpp:809
Definition: runtime.hpp:174
void DistTraverseUp(TASK &dummy, Args &...args)
Definition: tree_mpi.hpp:1268
void Split()
Definition: tree_mpi.hpp:450
Node(size_t morton)
Definition: tree_mpi.hpp:436
Definition: thread.hpp:166
This is the default ball tree splitter. Given coordinates, compute the direction from the two most fa...
Definition: tree_mpi.hpp:204