HMLP: High-performance Machine Learning Primitives
tree.hpp
1 
21 #ifndef TREE_HPP
22 #define TREE_HPP
23 
24 
25 #include <assert.h>
26 #include <typeinfo>
27 #include <type_traits>
28 #include <algorithm>
29 #include <functional>
30 #include <set>
31 #include <vector>
32 #include <deque>
33 #include <iostream>
34 #include <random>
35 #include <cstdint>
36 
37 
38 
39 
41 #include <hmlp.h>
42 #include <hmlp_base.hpp>
44 #include <primitives/combinatorics.hpp>
46 using namespace std;
47 using namespace hmlp;
48 
49 #define REPORT_ANN_STATUS 0
50 
51 //#define DEBUG_TREE 1
52 
53 
54 bool has_uneven_split = false;
55 
56 
57 
58 
59 namespace hmlp
60 {
61 
63 {
64  public:
65 
66  typedef pair<size_t, size_t> Recursor;
67 
68  static Recursor Root()
69  {
70  return Recursor( 0, 0 );
71  };
72 
73  static Recursor RecurLeft( Recursor r )
74  {
75  return Recursor( ( r.first << 1 ) + 0, r.second + 1 );
76  };
77 
78  static Recursor RecurRight( Recursor r )
79  {
80  return Recursor( ( r.first << 1 ) + 1, r.second + 1 );
81  };
82 
83  static size_t MortonID( Recursor r )
84  {
86  size_t shift = Shift( r.second );
88  return ( r.first << shift ) + r.second;
89  };
90 
91  static size_t SiblingMortonID( Recursor r )
92  {
94  size_t shift = Shift( r.second );
96  if ( r.first % 2 )
97  return ( ( r.first - 1 ) << shift ) + r.second;
98  else
99  return ( ( r.first + 1 ) << shift ) + r.second;
100  };
101 
103  static int Morton2Rank( size_t it, int size )
104  {
105  size_t itdepth = Depth( it );
106  size_t mpidepth = 0;
107  while ( size >>= 1 ) mpidepth ++;
108  if ( itdepth > mpidepth ) itdepth = mpidepth;
109  size_t itshift = Shift( itdepth );
110  return ( it >> itshift ) << ( mpidepth - itdepth );
111  };
113  static void Morton2Offsets( Recursor r, size_t depth, vector<size_t> &offsets )
114  {
115  if ( r.second == depth )
116  {
117  offsets.push_back( r.first );
118  }
119  else
120  {
121  Morton2Offsets( RecurLeft( r ), depth, offsets );
122  Morton2Offsets( RecurRight( r ), depth, offsets );
123  }
124  };
127  static vector<size_t> Morton2Offsets( size_t me, size_t depth )
128  {
129  vector<size_t> offsets;
130  size_t mydepth = Depth( me );
131  assert( mydepth <= depth );
132  Recursor r( me >> Shift( mydepth ), mydepth );
133  Morton2Offsets( r, depth, offsets );
134  return offsets;
135  };
149  static bool IsMyParent( size_t me, size_t it )
150  {
151  size_t itlevel = Depth( it );
152  size_t mylevel = Depth( me );
153  size_t itshift = Shift( itlevel );
154  bool is_my_parent = !( ( me ^ it ) >> itshift ) && ( itlevel <= mylevel );
155  #ifdef DEBUG_TREE
156  hmlp_print_binary( me );
157  hmlp_print_binary( it );
158  hmlp_print_binary( ( me ^ it ) >> itshift );
159  printf( "ismyparent %d itlevel %lu mylevel %lu shift %lu fixed shift %d\n",
160  is_my_parent, itlevel, mylevel, itshift, 1 << LEVELOFFSET );
161  #endif
162  return is_my_parent;
163  };
166  template<typename TQUERY>
167  static bool ContainAny( size_t target, TQUERY &querys )
168  {
169  for ( auto & q : querys )
170  if ( IsMyParent( q, target ) ) return true;
171  return false;
172  };
175  private:
176 
177  static size_t Depth( size_t it )
178  {
179  size_t filter = ( 1 << LEVELOFFSET ) - 1;
180  return it & filter;
181  };
183  static size_t Shift( size_t depth )
184  {
185  return ( 1 << LEVELOFFSET ) - depth + LEVELOFFSET;
186  };
188  const static int LEVELOFFSET = 4;
189 
190 };
193 template<typename T>
194 bool less_first( const pair<T, size_t> &a, const pair<T, size_t> &b )
195 {
196  return ( a.first < b.first );
197 };
198 template<typename T>
199 bool less_second( const pair<T, size_t> &a, const pair<T, size_t> &b )
200 {
201  return ( a.second < b.second );
202 };
203 template<typename T>
204 bool equal_second( const pair<T, size_t> &a, const pair<T, size_t> &b )
205 {
206  return ( a.second == b.second );
207 };
208 
209 
210 
211 template<typename T>
212 void MergeNeighbors( size_t k, pair<T, size_t> *A,
213  pair<T, size_t> *B, vector<pair<T, size_t>> &aux )
214 {
216  if ( aux.size() != 2 * k ) aux.resize( 2 * k );
217 
218  for ( size_t i = 0; i < k; i++ ) aux[ i ] = A[ i ];
219  for ( size_t i = 0; i < k; i++ ) aux[ k + i ] = B[ i ];
220 
221  sort( aux.begin(), aux.end(), less_second<T> );
222  auto it = unique( aux.begin(), aux.end(), equal_second<T> );
223  sort( aux.begin(), it, less_first<T> );
224 
225  for ( size_t i = 0; i < k; i++ ) A[ i ] = aux[ i ];
226 };
229 template<typename T>
230 void MergeNeighbors( size_t k, size_t n,
231  vector<pair<T, size_t>> &A, vector<pair<T, size_t>> &B )
232 {
233  assert( A.size() >= n * k && B.size() >= n * k );
234  #pragma omp parallel
235  {
236  vector<pair<T, size_t> > aux( 2 * k );
237  #pragma omp for
238  for( size_t i = 0; i < n; i++ )
239  {
240  MergeNeighbors( k, &(A[ i * k ]), &(B[ i * k ]), aux );
241  }
242  }
243 };
253 namespace tree
254 {
262 template<typename NODE>
263 class IndexPermuteTask : public Task
264 {
265  public:
266 
267  NODE *arg;
268 
269  void Set( NODE *user_arg )
270  {
271  name = string( "Permutation" );
272  arg = user_arg;
273  // Need an accurate cost model.
274  cost = 1.0;
275  };
276 
277  void DependencyAnalysis()
278  {
279  arg->DependencyAnalysis( RW, this );
280  if ( !arg->isleaf )
281  {
282  arg->lchild->DependencyAnalysis( R, this );
283  arg->rchild->DependencyAnalysis( R, this );
284  }
285  this->TryEnqueue();
286  };
287 
288 
289  void Execute( Worker* user_worker )
290  {
291  auto &gids = arg->gids;
292  auto *lchild = arg->lchild;
293  auto *rchild = arg->rchild;
294 
295  if ( !arg->isleaf )
296  {
297  auto &lgids = lchild->gids;
298  auto &rgids = rchild->gids;
299  gids = lgids;
300  gids.insert( gids.end(), rgids.begin(), rgids.end() );
301  }
302  };
303 
304 };
311 template<typename NODE>
312 class SplitTask : public Task
313 {
314  public:
315 
316  NODE *arg = NULL;
317 
318  void Set( NODE *user_arg )
319  {
320  name = string( "Split" );
321  arg = user_arg;
322  // Need an accurate cost model.
323  cost = 1.0;
324  };
325 
326  void DependencyAnalysis() { arg->DependOnParent( this ); };
327 
328  void Execute( Worker* user_worker ) { arg->Split(); };
329 
330 };
333 // * @brief This is the default ball tree splitter. Given coordinates,
335 // * compute the direction from the two most far away points.
336 // * Project all points to this line and split into two groups
337 // * using a median select.
338 // *
339 // * @para
340 // *
341 // * @TODO Need to explit the parallelism.
342 // */
343 //template<int N_SPLIT, typename T>
344 //struct centersplit
345 //{
346 // /** closure */
347 // Data<T> *Coordinate = NULL;
348 //
349 // inline vector<vector<size_t> > operator()
350 // (
351 // vector<size_t>& gids
352 // ) const
353 // {
354 // assert( N_SPLIT == 2 );
355 //
356 // Data<T> &X = *Coordinate;
357 // size_t d = X.row();
358 // size_t n = gids.size();
359 //
360 // T rcx0 = 0.0, rx01 = 0.0;
361 // size_t x0, x1;
362 // vector<vector<size_t> > split( N_SPLIT );
363 //
364 //
365 // vector<T> centroid = combinatorics::Mean( d, n, X, gids );
366 // vector<T> direction( d );
367 // vector<T> projection( n, 0.0 );
368 //
369 // //printf( "After Mean\n" );
370 //
371 // // Compute the farest x0 point from the centroid
372 // for ( size_t i = 0; i < n; i ++ )
373 // {
374 // T rcx = 0.0;
375 // for ( size_t p = 0; p < d; p ++ )
376 // {
377 // T tmp = X( p, gids[ i ] ) - centroid[ p ];
378 //
379 //
380 // rcx += tmp * tmp;
381 // }
382 // if ( rcx > rcx0 )
383 // {
384 // rcx0 = rcx;
385 // x0 = i;
386 // }
387 // }
388 //
389 // //printf( "After Farest\n" );
390 // //for ( int p = 0; p < d; p ++ )
391 // //{
392 // //}
393 // //printf( "\n" );
394 //
395 // // Compute the farest point x1 from x0
396 // for ( size_t i = 0; i < n; i ++ )
397 // {
398 // T rxx = 0.0;
399 // for ( size_t p = 0; p < d; p ++ )
400 // {
401 // T tmp = X( p, gids[ i ] ) - X( p, gids[ x0 ] );
402 // rxx += tmp * tmp;
403 // }
404 // if ( rxx > rx01 )
405 // {
406 // rx01 = rxx;
407 // x1 = i;
408 // }
409 // }
410 //
411 //
412 //
413 // // Compute direction
414 // for ( size_t p = 0; p < d; p ++ )
415 // {
416 // direction[ p ] = X( p, gids[ x1 ] ) - X( p, gids[ x0 ] );
417 // }
418 //
419 // //printf( "After Direction\n" );
420 // //for ( int p = 0; p < d; p ++ )
421 // //{
422 // // printf( "%5.2lf ", direction[ p ] );
423 // //}
424 // //printf( "\n" );
425 // //exit( 1 );
426 //
427 //
428 //
429 // // Compute projection
430 // projection.resize( n, 0.0 );
431 // for ( size_t i = 0; i < n; i ++ )
432 // for ( size_t p = 0; p < d; p ++ )
433 // projection[ i ] += X( p, gids[ i ] ) * direction[ p ];
434 //
435 // //printf( "After Projetion\n" );
436 // //for ( int p = 0; p < d; p ++ )
437 // //{
438 // // printf( "%5.2lf ", projec[ p ] );
439 // //}
440 // //printf( "\n" );
441 //
442 //
443 //
444 // /** Parallel median search */
445 // T median;
446 //
447 // if ( 1 )
448 // {
449 // median = combinatorics::Select( n, n / 2, projection );
450 // }
451 // else
452 // {
453 // auto proj_copy = projection;
454 // sort( proj_copy.begin(), proj_copy.end() );
455 // median = proj_copy[ n / 2 ];
456 // }
457 //
458 //
459 //
460 // split[ 0 ].reserve( n / 2 + 1 );
461 // split[ 1 ].reserve( n / 2 + 1 );
462 //
463 // /** TODO: Can be parallelized */
464 // vector<size_t> middle;
465 // for ( size_t i = 0; i < n; i ++ )
466 // {
467 // if ( projection[ i ] < median ) split[ 0 ].push_back( i );
468 // else if ( projection[ i ] > median ) split[ 1 ].push_back( i );
469 // else middle.push_back( i );
470 // }
471 //
472 // for ( size_t i = 0; i < middle.size(); i ++ )
473 // {
474 // if ( split[ 0 ].size() <= split[ 1 ].size() ) split[ 0 ].push_back( middle[ i ] );
475 // else split[ 1 ].push_back( middle[ i ] );
476 // }
477 //
478 //
479 // //printf( "split median %lf left %d right %d\n",
480 // // median,
481 // // (int)split[ 0 ].size(), (int)split[ 1 ].size() );
482 //
483 // //if ( split[ 0 ].size() > 0.6 * n ||
484 // // split[ 1 ].size() > 0.6 * n )
485 // //{
486 // // for ( int i = 0; i < n; i ++ )
487 // // {
488 // // printf( "%E ", projection[ i ] );
489 // // }
490 // // printf( "\n" );
491 // //}
492 //
493 //
494 // return split;
495 // };
496 //};
497 //
498 //
500 // * @brief This is the splitter used in the randomized tree. Given
501 // * coordinates, project all points onto a random direction
502 // * and split into two groups using a median select.
503 // *
504 // * @para
505 // *
506 // * @TODO Need to explit the parallelism.
507 // */
508 //template<int N_SPLIT, typename T>
509 //struct randomsplit
510 //{
511 // /** Closure */
512 // Data<T> *Coordinate = NULL;
513 //
514 // inline vector<vector<size_t> > operator()
515 // (
516 // vector<size_t>& gids
517 // ) const
518 // {
519 // assert( N_SPLIT == 2 );
520 //
521 // Data<T> &X = *Coordinate;
522 // size_t d = X.row();
523 // size_t n = gids.size();
524 //
525 // vector<vector<size_t> > split( N_SPLIT );
526 //
527 // vector<T> direction( d );
528 // vector<T> projection( n, 0.0 );
529 //
530 // // Compute random direction
531 // static default_random_engine generator;
532 // normal_distribution<T> distribution;
533 // for ( int p = 0; p < d; p ++ )
534 // {
535 // direction[ p ] = distribution( generator );
536 // }
537 //
538 // // Compute projection
539 // projection.resize( n, 0.0 );
540 // for ( size_t i = 0; i < n; i ++ )
541 // for ( size_t p = 0; p < d; p ++ )
542 // projection[ i ] += X( p, gids[ i ] ) * direction[ p ];
543 //
544 //
545 // // Parallel median search
546 // // T median = Select( n, n / 2, projection );
547 // auto proj_copy = projection;
548 // sort( proj_copy.begin(), proj_copy.end() );
549 // T median = proj_copy[ n / 2 ];
550 //
551 // split[ 0 ].reserve( n / 2 + 1 );
552 // split[ 1 ].reserve( n / 2 + 1 );
553 //
554 // /** TODO: Can be parallelized */
555 // vector<size_t> middle;
556 // for ( size_t i = 0; i < n; i ++ )
557 // {
558 // if ( projection[ i ] < median ) split[ 0 ].push_back( i );
559 // else if ( projection[ i ] > median ) split[ 1 ].push_back( i );
560 // else middle.push_back( i );
561 // }
562 //
563 // for ( size_t i = 0; i < middle.size(); i ++ )
564 // {
565 // if ( split[ 0 ].size() <= split[ 1 ].size() ) split[ 0 ].push_back( middle[ i ] );
566 // else split[ 1 ].push_back( middle[ i ] );
567 // }
568 //
569 //
570 // //printf( "split median %lf left %d right %d\n",
571 // // median,
572 // // (int)split[ 0 ].size(), (int)split[ 1 ].size() );
573 //
574 // //if ( split[ 0 ].size() > 0.6 * n ||
575 // // split[ 1 ].size() > 0.6 * n )
576 // //{
577 // // for ( int i = 0; i < n; i ++ )
578 // // {
579 // // printf( "%E ", projection[ i ] );
580 // // }
581 // // printf( "\n" );
582 // //}
583 //
584 //
585 // return split;
586 // };
587 //};
588 
589 
593 //template<typename SETUP, int N_CHILDREN, typename NODEDATA>
594 template<typename SETUP, typename NODEDATA>
595 class Node : public ReadWrite
596 {
597  public:
598 
600  typedef typename SETUP::T T;
602  static const int N_CHILDREN = 2;
603 
604  Node( SETUP* setup, size_t n, size_t l,
605  Node *parent, unordered_map<size_t, Node*> *morton2node, Lock *treelock )
606  {
607  this->setup = setup;
608  this->n = n;
609  this->l = l;
610  this->morton = 0;
611  this->treelist_id = 0;
612  this->gids.resize( n );
613  this->isleaf = false;
614  this->parent = parent;
615  this->lchild = NULL;
616  this->rchild = NULL;
617  this->morton2node = morton2node;
618  this->treelock = treelock;
619  for ( int i = 0; i < N_CHILDREN; i++ ) kids[ i ] = NULL;
620  };
621 
622  Node( SETUP *setup, int n, int l, vector<size_t> gids,
623  Node *parent, unordered_map<size_t, Node*> *morton2node, Lock *treelock )
624  {
625  this->setup = setup;
626  this->n = n;
627  this->l = l;
628  this->morton = 0;
629  this->treelist_id = 0;
630  this->gids = gids;
631  this->isleaf = false;
632  this->parent = parent;
633  this->lchild = NULL;
634  this->rchild = NULL;
635  this->morton2node = morton2node;
636  this->treelock = treelock;
637  for ( int i = 0; i < N_CHILDREN; i++ ) kids[ i ] = NULL;
638  };
639 
640 
645  Node( size_t morton ) { this->morton = morton; };
646 
648  ~Node() {};
649 
650  void Resize( int n )
651  {
652  this->n = n;
653  gids.resize( n );
654  };
655 
656 
657  void Split()
658  {
659  try
660  {
662  if ( isleaf ) return;
663 
664  int m = setup->m;
665  int max_depth = setup->max_depth;
666 
667  double beg = omp_get_wtime();
668  auto split = setup->splitter( gids );
669  double splitter_time = omp_get_wtime() - beg;
670  //printf( "splitter %5.3lfs\n", splitter_time );
671 
672  if ( std::abs( (int)split[ 0 ].size() - (int)split[ 1 ].size() ) > 1 )
673  {
674  if ( !has_uneven_split )
675  {
676  printf( "\n\nWARNING! uneven split. Using random split instead %lu %lu\n\n",
677  split[ 0 ].size(), split[ 1 ].size() );
678  has_uneven_split = true;
679  }
680  //printf( "split[ 0 ].size() %lu split[ 1 ].size() %lu\n",
681  // split[ 0 ].size(), split[ 1 ].size() );
682  split[ 0 ].resize( gids.size() / 2 );
683  split[ 1 ].resize( gids.size() - ( gids.size() / 2 ) );
684  //#pragma omp parallel for
685  for ( size_t i = 0; i < gids.size(); i ++ )
686  {
687  if ( i < gids.size() / 2 ) split[ 0 ][ i ] = i;
688  else split[ 1 ][ i - ( gids.size() / 2 ) ] = i;
689  }
690  }
691 
692  for ( size_t i = 0; i < N_CHILDREN; i ++ )
693  {
694  int nchild = split[ i ].size();
695 
697  kids[ i ]->Resize( nchild );
698  for ( int j = 0; j < nchild; j ++ )
699  {
700  kids[ i ]->gids[ j ] = gids[ split[ i ][ j ] ];
701  }
702  }
703  }
704  catch ( const exception & e )
705  {
706  cout << e.what() << endl;
707  }
708  };
717  bool ContainAny( vector<size_t> &queries )
718  {
719  if ( !setup->morton.size() )
720  {
721  printf( "Morton id was not initialized.\n" );
722  exit( 1 );
723  }
724  for ( size_t i = 0; i < queries.size(); i ++ )
725  {
726  if ( MortonHelper::IsMyParent( setup->morton[ queries[ i ] ], morton ) )
727  {
728 #ifdef DEBUG_TREE
729  printf( "\n" );
730  hmlp_print_binary( setup->morton[ queries[ i ] ] );
731  hmlp_print_binary( morton );
732  printf( "\n" );
733 #endif
734  return true;
735  }
736  }
737  return false;
738 
739  };
742  bool ContainAny( set<Node*> &querys )
743  {
744  if ( !setup->morton.size() )
745  {
746  printf( "Morton id was not initialized.\n" );
747  exit( 1 );
748  }
749  for ( auto it = querys.begin(); it != querys.end(); it ++ )
750  {
751  if ( MortonHelper::IsMyParent( (*it)->morton, morton ) )
752  {
753  return true;
754  }
755  }
756  return false;
757 
758  };
761  void Print()
762  {
763  printf( "l %lu offset %lu n %lu\n", this->l, this->offset, this->n );
764  hmlp_print_binary( this->morton );
765  };
766 
767 
769  void DependOnChildren( Task *task )
770  {
771  if ( this->lchild ) this->lchild->DependencyAnalysis( R, task );
772  if ( this->rchild ) this->rchild->DependencyAnalysis( R, task );
773  this->DependencyAnalysis( RW, task );
775  task->TryEnqueue();
776  };
777 
778  void DependOnParent( Task *task )
779  {
780  this->DependencyAnalysis( R, task );
781  if ( this->lchild ) this->lchild->DependencyAnalysis( RW, task );
782  if ( this->rchild ) this->rchild->DependencyAnalysis( RW, task );
784  task->TryEnqueue();
785  };
786 
787  void DependOnNoOne( Task *task )
788  {
789  this->DependencyAnalysis( RW, task );
791  task->TryEnqueue();
792  };
793 
794 
796  SETUP *setup = NULL;
797 
799  NODEDATA data;
800 
802  size_t n;
803 
805  size_t l;
806 
808  size_t morton = 0;
809  size_t offset = 0;
810 
812  size_t treelist_id;
813 
814  vector<size_t> gids;
815 
817  set<size_t> FarIDs;
818  set<Node*> FarNodes;
819  set<size_t> FarNodeMortonIDs;
820 
822  set<size_t> NearIDs;
823  set<Node*> NearNodes;
824  set<size_t> NearNodeMortonIDs;
825 
827  set<size_t> NNFarIDs;
828  set<Node*> NNFarNodes;
829  set<Node*> ProposedNNFarNodes;
830  set<size_t> NNFarNodeMortonIDs;
831 
833  set<size_t> NNNearIDs;
834  set<Node*> NNNearNodes;
835  set<Node*> ProposedNNNearNodes;
836  set<size_t> NNNearNodeMortonIDs;
837 
839  //set<size_t> HSSNear;
840  //set<size_t> HSSFar;
841  //set<size_t> FMMNear;
842  //set<size_t> FMMFar;
843 
844 
846  vector<map<size_t, Data<T>>> DistFar;
847  vector<map<size_t, Data<T>>> DistNear;
848 
849 
850 
851 
853  Lock *treelock = NULL;
854 
856  Node *kids[ N_CHILDREN ];
857  Node *lchild = NULL;
858  Node *rchild = NULL;
859  Node *sibling = NULL;
860  Node *parent = NULL;
861  unordered_map<size_t, Node*> *morton2node = NULL;
862 
863  bool isleaf;
864 
865  private:
866 
867 };
875 template<typename SPLITTER, typename DATATYPE>
876 class Setup
877 {
878  public:
879 
880  typedef DATATYPE T;
881 
882  Setup() {};
883 
884  ~Setup() {};
885 
887  size_t m = 0;
888 
890  size_t max_depth = 15;
891 
893  //Data<T> *X = NULL;
894 
896  Data<pair<T, size_t>> *NN = NULL;
897 
899  vector<size_t> morton;
900 
902  SPLITTER splitter;
903 
904 
905 
912  vector<size_t> ContainAny( vector<size_t> &queries, size_t target )
913  {
914  vector<size_t> validation( queries.size(), 0 );
915 
916  if ( !morton.size() )
917  {
918  printf( "Morton id was not initialized.\n" );
919  exit( 1 );
920  }
921 
922  for ( size_t i = 0; i < queries.size(); i ++ )
923  {
925  //auto it = this->setup->morton.find( queries[ i ] );
926 
927  //if ( it != this->setup->morton.end() )
928  //{
929  // if ( tree::IsMyParent( *it, this->morton ) ) validation[ i ] = 1;
930  //}
931 
932 
933  //if ( tree::IsMyParent( morton[ queries[ i ] ], target ) )
934  if ( MortonHelper::IsMyParent( morton[ queries[ i ] ], target ) )
935  validation[ i ] = 1;
936 
937  }
938  return validation;
939 
940  };
950 };
954 template<class SETUP, class NODEDATA>
955 class Tree
956 {
957  public:
958 
959  typedef typename SETUP::T T;
962 
963  static const int N_CHILDREN = 2;
964 
965 
966 
968  SETUP setup;
969 
971  size_t n = 0;
972 
974  size_t m = 0;
975 
977  size_t depth = 0;
978 
979 
982 
984  vector<NODE*> treelist;
985 
990  unordered_map<size_t, NODE*> morton2node;
991 
993  Tree() {};
994 
997  {
998  //printf( "~Tree() shared treelist.size() %lu treequeue.size() %lu\n",
999  // treelist.size(), treequeue.size() );
1000  for ( int i = 0; i < treelist.size(); i ++ )
1001  {
1002  if ( treelist[ i ] ) delete treelist[ i ];
1003  }
1004  morton2node.clear();
1005  //printf( "end ~Tree() shared\n" );
1006  };
1007 
1009  void Offset( NODE *node, size_t offset )
1010  {
1011  if ( node )
1012  {
1013  node->offset = offset;
1014  if ( node->lchild )
1015  {
1016  Offset( node->lchild, offset + 0 );
1017  Offset( node->rchild, offset + node->lchild->gids.size() );
1018  }
1019  }
1020  };
1023  void RecursiveMorton( NODE *node, MortonHelper::Recursor r )
1024  {
1026  if ( !node ) return;
1028  node->morton = MortonHelper::MortonID( r );
1030  RecursiveMorton( node->lchild, MortonHelper::RecurLeft( r ) );
1031  RecursiveMorton( node->rchild, MortonHelper::RecurRight( r ) );
1033  if ( !node->lchild )
1034  {
1035  for ( auto it : node->gids ) setup.morton[ it ] = node->morton;
1036  }
1037  };
1038 
1039 
1048  void AllocateNodes( NODE *root )
1049  {
1051  int glb_depth = std::ceil( std::log2( (double)n / m ) );
1052  if ( glb_depth > setup.max_depth ) glb_depth = setup.max_depth;
1054  depth = glb_depth - root->l;
1055 
1056  //printf( "local AllocateNodes n %lu m %lu glb_depth %d loc_depth %lu\n",
1057  // n, m, glb_depth, depth );
1058 
1060  for ( auto node_ptr : treelist ) delete node_ptr;
1061  treelist.clear();
1062  morton2node.clear();
1063  treelist.reserve( 1 << ( depth + 1 ) );
1064  deque<NODE*> treequeue;
1066  treequeue.push_back( root );
1067 
1068 
1070  while ( auto *node = treequeue.front() )
1071  {
1073  node->treelist_id = treelist.size();
1075  if ( node->l < glb_depth )
1076  {
1077  for ( int i = 0; i < N_CHILDREN; i ++ )
1078  {
1079  node->kids[ i ] = new NODE( &setup, 0, node->l + 1, node, &morton2node, &lock );
1080  treequeue.push_back( node->kids[ i ] );
1081  }
1082  node->lchild = node->kids[ 0 ];
1083  node->rchild = node->kids[ 1 ];
1084  if ( node->lchild ) node->lchild->sibling = node->rchild;
1085  if ( node->rchild ) node->rchild->sibling = node->lchild;
1086  }
1087  else
1088  {
1090  node->isleaf = true;
1091  treequeue.push_back( NULL );
1092  }
1093  treelist.push_back( node );
1094  treequeue.pop_front();
1095  }
1096 
1097  };
1103  {
1104  double beg, alloc_time, split_time, morton_time, permute_time;
1105 
1106  this->n = setup.ProblemSize();
1107  this->m = setup.LeafNodeSize();
1108 
1110  global_indices.clear();
1111  for ( size_t i = 0; i < n; i ++ ) global_indices.push_back( i );
1112 
1114  has_uneven_split = false;
1115 
1117  beg = omp_get_wtime();
1118  AllocateNodes( new NODE( &setup, n, 0, global_indices, NULL, &morton2node, &lock ) );
1119  alloc_time = omp_get_wtime() - beg;
1120 
1122  beg = omp_get_wtime();
1123  SplitTask<NODE> splittask;
1124  TraverseDown( splittask );
1125  ExecuteAllTasks();
1126  split_time = omp_get_wtime() - beg;
1127 
1128 
1130  setup.morton.resize( n );
1132  RecursiveMorton( treelist[ 0 ], MortonHelper::Root() );
1133 
1134 
1135  Offset( treelist[ 0 ], 0 );
1136 
1138  morton2node.clear();
1139  for ( size_t i = 0; i < treelist.size(); i ++ )
1140  {
1141  morton2node[ treelist[ i ]->morton ] = treelist[ i ];
1142  }
1143 
1145  IndexPermuteTask<NODE> indexpermutetask;
1146  TraverseUp( indexpermutetask );
1147  ExecuteAllTasks();
1148 
1149  };
1153  vector<size_t> GetPermutation()
1154  {
1155  int n_nodes = 1 << this->depth;
1156  auto level_beg = this->treelist.begin() + n_nodes - 1;
1157 
1158  vector<size_t> perm;
1159 
1160  for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
1161  {
1162  auto *node = *(level_beg + node_ind);
1163  auto gids = node->gids;
1164  perm.insert( perm.end(), gids.begin(), gids.end() );
1165  }
1166 
1167  return perm;
1168  };
1177  template<typename KNNTASK>
1178  Data<pair<T, size_t>> AllNearestNeighbor( size_t n_tree, size_t k,
1179  size_t max_depth, pair<T, size_t> initNN,
1180  KNNTASK &dummy )
1181  {
1183  Data<pair<T, size_t>> NN( k, setup.ProblemSize(), initNN );
1184 
1186  setup.m = 4 * k;
1187  if ( setup.m < 32 ) setup.m = 32;
1188  setup.NN = &NN;
1189 
1190  if ( REPORT_ANN_STATUS )
1191  {
1192  printf( "========================================================\n");
1193  }
1194 
1196  for ( int t = 0; t < n_tree; t ++ )
1197  {
1199  double knn_acc = 0.0;
1200  size_t num_acc = 0;
1202  TreePartition();
1203  TraverseLeafs( dummy );
1204  ExecuteAllTasks();
1205 
1206  size_t n_nodes = 1 << depth;
1207  auto level_beg = treelist.begin() + n_nodes - 1;
1208  for ( size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1209  {
1210  auto *node = *(level_beg + node_ind);
1211  knn_acc += node->data.knn_acc;
1212  num_acc += node->data.num_acc;
1213  }
1214  if ( REPORT_ANN_STATUS )
1215  {
1216  printf( "ANN iter %2d, average accuracy %.2lf%% (over %4lu samples)\n",
1217  t, knn_acc / num_acc, num_acc );
1218  }
1219 
1221  if ( knn_acc / num_acc < 0.8 )
1222  {
1223  if ( 2.0 * setup.m < 2048 ) setup.m = 2.0 * setup.m;
1224  }
1225  else break;
1226 
1227 
1228 #ifdef DEBUG_TREE
1229  printf( "Iter %2d NN 0 ", t );
1230  for ( size_t i = 0; i < NN.row(); i ++ )
1231  {
1232  printf( "%E(%lu) ", NN[ i ].first, NN[ i ].second );
1233  }
1234  printf( "\n" );
1235 #endif
1236  }
1238  if ( REPORT_ANN_STATUS )
1239  {
1240  printf( "========================================================\n\n");
1241  }
1242 
1244  #pragma omp parallel for
1245  for ( size_t j = 0; j < NN.col(); j ++ )
1246  sort( NN.data() + j * NN.row(), NN.data() + ( j + 1 ) * NN.row() );
1247 
1249  for ( auto &neig : NN )
1250  {
1251  if ( neig.second < 0 || neig.second >= NN.col() )
1252  {
1253  printf( "Illegle neighbor gid %lu\n", neig.second );
1254  break;
1255  }
1256  }
1257 
1259  return NN;
1260 
1261  };
1265  {
1267  int total_depth = treelist.back()->l;
1269  int num_leafs = 1 << total_depth;
1271  Data<int> A( num_leafs, num_leafs, 0 );
1273  for ( int t = 1; t < treelist.size(); t ++ )
1274  {
1275  auto *node = treelist[ t ];
1277  for ( auto *it : node->NNNearNodes )
1278  {
1279  auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1280  auto J = MortonHelper::Morton2Offsets( it->morton, total_depth );
1281  for ( auto i : I ) for ( auto j : J ) A( i, j ) += 1;
1282  }
1284  for ( auto *it : node->NNFarNodes )
1285  {
1286  auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1287  auto J = MortonHelper::Morton2Offsets( it->morton, total_depth );
1288  for ( auto i : I ) for ( auto j : J ) A( i, j ) += 1;
1289  }
1290  }
1291 
1292  for ( size_t i = 0; i < num_leafs; i ++ )
1293  {
1294  for ( size_t j = 0; j < num_leafs; j ++ ) printf( "%d", A( i, j ) );
1295  printf( "\n" );
1296  }
1297 
1298 
1299  return A;
1300  };
1310  template<typename TASK, typename... Args>
1311  void TraverseLeafs( TASK &dummy, Args&... args )
1312  {
1314  assert( this->treelist.size() );
1315 
1316  int n_nodes = 1 << this->depth;
1317  auto level_beg = this->treelist.begin() + n_nodes - 1;
1318 
1319  if ( out_of_order_traversal )
1320  {
1321  for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
1322  {
1323  auto *node = *(level_beg + node_ind);
1324  RecuTaskSubmit( node, dummy, args... );
1325  }
1326  }
1327  else
1328  {
1329  int nthd_glb = omp_get_max_threads();
1331  #pragma omp parallel for if ( n_nodes > nthd_glb / 2 ) schedule( dynamic )
1332  for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
1333  {
1334  auto *node = *(level_beg + node_ind);
1335  RecuTaskExecute( node, dummy, args... );
1336  }
1337  }
1338  };
1344  template<typename TASK, typename... Args>
1345  void TraverseUp( TASK &dummy, Args&... args )
1346  {
1348  assert( this->treelist.size() );
1349 
1358  int local_begin_level = ( treelist[ 0 ]->l ) ? 1 : 0;
1359 
1361  for ( int l = this->depth; l >= local_begin_level; l -- )
1362  {
1363  size_t n_nodes = 1 << l;
1364  auto level_beg = this->treelist.begin() + n_nodes - 1;
1365 
1366 
1367  if ( out_of_order_traversal )
1368  {
1370  for ( size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1371  {
1372  auto *node = *(level_beg + node_ind);
1373  RecuTaskSubmit( node, dummy, args... );
1374  }
1375  }
1376  else
1377  {
1378  int nthd_glb = omp_get_max_threads();
1380  #pragma omp parallel for if ( n_nodes > nthd_glb / 2 ) schedule( dynamic )
1381  for ( size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1382  {
1383  auto *node = *(level_beg + node_ind);
1384  RecuTaskExecute( node, dummy, args... );
1385  }
1386  }
1387  }
1388  };
1395  template<typename TASK, typename... Args>
1396  void TraverseDown( TASK &dummy, Args&... args )
1397  {
1399  assert( this->treelist.size() );
1400 
1408  int local_begin_level = ( treelist[ 0 ]->l ) ? 1 : 0;
1409 
1410  for ( int l = local_begin_level; l <= this->depth; l ++ )
1411  {
1412  size_t n_nodes = 1 << l;
1413  auto level_beg = this->treelist.begin() + n_nodes - 1;
1414 
1415  if ( out_of_order_traversal )
1416  {
1418  for ( size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1419  {
1420  auto *node = *(level_beg + node_ind);
1421  RecuTaskSubmit( node, dummy, args... );
1422  }
1423  }
1424  else
1425  {
1426  int nthd_glb = omp_get_max_threads();
1428  #pragma omp parallel for if ( n_nodes > nthd_glb / 2 ) schedule( dynamic )
1429  for ( size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1430  {
1431  auto *node = *(level_beg + node_ind);
1432  RecuTaskExecute( node, dummy, args... );
1433  }
1434  }
1435  }
1436  };
1444  template<typename TASK, typename... Args>
1445  void TraverseUnOrdered( TASK &dummy, Args&... args )
1446  {
1447  TraverseDown( dummy, args... );
1448  };
1452  {
1453  //for ( size_t i = 0; i < treelist.size(); i ++ )
1454  //{
1455  // treelist[ i ]->DependencyCleanUp();
1456  //}
1457  for ( auto node : treelist ) node->DependencyCleanUp();
1458 
1459  for ( auto it : morton2node )
1460  {
1461  auto *node = it.second;
1462  if ( node ) node->DependencyCleanUp();
1463  }
1464  };
1467  {
1468  hmlp_run();
1469  DependencyCleanUp();
1470  };
1473  bool DoOutOfOrder() { return out_of_order_traversal; };
1474 
1475 
1477  template<typename SUMMARY>
1478  void Summary( SUMMARY &summary )
1479  {
1480  assert( N_CHILDREN == 2 );
1481 
1482  for ( std::size_t l = 0; l <= depth; l ++ )
1483  {
1484  size_t n_nodes = 1 << l;
1485  auto level_beg = treelist.begin() + n_nodes - 1;
1486  for ( size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1487  {
1488  auto *node = *(level_beg + node_ind);
1489  summary( node );
1490  }
1491  }
1492  };
1495  private:
1496 
1497  bool out_of_order_traversal = true;
1498 
1499  protected:
1500 
1501  vector<size_t> global_indices;
1502 
1503 };
1504 };
1505 };
1507 #endif
void Split()
Definition: tree.hpp:657
Definition: tree.hpp:955
vector< NODE * > treelist
Definition: tree.hpp:984
size_t n
Definition: tree.hpp:802
void DependOnChildren(Task *task)
Definition: tree.hpp:769
static bool IsMyParent(size_t me, size_t it)
Check if it&#39;&#39; isme&#39;&#39;&#39;s ancestor by checking two facts. 1) itlevel >= mylevel and 2) morton above itle...
Definition: tree.hpp:149
~Tree()
Definition: tree.hpp:996
void TreePartition()
Shared-memory tree partition.
Definition: tree.hpp:1102
void Offset(NODE *node, size_t offset)
Definition: tree.hpp:1009
Data< pair< T, size_t > > AllNearestNeighbor(size_t n_tree, size_t k, size_t max_depth, pair< T, size_t > initNN, KNNTASK &dummy)
Definition: tree.hpp:1178
void AllocateNodes(NODE *root)
Allocate the local tree using the local root with n points and depth l.
Definition: tree.hpp:1048
Permuate the order of gids for each internal node to the order of leaf nodes.
Definition: tree.hpp:263
void DependencyCleanUp()
Definition: tree.hpp:1451
SETUP::T T
Definition: tree.hpp:600
void RecuTaskSubmit(ARG *arg)
Recursive task sibmission (base case).
Definition: runtime.hpp:446
Definition: tree.hpp:312
void ExecuteAllTasks()
Definition: tree.hpp:1466
size_t morton
Definition: tree.hpp:808
size_t l
Definition: tree.hpp:805
~Node()
Definition: tree.hpp:648
set< size_t > NearIDs
Definition: tree.hpp:822
void DependOnParent(Task *task)
Definition: tree.hpp:778
NODEDATA data
Definition: tree.hpp:799
Tree()
Definition: tree.hpp:993
This class provides the ability to perform dependency analysis.
Definition: runtime.hpp:498
vector< size_t > morton
Definition: tree.hpp:899
set< size_t > FarIDs
Definition: tree.hpp:817
bool TryEnqueue()
Try to dispatch the task if there is no dependency left.
Definition: runtime.cpp:339
set< size_t > NNNearIDs
Definition: tree.hpp:833
static vector< size_t > Morton2Offsets(size_t me, size_t depth)
Definition: tree.hpp:127
Data and setup that are shared with all nodes.
Definition: tree.hpp:876
void RecursiveMorton(NODE *node, MortonHelper::Recursor r)
Definition: tree.hpp:1023
static bool ContainAny(size_t target, TQUERY &querys)
Definition: tree.hpp:167
SETUP setup
Definition: tree.hpp:968
bool ContainAny(set< Node * > &querys)
Definition: tree.hpp:742
bool DoOutOfOrder()
Definition: tree.hpp:1473
static size_t SiblingMortonID(Recursor r)
Definition: tree.hpp:91
This is the default ball tree splitter. Given coordinates, compute the direction from the two most fa...
Definition: tree.hpp:595
Node(size_t morton)
Definition: tree.hpp:645
void DependOnNoOne(Task *task)
Definition: tree.hpp:787
void TraverseLeafs(TASK &dummy, Args &...args)
Definition: tree.hpp:1311
Wrapper for omp or pthread mutex.
Definition: tci.hpp:50
void TraverseUnOrdered(TASK &dummy, Args &...args)
For unordered traversal, we just call local downward traversal.
Definition: tree.hpp:1445
unordered_map< size_t, NODE * > morton2node
Definition: tree.hpp:990
void MergeNeighbors(size_t k, pair< T, size_t > *A, pair< T, size_t > *B, vector< pair< T, size_t >> &aux)
Definition: tree.hpp:212
Definition: tree.hpp:62
void RecuTaskExecute(ARG *arg)
Recursive task execution (base case).
Definition: runtime.hpp:469
size_t treelist_id
Definition: tree.hpp:812
set< size_t > NNFarIDs
Definition: tree.hpp:827
static int Morton2Rank(size_t it, int size)
return the MPI rank that owns it.
Definition: tree.hpp:103
Node< SETUP, NODEDATA > NODE
Definition: tree.hpp:961
Data< int > CheckAllInteractions()
Definition: tree.hpp:1264
void Summary(SUMMARY &summary)
Summarize all events in each level.
Definition: tree.hpp:1478
Definition: Data.hpp:134
static size_t MortonID(Recursor r)
Definition: tree.hpp:83
void TraverseUp(TASK &dummy, Args &...args)
Definition: tree.hpp:1345
vector< map< size_t, Data< T > > > DistFar
Definition: tree.hpp:846
vector< size_t > GetPermutation()
Definition: tree.hpp:1153
bool ContainAny(vector< size_t > &queries)
Check if this node contain any query using morton. Notice that queries[] contains gids; thus...
Definition: tree.hpp:717
Lock lock
Definition: tree.hpp:981
vector< size_t > ContainAny(vector< size_t > &queries, size_t target)
Check if this node contain any query using morton. Notice that queries[] contains gids; thus...
Definition: tree.hpp:912
static void Morton2Offsets(Recursor r, size_t depth, vector< size_t > &offsets)
Definition: tree.hpp:113
bool less_first(const pair< T, size_t > &a, const pair< T, size_t > &b)
Definition: tree.hpp:194
void TraverseDown(TASK &dummy, Args &...args)
Definition: tree.hpp:1396
Definition: gofmm.hpp:83
void Print()
Definition: tree.hpp:761
SPLITTER splitter
Definition: tree.hpp:902
Definition: runtime.hpp:174
Definition: thread.hpp:166