27 #include <type_traits> 42 #include <hmlp_base.hpp> 44 #include <primitives/combinatorics.hpp> 49 #define REPORT_ANN_STATUS 0 54 bool has_uneven_split =
false;
66 typedef pair<size_t, size_t> Recursor;
68 static Recursor Root()
70 return Recursor( 0, 0 );
73 static Recursor RecurLeft( Recursor r )
75 return Recursor( ( r.first << 1 ) + 0, r.second + 1 );
78 static Recursor RecurRight( Recursor r )
80 return Recursor( ( r.first << 1 ) + 1, r.second + 1 );
86 size_t shift = Shift( r.second );
88 return ( r.first << shift ) + r.second;
94 size_t shift = Shift( r.second );
97 return ( ( r.first - 1 ) << shift ) + r.second;
99 return ( ( r.first + 1 ) << shift ) + r.second;
105 size_t itdepth = Depth( it );
107 while ( size >>= 1 ) mpidepth ++;
108 if ( itdepth > mpidepth ) itdepth = mpidepth;
109 size_t itshift = Shift( itdepth );
110 return ( it >> itshift ) << ( mpidepth - itdepth );
115 if ( r.second == depth )
117 offsets.push_back( r.first );
121 Morton2Offsets( RecurLeft( r ), depth, offsets );
122 Morton2Offsets( RecurRight( r ), depth, offsets );
129 vector<size_t> offsets;
130 size_t mydepth = Depth( me );
131 assert( mydepth <= depth );
132 Recursor r( me >> Shift( mydepth ), mydepth );
133 Morton2Offsets( r, depth, offsets );
151 size_t itlevel = Depth( it );
152 size_t mylevel = Depth( me );
153 size_t itshift = Shift( itlevel );
154 bool is_my_parent = !( ( me ^ it ) >> itshift ) && ( itlevel <= mylevel );
156 hmlp_print_binary( me );
157 hmlp_print_binary( it );
158 hmlp_print_binary( ( me ^ it ) >> itshift );
159 printf(
"ismyparent %d itlevel %lu mylevel %lu shift %lu fixed shift %d\n",
160 is_my_parent, itlevel, mylevel, itshift, 1 << LEVELOFFSET );
166 template<
typename TQUERY>
169 for (
auto & q : querys )
170 if ( IsMyParent( q, target ) )
return true;
177 static size_t Depth(
size_t it )
179 size_t filter = ( 1 << LEVELOFFSET ) - 1;
183 static size_t Shift(
size_t depth )
185 return ( 1 << LEVELOFFSET ) - depth + LEVELOFFSET;
188 const static int LEVELOFFSET = 4;
194 bool less_first(
const pair<T, size_t> &a,
const pair<T, size_t> &b )
196 return ( a.first < b.first );
199 bool less_second(
const pair<T, size_t> &a,
const pair<T, size_t> &b )
201 return ( a.second < b.second );
204 bool equal_second(
const pair<T, size_t> &a,
const pair<T, size_t> &b )
206 return ( a.second == b.second );
213 pair<T, size_t> *B, vector<pair<T, size_t>> &aux )
216 if ( aux.size() != 2 * k ) aux.resize( 2 * k );
218 for (
size_t i = 0; i < k; i++ ) aux[ i ] = A[ i ];
219 for (
size_t i = 0; i < k; i++ ) aux[ k + i ] = B[ i ];
221 sort( aux.begin(), aux.end(), less_second<T> );
222 auto it = unique( aux.begin(), aux.end(), equal_second<T> );
223 sort( aux.begin(), it, less_first<T> );
225 for (
size_t i = 0; i < k; i++ ) A[ i ] = aux[ i ];
231 vector<pair<T, size_t>> &A, vector<pair<T, size_t>> &B )
233 assert( A.size() >= n * k && B.size() >= n * k );
236 vector<pair<T, size_t> > aux( 2 * k );
238 for(
size_t i = 0; i < n; i++ )
262 template<
typename NODE>
269 void Set( NODE *user_arg )
271 name = string(
"Permutation" );
277 void DependencyAnalysis()
279 arg->DependencyAnalysis( RW,
this );
282 arg->lchild->DependencyAnalysis( R,
this );
283 arg->rchild->DependencyAnalysis( R,
this );
289 void Execute(
Worker* user_worker )
291 auto &gids = arg->gids;
292 auto *lchild = arg->lchild;
293 auto *rchild = arg->rchild;
297 auto &lgids = lchild->gids;
298 auto &rgids = rchild->gids;
300 gids.insert( gids.end(), rgids.begin(), rgids.end() );
311 template<
typename NODE>
318 void Set( NODE *user_arg )
320 name = string(
"Split" );
326 void DependencyAnalysis() { arg->DependOnParent(
this ); };
328 void Execute(
Worker* user_worker ) { arg->Split(); };
594 template<
typename SETUP,
typename NODEDATA>
600 typedef typename SETUP::T
T;
602 static const int N_CHILDREN = 2;
604 Node( SETUP* setup,
size_t n,
size_t l,
605 Node *parent, unordered_map<size_t, Node*> *morton2node,
Lock *treelock )
611 this->treelist_id = 0;
612 this->gids.resize( n );
613 this->isleaf =
false;
614 this->parent = parent;
617 this->morton2node = morton2node;
618 this->treelock = treelock;
619 for (
int i = 0; i < N_CHILDREN; i++ ) kids[ i ] = NULL;
622 Node( SETUP *setup,
int n,
int l, vector<size_t> gids,
623 Node *parent, unordered_map<size_t, Node*> *morton2node,
Lock *treelock )
629 this->treelist_id = 0;
631 this->isleaf =
false;
632 this->parent = parent;
635 this->morton2node = morton2node;
636 this->treelock = treelock;
637 for (
int i = 0; i < N_CHILDREN; i++ ) kids[ i ] = NULL;
645 Node(
size_t morton ) { this->morton = morton; };
662 if ( isleaf )
return;
665 int max_depth = setup->max_depth;
667 double beg = omp_get_wtime();
668 auto split = setup->splitter( gids );
669 double splitter_time = omp_get_wtime() - beg;
672 if ( std::abs( (
int)split[ 0 ].size() - (
int)split[ 1 ].size() ) > 1 )
674 if ( !has_uneven_split )
676 printf(
"\n\nWARNING! uneven split. Using random split instead %lu %lu\n\n",
677 split[ 0 ].size(), split[ 1 ].size() );
678 has_uneven_split =
true;
682 split[ 0 ].resize( gids.size() / 2 );
683 split[ 1 ].resize( gids.size() - ( gids.size() / 2 ) );
685 for (
size_t i = 0; i < gids.size(); i ++ )
687 if ( i < gids.size() / 2 ) split[ 0 ][ i ] = i;
688 else split[ 1 ][ i - ( gids.size() / 2 ) ] = i;
692 for (
size_t i = 0; i < N_CHILDREN; i ++ )
694 int nchild = split[ i ].size();
697 kids[ i ]->Resize( nchild );
698 for (
int j = 0; j < nchild; j ++ )
700 kids[ i ]->gids[ j ] = gids[ split[ i ][ j ] ];
704 catch (
const exception & e )
706 cout << e.what() << endl;
719 if ( !setup->morton.size() )
721 printf(
"Morton id was not initialized.\n" );
724 for (
size_t i = 0; i < queries.size(); i ++ )
726 if ( MortonHelper::IsMyParent( setup->morton[ queries[ i ] ], morton ) )
730 hmlp_print_binary( setup->morton[ queries[ i ] ] );
731 hmlp_print_binary( morton );
744 if ( !setup->morton.size() )
746 printf(
"Morton id was not initialized.\n" );
749 for (
auto it = querys.begin(); it != querys.end(); it ++ )
751 if ( MortonHelper::IsMyParent( (*it)->morton, morton ) )
763 printf(
"l %lu offset %lu n %lu\n", this->l, this->offset, this->n );
764 hmlp_print_binary( this->morton );
771 if ( this->lchild ) this->lchild->DependencyAnalysis( R, task );
772 if ( this->rchild ) this->rchild->DependencyAnalysis( R, task );
773 this->DependencyAnalysis( RW, task );
780 this->DependencyAnalysis( R, task );
781 if ( this->lchild ) this->lchild->DependencyAnalysis( RW, task );
782 if ( this->rchild ) this->rchild->DependencyAnalysis( RW, task );
789 this->DependencyAnalysis( RW, task );
819 set<size_t> FarNodeMortonIDs;
823 set<Node*> NearNodes;
824 set<size_t> NearNodeMortonIDs;
828 set<Node*> NNFarNodes;
829 set<Node*> ProposedNNFarNodes;
830 set<size_t> NNFarNodeMortonIDs;
834 set<Node*> NNNearNodes;
835 set<Node*> ProposedNNNearNodes;
836 set<size_t> NNNearNodeMortonIDs;
847 vector<map<size_t, Data<T>>> DistNear;
859 Node *sibling = NULL;
861 unordered_map<size_t, Node*> *morton2node = NULL;
875 template<
typename SPLITTER,
typename DATATYPE>
890 size_t max_depth = 15;
912 vector<size_t>
ContainAny( vector<size_t> &queries,
size_t target )
914 vector<size_t> validation( queries.size(), 0 );
916 if ( !morton.size() )
918 printf(
"Morton id was not initialized.\n" );
922 for (
size_t i = 0; i < queries.size(); i ++ )
934 if ( MortonHelper::IsMyParent( morton[ queries[ i ] ], target ) )
954 template<
class SETUP,
class NODEDATA>
959 typedef typename SETUP::T T;
963 static const int N_CHILDREN = 2;
1000 for (
int i = 0; i < treelist.size(); i ++ )
1002 if ( treelist[ i ] )
delete treelist[ i ];
1004 morton2node.clear();
1013 node->offset = offset;
1016 Offset( node->lchild, offset + 0 );
1017 Offset( node->rchild, offset + node->lchild->gids.size() );
1026 if ( !node )
return;
1028 node->
morton = MortonHelper::MortonID( r );
1030 RecursiveMorton( node->lchild, MortonHelper::RecurLeft( r ) );
1031 RecursiveMorton( node->rchild, MortonHelper::RecurRight( r ) );
1033 if ( !node->lchild )
1035 for (
auto it : node->gids ) setup.morton[ it ] = node->
morton;
1051 int glb_depth = std::ceil( std::log2( (
double)n / m ) );
1052 if ( glb_depth > setup.max_depth ) glb_depth = setup.max_depth;
1054 depth = glb_depth - root->
l;
1060 for (
auto node_ptr : treelist )
delete node_ptr;
1062 morton2node.clear();
1063 treelist.reserve( 1 << ( depth + 1 ) );
1064 deque<NODE*> treequeue;
1066 treequeue.push_back( root );
1070 while (
auto *node = treequeue.front() )
1073 node->treelist_id = treelist.size();
1075 if ( node->l < glb_depth )
1077 for (
int i = 0; i < N_CHILDREN; i ++ )
1079 node->kids[ i ] =
new NODE( &setup, 0, node->l + 1, node, &morton2node, &lock );
1080 treequeue.push_back( node->kids[ i ] );
1082 node->lchild = node->kids[ 0 ];
1083 node->rchild = node->kids[ 1 ];
1084 if ( node->lchild ) node->lchild->sibling = node->rchild;
1085 if ( node->rchild ) node->rchild->sibling = node->lchild;
1090 node->isleaf =
true;
1091 treequeue.push_back( NULL );
1093 treelist.push_back( node );
1094 treequeue.pop_front();
1104 double beg, alloc_time, split_time, morton_time, permute_time;
1106 this->n = setup.ProblemSize();
1107 this->m = setup.LeafNodeSize();
1110 global_indices.clear();
1111 for (
size_t i = 0; i < n; i ++ ) global_indices.push_back( i );
1114 has_uneven_split =
false;
1117 beg = omp_get_wtime();
1118 AllocateNodes(
new NODE( &setup, n, 0, global_indices, NULL, &morton2node, &lock ) );
1119 alloc_time = omp_get_wtime() - beg;
1122 beg = omp_get_wtime();
1124 TraverseDown( splittask );
1126 split_time = omp_get_wtime() - beg;
1130 setup.morton.resize( n );
1132 RecursiveMorton( treelist[ 0 ], MortonHelper::Root() );
1135 Offset( treelist[ 0 ], 0 );
1138 morton2node.clear();
1139 for (
size_t i = 0; i < treelist.size(); i ++ )
1141 morton2node[ treelist[ i ]->morton ] = treelist[ i ];
1146 TraverseUp( indexpermutetask );
1155 int n_nodes = 1 << this->depth;
1156 auto level_beg = this->treelist.begin() + n_nodes - 1;
1158 vector<size_t> perm;
1160 for (
int node_ind = 0; node_ind < n_nodes; node_ind ++ )
1162 auto *node = *(level_beg + node_ind);
1163 auto gids = node->gids;
1164 perm.insert( perm.end(), gids.begin(), gids.end() );
1177 template<
typename KNNTASK>
1179 size_t max_depth, pair<T, size_t> initNN,
1187 if ( setup.m < 32 ) setup.m = 32;
1190 if ( REPORT_ANN_STATUS )
1192 printf(
"========================================================\n");
1196 for (
int t = 0; t < n_tree; t ++ )
1199 double knn_acc = 0.0;
1203 TraverseLeafs( dummy );
1206 size_t n_nodes = 1 << depth;
1207 auto level_beg = treelist.begin() + n_nodes - 1;
1208 for (
size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1210 auto *node = *(level_beg + node_ind);
1211 knn_acc += node->data.knn_acc;
1212 num_acc += node->data.num_acc;
1214 if ( REPORT_ANN_STATUS )
1216 printf(
"ANN iter %2d, average accuracy %.2lf%% (over %4lu samples)\n",
1217 t, knn_acc / num_acc, num_acc );
1221 if ( knn_acc / num_acc < 0.8 )
1223 if ( 2.0 * setup.m < 2048 ) setup.m = 2.0 * setup.m;
1229 printf(
"Iter %2d NN 0 ", t );
1230 for (
size_t i = 0; i < NN.row(); i ++ )
1232 printf(
"%E(%lu) ", NN[ i ].first, NN[ i ].second );
1238 if ( REPORT_ANN_STATUS )
1240 printf(
"========================================================\n\n");
1244 #pragma omp parallel for 1245 for (
size_t j = 0; j < NN.col(); j ++ )
1246 sort( NN.data() + j * NN.row(), NN.data() + ( j + 1 ) * NN.row() );
1249 for (
auto &neig : NN )
1251 if ( neig.second < 0 || neig.second >= NN.col() )
1253 printf(
"Illegle neighbor gid %lu\n", neig.second );
1267 int total_depth = treelist.back()->l;
1269 int num_leafs = 1 << total_depth;
1273 for (
int t = 1; t < treelist.size(); t ++ )
1275 auto *node = treelist[ t ];
1277 for (
auto *it : node->NNNearNodes )
1279 auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1280 auto J = MortonHelper::Morton2Offsets( it->morton, total_depth );
1281 for (
auto i : I )
for (
auto j : J ) A( i, j ) += 1;
1284 for (
auto *it : node->NNFarNodes )
1286 auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1287 auto J = MortonHelper::Morton2Offsets( it->morton, total_depth );
1288 for (
auto i : I )
for (
auto j : J ) A( i, j ) += 1;
1292 for (
size_t i = 0; i < num_leafs; i ++ )
1294 for (
size_t j = 0; j < num_leafs; j ++ ) printf(
"%d", A( i, j ) );
1310 template<
typename TASK,
typename... Args>
1314 assert( this->treelist.size() );
1316 int n_nodes = 1 << this->depth;
1317 auto level_beg = this->treelist.begin() + n_nodes - 1;
1319 if ( out_of_order_traversal )
1321 for (
int node_ind = 0; node_ind < n_nodes; node_ind ++ )
1323 auto *node = *(level_beg + node_ind);
1329 int nthd_glb = omp_get_max_threads();
1331 #pragma omp parallel for if ( n_nodes > nthd_glb / 2 ) schedule( dynamic ) 1332 for (
int node_ind = 0; node_ind < n_nodes; node_ind ++ )
1334 auto *node = *(level_beg + node_ind);
1344 template<
typename TASK,
typename... Args>
1348 assert( this->treelist.size() );
1358 int local_begin_level = ( treelist[ 0 ]->l ) ? 1 : 0;
1361 for (
int l = this->depth; l >= local_begin_level; l -- )
1363 size_t n_nodes = 1 << l;
1364 auto level_beg = this->treelist.begin() + n_nodes - 1;
1367 if ( out_of_order_traversal )
1370 for (
size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1372 auto *node = *(level_beg + node_ind);
1378 int nthd_glb = omp_get_max_threads();
1380 #pragma omp parallel for if ( n_nodes > nthd_glb / 2 ) schedule( dynamic ) 1381 for (
size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1383 auto *node = *(level_beg + node_ind);
1395 template<
typename TASK,
typename... Args>
1399 assert( this->treelist.size() );
1408 int local_begin_level = ( treelist[ 0 ]->l ) ? 1 : 0;
1410 for (
int l = local_begin_level; l <= this->depth; l ++ )
1412 size_t n_nodes = 1 << l;
1413 auto level_beg = this->treelist.begin() + n_nodes - 1;
1415 if ( out_of_order_traversal )
1418 for (
size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1420 auto *node = *(level_beg + node_ind);
1426 int nthd_glb = omp_get_max_threads();
1428 #pragma omp parallel for if ( n_nodes > nthd_glb / 2 ) schedule( dynamic ) 1429 for (
size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1431 auto *node = *(level_beg + node_ind);
1444 template<
typename TASK,
typename... Args>
1447 TraverseDown( dummy, args... );
1457 for (
auto node : treelist ) node->DependencyCleanUp();
1459 for (
auto it : morton2node )
1461 auto *node = it.second;
1462 if ( node ) node->DependencyCleanUp();
1469 DependencyCleanUp();
1477 template<
typename SUMMARY>
1480 assert( N_CHILDREN == 2 );
1482 for ( std::size_t l = 0; l <= depth; l ++ )
1484 size_t n_nodes = 1 << l;
1485 auto level_beg = treelist.begin() + n_nodes - 1;
1486 for (
size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1488 auto *node = *(level_beg + node_ind);
1497 bool out_of_order_traversal =
true;
1501 vector<size_t> global_indices;
void Split()
Definition: tree.hpp:657
vector< NODE * > treelist
Definition: tree.hpp:984
size_t n
Definition: tree.hpp:802
void DependOnChildren(Task *task)
Definition: tree.hpp:769
static bool IsMyParent(size_t me, size_t it)
Check if it'' isme'''s ancestor by checking two facts. 1) itlevel >= mylevel and 2) morton above itle...
Definition: tree.hpp:149
~Tree()
Definition: tree.hpp:996
void TreePartition()
Shared-memory tree partition.
Definition: tree.hpp:1102
void Offset(NODE *node, size_t offset)
Definition: tree.hpp:1009
Data< pair< T, size_t > > AllNearestNeighbor(size_t n_tree, size_t k, size_t max_depth, pair< T, size_t > initNN, KNNTASK &dummy)
Definition: tree.hpp:1178
void AllocateNodes(NODE *root)
Allocate the local tree using the local root with n points and depth l.
Definition: tree.hpp:1048
Permuate the order of gids for each internal node to the order of leaf nodes.
Definition: tree.hpp:263
void DependencyCleanUp()
Definition: tree.hpp:1451
SETUP::T T
Definition: tree.hpp:600
void RecuTaskSubmit(ARG *arg)
Recursive task sibmission (base case).
Definition: runtime.hpp:446
void ExecuteAllTasks()
Definition: tree.hpp:1466
size_t morton
Definition: tree.hpp:808
size_t l
Definition: tree.hpp:805
~Node()
Definition: tree.hpp:648
set< size_t > NearIDs
Definition: tree.hpp:822
void DependOnParent(Task *task)
Definition: tree.hpp:778
NODEDATA data
Definition: tree.hpp:799
Tree()
Definition: tree.hpp:993
This class provides the ability to perform dependency analysis.
Definition: runtime.hpp:498
vector< size_t > morton
Definition: tree.hpp:899
set< size_t > FarIDs
Definition: tree.hpp:817
bool TryEnqueue()
Try to dispatch the task if there is no dependency left.
Definition: runtime.cpp:339
set< size_t > NNNearIDs
Definition: tree.hpp:833
static vector< size_t > Morton2Offsets(size_t me, size_t depth)
Definition: tree.hpp:127
Data and setup that are shared with all nodes.
Definition: tree.hpp:876
void RecursiveMorton(NODE *node, MortonHelper::Recursor r)
Definition: tree.hpp:1023
static bool ContainAny(size_t target, TQUERY &querys)
Definition: tree.hpp:167
SETUP setup
Definition: tree.hpp:968
bool ContainAny(set< Node * > &querys)
Definition: tree.hpp:742
bool DoOutOfOrder()
Definition: tree.hpp:1473
static size_t SiblingMortonID(Recursor r)
Definition: tree.hpp:91
This is the default ball tree splitter. Given coordinates, compute the direction from the two most fa...
Definition: tree.hpp:595
Node(size_t morton)
Definition: tree.hpp:645
void DependOnNoOne(Task *task)
Definition: tree.hpp:787
void TraverseLeafs(TASK &dummy, Args &...args)
Definition: tree.hpp:1311
Wrapper for omp or pthread mutex.
Definition: tci.hpp:50
void TraverseUnOrdered(TASK &dummy, Args &...args)
For unordered traversal, we just call local downward traversal.
Definition: tree.hpp:1445
unordered_map< size_t, NODE * > morton2node
Definition: tree.hpp:990
void MergeNeighbors(size_t k, pair< T, size_t > *A, pair< T, size_t > *B, vector< pair< T, size_t >> &aux)
Definition: tree.hpp:212
void RecuTaskExecute(ARG *arg)
Recursive task execution (base case).
Definition: runtime.hpp:469
size_t treelist_id
Definition: tree.hpp:812
set< size_t > NNFarIDs
Definition: tree.hpp:827
static int Morton2Rank(size_t it, int size)
return the MPI rank that owns it.
Definition: tree.hpp:103
Node< SETUP, NODEDATA > NODE
Definition: tree.hpp:961
Data< int > CheckAllInteractions()
Definition: tree.hpp:1264
void Summary(SUMMARY &summary)
Summarize all events in each level.
Definition: tree.hpp:1478
static size_t MortonID(Recursor r)
Definition: tree.hpp:83
void TraverseUp(TASK &dummy, Args &...args)
Definition: tree.hpp:1345
vector< map< size_t, Data< T > > > DistFar
Definition: tree.hpp:846
vector< size_t > GetPermutation()
Definition: tree.hpp:1153
bool ContainAny(vector< size_t > &queries)
Check if this node contain any query using morton. Notice that queries[] contains gids; thus...
Definition: tree.hpp:717
Lock lock
Definition: tree.hpp:981
vector< size_t > ContainAny(vector< size_t > &queries, size_t target)
Check if this node contain any query using morton. Notice that queries[] contains gids; thus...
Definition: tree.hpp:912
static void Morton2Offsets(Recursor r, size_t depth, vector< size_t > &offsets)
Definition: tree.hpp:113
bool less_first(const pair< T, size_t > &a, const pair< T, size_t > &b)
Definition: tree.hpp:194
void TraverseDown(TASK &dummy, Args &...args)
Definition: tree.hpp:1396
void Print()
Definition: tree.hpp:761
SPLITTER splitter
Definition: tree.hpp:902
Definition: runtime.hpp:174
Definition: thread.hpp:166