GCC Code Coverage Report
Directory: . Exec Total Coverage
File: gofmm/tree_mpi.hpp Lines: 0 280 0.0 %
Date: 2019-01-14 Branches: 0 2261 0.0 %

Line Exec Source
1
/**
2
 *  HMLP (High-Performance Machine Learning Primitives)
3
 *
4
 *  Copyright (C) 2014-2017, The University of Texas at Austin
5
 *
6
 *  This program is free software: you can redistribute it and/or modify
7
 *  it under the terms of the GNU General Public License as published by
8
 *  the Free Software Foundation, either version 3 of the License, or
9
 *  (at your option) any later version.
10
 *
11
 *  This program is distributed in the hope that it will be useful,
12
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
 *  GNU General Public License for more details.
15
 *
16
 *  You should have received a copy of the GNU General Public License
17
 *  along with this program. If not, see the LICENSE file.
18
 *
19
 **/
20
21
22
#ifndef MPITREE_HPP
23
#define MPITREE_HPP
24
25
/** Inherit most of the classes from shared-memory GOFMM. */
26
#include <tree.hpp>
27
/** Use distributed matrices inspired by the Elemental notation. */
28
//#include <DistData.hpp>
29
/** Use STL and HMLP namespaces. */
30
using namespace std;
31
using namespace hmlp;
32
33
34
namespace hmlp
35
{
36
namespace mpitree
37
{
38
39
40
41
///**
42
// *  @brief This is the default ball tree splitter. Given coordinates,
43
// *         compute the direction from the two most far away points.
44
// *         Project all points to this line and split into two groups
45
// *         using a median select.
46
// *
47
// *  @para
48
// *
49
// *  @TODO  Need to explit the parallelism.
50
// */
51
//template<int N_SPLIT, typename T>
52
//struct centersplit
53
//{
54
//  // closure
55
//  Data<T> *Coordinate;
56
//
57
//  inline vector<vector<size_t> > operator()
58
//  (
59
//    vector<size_t>& gids
60
//  ) const
61
//  {
62
//    assert( N_SPLIT == 2 );
63
//
64
//    Data<T> &X = *Coordinate;
65
//    size_t d = X.row();
66
//    size_t n = gids.size();
67
//
68
//    T rcx0 = 0.0, rx01 = 0.0;
69
//    size_t x0, x1;
70
//    vector<vector<size_t> > split( N_SPLIT );
71
//
72
//
73
//    vector<T> centroid = combinatorics::Mean( d, n, X, gids );
74
//    vector<T> direction( d );
75
//    vector<T> projection( n, 0.0 );
76
//
77
//    //printf( "After Mean\n" );
78
//
79
//    // Compute the farest x0 point from the centroid
80
//    for ( int i = 0; i < n; i ++ )
81
//    {
82
//      T rcx = 0.0;
83
//      for ( int p = 0; p < d; p ++ )
84
//      {
85
//        T tmp = X[ gids[ i ] * d + p ] - centroid[ p ];
86
//        rcx += tmp * tmp;
87
//      }
88
//      //printf( "\n" );
89
//      if ( rcx > rcx0 )
90
//      {
91
//        rcx0 = rcx;
92
//        x0 = i;
93
//      }
94
//    }
95
//
96
//
97
//    // Compute the farest point x1 from x0
98
//    for ( int i = 0; i < n; i ++ )
99
//    {
100
//      T rxx = 0.0;
101
//      for ( int p = 0; p < d; p ++ )
102
//      {
103
//        T tmp = X[ gids[ i ] * d + p ] - X[ gids[ x0 ] * d + p ];
104
//        rxx += tmp * tmp;
105
//      }
106
//      if ( rxx > rx01 )
107
//      {
108
//        rx01 = rxx;
109
//        x1 = i;
110
//      }
111
//    }
112
//
113
//    // Compute direction
114
//    for ( int p = 0; p < d; p ++ )
115
//    {
116
//      direction[ p ] = X[ gids[ x1 ] * d + p ] - X[ gids[ x0 ] * d + p ];
117
//    }
118
//
119
//    // Compute projection
120
//    projection.resize( n, 0.0 );
121
//    for ( int i = 0; i < n; i ++ )
122
//      for ( int p = 0; p < d; p ++ )
123
//        projection[ i ] += X[ gids[ i ] * d + p ] * direction[ p ];
124
//
125
//    /** Parallel median search */
126
//    T median;
127
//
128
//    if ( 1 )
129
//    {
130
//      median = hmlp::combinatorics::Select( n, n / 2, projection );
131
//    }
132
//    else
133
//    {
134
//      auto proj_copy = projection;
135
//      std::sort( proj_copy.begin(), proj_copy.end() );
136
//      median = proj_copy[ n / 2 ];
137
//    }
138
//
139
//    split[ 0 ].reserve( n / 2 + 1 );
140
//    split[ 1 ].reserve( n / 2 + 1 );
141
//
142
//
143
//    /** TODO: Can be parallelized */
144
//    std::vector<std::size_t> middle;
145
//    for ( size_t i = 0; i < n; i ++ )
146
//    {
147
//      if      ( projection[ i ] < median ) split[ 0 ].push_back( i );
148
//      else if ( projection[ i ] > median ) split[ 1 ].push_back( i );
149
//      else                                 middle.push_back( i );
150
//    }
151
//
152
//    for ( size_t i = 0; i < middle.size(); i ++ )
153
//    {
154
//      if ( split[ 0 ].size() <= split[ 1 ].size() ) split[ 0 ].push_back( middle[ i ] );
155
//      else                                          split[ 1 ].push_back( middle[ i ] );
156
//    }
157
//
158
//
159
//    return split;
160
//  };
161
//
162
//
163
//  inline std::vector<std::vector<size_t> > operator()
164
//  (
165
//    std::vector<size_t>& gids,
166
//    hmlp::mpi::Comm comm
167
//  ) const
168
//  {
169
//    std::vector<std::vector<size_t> > split( N_SPLIT );
170
//
171
//    return split;
172
//  };
173
//
174
//};
175
//
176
//
177
//
178
//
179
//
180
//template<int N_SPLIT, typename T>
181
//struct randomsplit
182
//{
183
//  Data<T> *Coordinate = NULL;
184
//
185
//  inline vector<vector<size_t> > operator() ( vector<size_t>& gids ) const
186
//  {
187
//    vector<vector<size_t> > split( N_SPLIT );
188
//    return split;
189
//  };
190
//
191
//  inline vector<vector<size_t> > operator() ( vector<size_t>& gids, mpi::Comm comm ) const
192
//  {
193
//    vector<vector<size_t> > split( N_SPLIT );
194
//    return split;
195
//  };
196
//};
197
//
198
199
200
201
202
203
template<typename NODE>
204
class DistSplitTask : public Task
205
{
206
  public:
207
208
    NODE *arg = NULL;
209
210
    void Set( NODE *user_arg )
211
    {
212
      arg = user_arg;
213
      name = string( "DistSplit" );
214
      label = to_string( arg->treelist_id );
215
216
      double flops = 6.0 * arg->n;
217
      double  mops = 6.0 * arg->n;
218
219
      /** Setup the event */
220
      event.Set( label + name, flops, mops );
221
      /** Asuume computation bound */
222
      cost = mops / 1E+9;
223
      /** "HIGH" priority */
224
      priority = true;
225
    };
226
227
228
    void DependencyAnalysis()
229
    {
230
      arg->DependencyAnalysis( R, this );
231
232
      if ( !arg->isleaf )
233
      {
234
        if ( arg->GetCommSize() > 1 )
235
        {
236
					assert( arg->child );
237
          arg->child->DependencyAnalysis( RW, this );
238
        }
239
        else
240
        {
241
					assert( arg->lchild && arg->rchild );
242
          arg->lchild->DependencyAnalysis( RW, this );
243
          arg->rchild->DependencyAnalysis( RW, this );
244
        }
245
      }
246
      this->TryEnqueue();
247
		};
248
249
    void Execute( Worker* user_worker ) { arg->Split(); };
250
251
}; /** end class DistSplitTask */
252
253
254
255
256
/**
257
 *  @brief Data and setup that are shared with all nodes.
258
 */
259
template<typename SPLITTER, typename DATATYPE>
260
class Setup
261
{
262
  public:
263
264
    typedef DATATYPE T;
265
266
    Setup() {};
267
268
    ~Setup() {};
269
270
271
272
273
    /**
274
     *  @brief Check if this node contain any query using morton.
275
		 *         Notice that queries[] contains gids; thus, morton[]
276
		 *         needs to be accessed using gids.
277
     *
278
     */
279
		vector<size_t> ContainAny( vector<size_t> &queries, size_t target )
280
    {
281
			vector<size_t> validation( queries.size(), 0 );
282
283
      if ( !morton.size() )
284
      {
285
        printf( "Morton id was not initialized.\n" );
286
        exit( 1 );
287
      }
288
289
      for ( size_t i = 0; i < queries.size(); i ++ )
290
      {
291
				/** notice that setup->morton only contains local morton ids */
292
        //auto it = this->setup->morton.find( queries[ i ] );
293
294
				//if ( it != this->setup->morton.end() )
295
				//{
296
        //  if ( tree::IsMyParent( *it, this->morton ) ) validation[ i ] = 1;
297
				//}
298
299
300
       if ( MortonHelper::IsMyParent( morton[ queries[ i ] ], target ) )
301
				 validation[ i ] = 1;
302
303
      }
304
      return validation;
305
306
    }; /** end ContainAny() */
307
308
309
310
311
    /** maximum leaf node size */
312
    size_t m;
313
314
    /** by default we use 4 bits = 0-15 levels */
315
    size_t max_depth = 15;
316
317
    /** coordinates (accessed with gids) */
318
    //DistData<STAR, CBLK, T> *X_cblk = NULL;
319
    //DistData<STAR, CIDS, T> *X      = NULL;
320
321
    /** neighbors<distance, gid> (accessed with gids) */
322
    DistData<STAR, CBLK, pair<T, size_t>> *NN_cblk = NULL;
323
    DistData<STAR, CIDS, pair<T, size_t>> *NN      = NULL;
324
325
    /** morton ids */
326
    vector<size_t> morton;
327
328
    /** tree splitter */
329
    SPLITTER splitter;
330
331
}; /** end class Setup */
332
333
334
335
template<typename NODE>
336
class DistIndexPermuteTask : public Task
337
{
338
  public:
339
340
    NODE *arg = NULL;
341
342
    void Set( NODE *user_arg )
343
    {
344
      name = std::string( "Permutation" );
345
      arg = user_arg;
346
      // Need an accurate cost model.
347
      cost = 1.0;
348
    };
349
350
    void DependencyAnalysis() { arg->DependOnChildren( this ); };
351
    //{
352
    //  arg->DependencyAnalysis( hmlp::ReadWriteType::RW, this );
353
    //  if ( !arg->isleaf && !arg->child )
354
    //  {
355
    //    arg->lchild->DependencyAnalysis( hmlp::ReadWriteType::R, this );
356
    //    arg->rchild->DependencyAnalysis( hmlp::ReadWriteType::R, this );
357
    //  }
358
    //  this->TryEnqueue();
359
    //};
360
361
362
    void Execute( Worker* user_worker )
363
    {
364
      if ( !arg->isleaf && !arg->child )
365
      {
366
        auto &gids = arg->gids;
367
        auto &lgids = arg->lchild->gids;
368
        auto &rgids = arg->rchild->gids;
369
        gids = lgids;
370
        gids.insert( gids.end(), rgids.begin(), rgids.end() );
371
      }
372
    };
373
374
}; /** end class IndexPermuteTask */
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
/**
392
 *
393
 */
394
//template<typename SETUP, int N_CHILDREN, typename NODEDATA>
395
template<typename SETUP, typename NODEDATA>
396
class Node : public tree::Node<SETUP, NODEDATA>
397
{
398
  public:
399
400
    /** Deduce data type from SETUP. */
401
    typedef typename SETUP::T T;
402
403
    static const int N_CHILDREN = 2;
404
405
    /** Inherit all parameters from tree::Node */
406
407
408
    /** (Default) constructor for inner nodes (gids and n unassigned) */
409
    Node( SETUP *setup, size_t n, size_t l,
410
        Node *parent,
411
        unordered_map<size_t, tree::Node<SETUP, NODEDATA>*> *morton2node,
412
        Lock *treelock, mpi::Comm comm )
413
      : tree::Node<SETUP, NODEDATA>( setup, n, l,
414
          parent, morton2node, treelock )
415
    {
416
      /** Local communicator */
417
      this->comm = comm;
418
      /** Get MPI size and rank. */
419
      mpi::Comm_size( comm, &size );
420
      mpi::Comm_rank( comm, &rank );
421
    };
422
423
    /** (Default) constructor for root. */
424
    Node( SETUP *setup, size_t n, size_t l, vector<size_t> &gids,
425
        Node *parent,
426
        unordered_map<size_t, tree::Node<SETUP, NODEDATA>*> *morton2node,
427
        Lock *treelock, mpi::Comm comm )
428
      : Node<SETUP, NODEDATA>( setup, n, l, parent,
429
          morton2node, treelock, comm )
430
    {
431
      /** Notice that "gids.size() < n". */
432
      this->gids = gids;
433
    };
434
435
    /** (Default) constructor for LET nodes. */
436
    Node( size_t morton ) : tree::Node<SETUP, NODEDATA>( morton )
437
    {
438
    };
439
440
441
    //void SetupChild( class Node *child )
442
    //{
443
    //  this->kids[ 0 ] = child;
444
    //  this->child = child;
445
    //};
446
447
448
449
    /** */
450
    void Split()
451
    {
452
      /** Reduce to get the total size of gids. */
453
      int num_points_total = 0;
454
      int num_points_owned = (this->gids).size();
455
      /** n = sum( num_points_owned ) over all MPI processes in comm. */
456
      mpi::Allreduce( &num_points_owned, &num_points_total, 1, MPI_SUM, comm );
457
      this->n = num_points_total;
458
459
      if ( child )
460
      {
461
        /** The local communicator of this node contains at least 2 processes. */
462
        assert( size > 1 );
463
464
        /** Invoke distributed splitter. */
465
        auto split = this->setup->splitter( this->gids, comm );
466
467
        /** Get partner MPI rank. */
468
        int partner_rank = 0;
469
        int sent_size = 0;
470
        int recv_size = 0;
471
        vector<size_t> &kept_gids = child->gids;
472
        vector<int>     sent_gids;
473
        vector<int>     recv_gids;
474
475
        if ( rank < size / 2 )
476
        {
477
          /** left child */
478
          partner_rank = rank + size / 2;
479
          /** MPI ranks 0:size/2-1 keep split[ 0 ] */
480
          kept_gids.resize( split[ 0 ].size() );
481
          for ( size_t i = 0; i < kept_gids.size(); i ++ )
482
            kept_gids[ i ] = this->gids[ split[ 0 ][ i ] ];
483
          /** MPI ranks 0:size/2-1 send split[ 1 ] */
484
          sent_gids.resize( split[ 1 ].size() );
485
          sent_size = sent_gids.size();
486
          for ( size_t i = 0; i < sent_gids.size(); i ++ )
487
            sent_gids[ i ] = this->gids[ split[ 1 ][ i ] ];
488
        }
489
        else
490
        {
491
          /** right child */
492
          partner_rank = rank - size / 2;
493
          /** MPI ranks size/2:size-1 keep split[ 1 ] */
494
          kept_gids.resize( split[ 1 ].size() );
495
          for ( size_t i = 0; i < kept_gids.size(); i ++ )
496
            kept_gids[ i ] = this->gids[ split[ 1 ][ i ] ];
497
          /** MPI ranks size/2:size-1 send split[ 0 ] */
498
          sent_gids.resize( split[ 0 ].size() );
499
          sent_size = sent_gids.size();
500
          for ( size_t i = 0; i < sent_gids.size(); i ++ )
501
            sent_gids[ i ] = this->gids[ split[ 0 ][ i ] ];
502
        }
503
        assert( partner_rank >= 0 );
504
505
506
507
508
509
        ///** Exchange recv_gids.size(). */
510
        //mpi::Sendrecv( &sent_size, 1, partner_rank, 10,
511
        //               &recv_size, 1, partner_rank, 10, comm, &status );
512
513
        //printf( "rank %d kept_size %lu sent_size %d recv_size %d\n",
514
        //    rank, kept_gids.size(), sent_size, recv_size ); fflush( stdout );
515
516
        ///** resize recv_gids */
517
        //recv_gids.resize( recv_size );
518
519
        ///** Exchange recv_gids.size() */
520
        //mpi::Sendrecv(
521
        //    sent_gids.data(), sent_size, MPI_INT, partner_rank, 20,
522
        //    recv_gids.data(), recv_size, MPI_INT, partner_rank, 20,
523
        //    comm, &status );
524
525
        /** Exchange gids with my partner. */
526
        mpi::ExchangeVector( sent_gids, partner_rank, 20,
527
                             recv_gids, partner_rank, 20, comm, &status );
528
        /** kept_gids += recv_gids. */
529
        for ( auto it : recv_gids ) kept_gids.push_back( it );
530
        //kept_gids.reserve( kept_gids.size() + recv_gids.size() );
531
        //for ( size_t i = 0; i < recv_gids.size(); i ++ )
532
        //  kept_gids.push_back( recv_gids[ i ] );
533
534
535
      }
536
			else
537
			{
538
        tree::Node<SETUP, NODEDATA>::Split();
539
		  }
540
      /** Synchronize within local communicator */
541
      mpi::Barrier( comm );
542
    }; /** end Split() */
543
544
545
    /** @brief Support dependency analysis. */
546
    void DependOnChildren( Task *task )
547
    {
548
      this->DependencyAnalysis( RW, task );
549
      if ( size < 2 )
550
      {
551
        if ( this->lchild ) this->lchild->DependencyAnalysis( R, task );
552
        if ( this->rchild ) this->rchild->DependencyAnalysis( R, task );
553
      }
554
      else
555
      {
556
        if ( child ) child->DependencyAnalysis( R, task );
557
      }
558
      /** Try to enqueue if there is no dependency. */
559
      task->TryEnqueue();
560
    }; /** end DependOnChildren() */
561
562
    /** @brief */
563
    void DependOnParent( Task *task )
564
    {
565
      this->DependencyAnalysis( R, task );
566
      if ( size < 2 )
567
      {
568
        if ( this->lchild ) this->lchild->DependencyAnalysis( RW, task );
569
        if ( this->rchild ) this->rchild->DependencyAnalysis( RW, task );
570
      }
571
      else
572
      {
573
        if ( child ) child->DependencyAnalysis( RW, task );
574
      }
575
      /** Try to enqueue if there is no dependency. */
576
      task->TryEnqueue();
577
    }; /** end DependOnParent() */
578
579
580
    /** Preserve for debugging. */
581
    void Print() {};
582
583
    /** Return local MPI communicator. */
584
    mpi::Comm GetComm() { return comm; };
585
    /** Return local MPI size. */
586
    int GetCommSize() { return size; };
587
    /** Return local MPI rank. */
588
    int GetCommRank() { return rank; };
589
590
    /** Distributed tree nodes only have one child. */
591
    Node *child = NULL;
592
593
  private:
594
595
    /** Initialize with all processes. */
596
    mpi::Comm comm = MPI_COMM_WORLD;
597
    /** MPI status. */
598
    mpi::Status status;
599
    /** Subcommunicator size. */
600
    int size = 1;
601
    /** Subcommunicator rank. */
602
    int rank = 0;
603
604
}; /** end class Node */
605
606
607
608
609
610
611
612
613
614
615
/**
616
 *  @brief This distributed tree inherits the shared memory tree
617
 *         with some additional MPI data structure and function call.
618
 */
619
template<class SETUP, class NODEDATA>
620
class Tree : public tree::Tree<SETUP, NODEDATA>,
621
             public mpi::MPIObject
622
{
623
  public:
624
625
    typedef typename SETUP::T T;
626
627
    /**
628
     *  Inherit parameters n, m, and depth; local treelists and morton2node map.
629
     *
630
     *  Explanation for the morton2node map in the distributed tree:
631
     *
632
     *  morton2node has type map<size_t, tree::Node>, but it actually contains
633
     *  "three" different kinds of tree nodes.
634
     *
635
     *  1. Local tree nodes (exactly in type tree::Node)
636
     *
637
     *  2. Distributed tree nodes (in type mpitree::Node)
638
     *
639
     *  3. Local essential nodes (in type tree::Node with essential data)
640
     *
641
     */
642
643
    /**
644
     *  Define local tree node type as NODE. Notice that all pointers in the
645
     *  interaction lists and morton2node map will be in this type.
646
     */
647
    typedef tree::Node<SETUP, NODEDATA> NODE;
648
649
    /** Define distributed tree node type as MPINODE.  */
650
    typedef Node<SETUP, NODEDATA> MPINODE;
651
652
    /**
653
     *  Distribued tree nodes in the top-down order. Notice thay
654
     *  mpitreelist.back() is the root of the local tree.
655
     *
656
     *  i.e. mpitrelist.back() == treelist.front();
657
     */
658
    vector<MPINODE*> mpitreelists;
659
660
661
    /** (Default) Tree constructor */
662
    Tree( mpi::Comm comm ) : tree::Tree<SETUP, NODEDATA>::Tree(),
663
                             mpi::MPIObject( comm )
664
    {
665
      //this->comm = comm;
666
			/** Get size and rank */
667
      //mpi::Comm_size( comm, &size );
668
      //mpi::Comm_rank( comm, &rank );
669
      /** Create a ReadWrite object per rank */
670
      //NearRecvFrom.resize( size );
671
      NearRecvFrom.resize( this->GetCommSize() );
672
      //FarRecvFrom.resize( size );
673
      FarRecvFrom.resize( this->GetCommSize() );
674
    };
675
676
    /** (Default) Tree destructor.  */
677
    ~Tree()
678
    {
679
      //printf( "~Tree() distributed, mpitreelists.size() %lu\n",
680
      //    mpitreelists.size() ); fflush( stdout );
681
      /**
682
       *  we do not free the last tree node, it will be deleted by
683
       *  hmlp::tree::Tree()::~Tree()
684
       */
685
      if ( mpitreelists.size() )
686
      {
687
        for ( size_t i = 0; i < mpitreelists.size() - 1; i ++ )
688
          if ( mpitreelists[ i ] ) delete mpitreelists[ i ];
689
        mpitreelists.clear();
690
      }
691
      //printf( "end ~Tree() distributed\n" ); fflush( stdout );
692
    };
693
694
695
    /**
696
     *  @brief free all tree nodes including local tree nodes,
697
     *         distributed tree nodes and let nodes
698
     */
699
    void CleanUp()
700
    {
701
      /** Free all local tree nodes */
702
      if ( this->treelist.size() )
703
      {
704
        for ( size_t i = 0; i < this->treelist.size(); i ++ )
705
          if ( this->treelist[ i ] ) delete this->treelist[ i ];
706
      }
707
      this->treelist.clear();
708
709
      /** Free all distributed tree nodes */
710
      if ( mpitreelists.size() )
711
      {
712
        for ( size_t i = 0; i < mpitreelists.size() - 1; i ++ )
713
          if ( mpitreelists[ i ] ) delete mpitreelists[ i ];
714
      }
715
      mpitreelists.clear();
716
717
    }; /** end CleanUp() */
718
719
720
    /** @breif Allocate all distributed tree nodse. */
721
    void AllocateNodes( vector<size_t> &gids )
722
    {
723
      /** Decide the depth of the distributed tree according to mpi size. */
724
      //auto mycomm  = comm;
725
      auto mycomm  = this->GetComm();
726
      //int mysize  = size;
727
      int mysize  = this->GetCommSize();
728
      //int myrank  = rank;
729
      int myrank  = this->GetCommRank();
730
      int mycolor = 0;
731
      size_t mylevel = 0;
732
733
      /** Allocate root( setup, n = 0, l = 0, parent = NULL ). */
734
      auto *root = new MPINODE( &(this->setup),
735
          this->n, mylevel, gids, NULL,
736
          &(this->morton2node), &(this->lock), mycomm );
737
738
      /** Push root to the mpi treelist. */
739
      mpitreelists.push_back( root );
740
741
      /** Recursively spliiting the communicator. */
742
      while ( mysize > 1 )
743
      {
744
        mpi::Comm childcomm;
745
746
        /** Increase level. */
747
        mylevel += 1;
748
        /** Left color = 0, right color = 1. */
749
        mycolor = ( myrank < mysize / 2 ) ? 0 : 1;
750
        /** Split and assign the subcommunicators for children. */
751
        ierr = mpi::Comm_split( mycomm, mycolor, myrank, &(childcomm) );
752
        /** Update mycomm, mysize, and myrank to proceed to the next iteration. */
753
        mycomm = childcomm;
754
        mpi::Comm_size( mycomm, &mysize );
755
        mpi::Comm_rank( mycomm, &myrank );
756
757
        /** Create the child node. */
758
        auto *parent = mpitreelists.back();
759
        auto *child  = new MPINODE( &(this->setup),
760
            (size_t)0, mylevel, parent,
761
            &(this->morton2node), &(this->lock), mycomm );
762
763
        /** Create the sibling in type NODE but not MPINODE. */
764
        child->sibling = new NODE( (size_t)0 ); // Node morton is computed later.
765
        /** Setup parent's children */
766
        //parent->SetupChild( child );
767
        parent->kids[ 0 ] = child;
768
        parent->child = child;
769
        /** Push to the mpi treelist */
770
        mpitreelists.push_back( child );
771
      }
772
      /** Global synchronization. */
773
      this->Barrier();
774
775
			/** Allocate local tree nodes. */
776
      auto *local_tree_root = mpitreelists.back();
777
      tree::Tree<SETUP, NODEDATA>::AllocateNodes( local_tree_root );
778
779
    }; /** end AllocateNodes() */
780
781
782
783
784
    vector<size_t> GetPermutation()
785
    {
786
      vector<size_t> perm_loc, perm_glb;
787
      perm_loc = tree::Tree<SETUP, NODEDATA>::GetPermutation();
788
      mpi::GatherVector( perm_loc, perm_glb, 0, this->GetComm() );
789
790
      //if ( rank == 0 )
791
      //{
792
      //  /** Sanity check using an 0:N-1 table. */
793
      //  vector<bool> Table( this->n, false );
794
      //  for ( size_t i = 0; i < perm_glb.size(); i ++ )
795
      //    Table[ perm_glb[ i ] ] = true;
796
      //  for ( size_t i = 0; i < Table.size(); i ++ ) assert( Table[ i ] );
797
      //}
798
799
      return perm_glb;
800
    }; /** end GetTreePermutation() */
801
802
803
804
805
806
    /** Perform approximate kappa neighbor search. */
807
    template<typename KNNTASK>
808
    DistData<STAR, CBLK, pair<T, size_t>>
809
    AllNearestNeighbor( size_t n_tree, size_t n, size_t k,
810
      pair<T, size_t> initNN, KNNTASK &dummy )
811
    {
812
      mpi::PrintProgress( "[BEG] NeighborSearch ...", this->GetComm() );
813
814
      /** Get the problem size from setup->K->row(). */
815
      this->n = n;
816
      /** k-by-N, column major. */
817
      DistData<STAR, CBLK, pair<T, size_t>> NN( k, n, initNN, this->GetComm() );
818
      /** Use leaf size = 4 * k.  */
819
      this->setup.m = 4 * k;
820
      if ( this->setup.m < 512 ) this->setup.m = 512;
821
      this->m = this->setup.m;
822
823
824
      ///** Local problem size (assuming Round-Robin) */
825
      ////num_points_owned = ( n - 1 ) / size + 1;
826
      //num_points_owned = ( n - 1 ) / this->GetCommSize() + 1;
827
828
      ///** Edge case */
829
      //if ( n % this->GetCommSize() )
830
      //{
831
      //  //if ( rank >= ( n % size ) ) num_points_owned -= 1;
832
      //  if ( this->GetCommRank() >= ( n % this->GetCommSize() ) )
833
      //    num_points_owned -= 1;
834
      //}
835
836
      /** Local problem size (assuming Round-Robin) */
837
      num_points_owned = n / this->GetCommSize();
838
      /** Edge case */
839
      if ( this->GetCommRank() < ( n % this->GetCommSize() ) )
840
         num_points_owned += 1;
841
842
843
844
      /** Initial gids distribution (asssuming Round-Robin) */
845
      vector<size_t> gids( num_points_owned, 0 );
846
      for ( size_t i = 0; i < num_points_owned; i ++ )
847
        //gids[ i ] = i * size + rank;
848
        gids[ i ] = i * this->GetCommSize() + this->GetCommRank();
849
850
      /** Allocate distributed tree nodes in advance. */
851
      AllocateNodes( gids );
852
853
854
      /** Metric tree partitioning. */
855
      DistSplitTask<MPINODE> mpisplittask;
856
      tree::SplitTask<NODE>  seqsplittask;
857
      for ( size_t t = 0; t < n_tree; t ++ )
858
      {
859
        DistTraverseDown( mpisplittask );
860
        LocaTraverseDown( seqsplittask );
861
        ExecuteAllTasks();
862
863
        /** Query neighbors computed in CIDS distribution.  */
864
        DistData<STAR, CIDS, pair<T, size_t>> Q_cids( k, this->n, this->treelist[ 0 ]->gids, initNN, this->GetComm() );
865
        /** Pass in neighbor pointer. */
866
        this->setup.NN = &Q_cids;
867
        LocaTraverseLeafs( dummy );
868
        ExecuteAllTasks();
869
870
        /** Queries computed in CBLK distribution */
871
        DistData<STAR, CBLK, pair<T, size_t>> Q_cblk( k, this->n, this->GetComm() );
872
        /** Redistribute from CIDS to CBLK */
873
        Q_cblk = Q_cids;
874
        /** Merge Q_cblk into NN (sort and remove duplication) */
875
        assert( Q_cblk.col_owned() == NN.col_owned() );
876
        MergeNeighbors( k, NN.col_owned(), NN, Q_cblk );
877
      }
878
879
880
881
882
883
884
885
//      /** Metric tree partitioning. */
886
//      DistSplitTask<MPINODE> mpisplittask;
887
//      tree::SplitTask<NODE>  seqsplittask;
888
//      DependencyCleanUp();
889
//      DistTraverseDown( mpisplittask );
890
//      LocaTraverseDown( seqsplittask );
891
//      ExecuteAllTasks();
892
//
893
//
894
//      for ( size_t t = 0; t < n_tree; t ++ )
895
//      {
896
//        this->Barrier();
897
//        //if ( this->GetCommRank() == 0 ) printf( "Iteration #%lu\n", t );
898
//
899
//        /** Query neighbors computed in CIDS distribution.  */
900
//        DistData<STAR, CIDS, pair<T, size_t>> Q_cids( k, this->n,
901
//            this->treelist[ 0 ]->gids, initNN, this->GetComm() );
902
//        /** Pass in neighbor pointer. */
903
//        this->setup.NN = &Q_cids;
904
//        /** Overlap */
905
//        if ( t != n_tree - 1 )
906
//        {
907
//          //DependencyCleanUp();
908
//          DistTraverseDown( mpisplittask );
909
//          ExecuteAllTasks();
910
//        }
911
//        mpi::PrintProgress( "[MID] Here ...", this->GetComm() );
912
//        DependencyCleanUp();
913
//        LocaTraverseLeafs( dummy );
914
//        LocaTraverseDown( seqsplittask );
915
//        ExecuteAllTasks();
916
//        mpi::PrintProgress( "[MID] Here 22...", this->GetComm() );
917
//
918
//        if ( t == 0 )
919
//        {
920
//          /** Redistribute from CIDS to CBLK */
921
//          NN = Q_cids;
922
//        }
923
//        else
924
//        {
925
//          /** Queries computed in CBLK distribution */
926
//          DistData<STAR, CBLK, pair<T, size_t>> Q_cblk( k, this->n, this->GetComm() );
927
//          /** Redistribute from CIDS to CBLK */
928
//          Q_cblk = Q_cids;
929
//          /** Merge Q_cblk into NN (sort and remove duplication) */
930
//          assert( Q_cblk.col_owned() == NN.col_owned() );
931
//          MergeNeighbors( k, NN.col_owned(), NN, Q_cblk );
932
//        }
933
//
934
//        //double mer_time = omp_get_wtime() - beg;
935
//
936
//        //if ( rank == 0 )
937
//        //printf( "%lfs %lfs %lfs\n", mpi_time, seq_time, mer_time ); fflush( stdout );
938
//      }
939
940
      /** Check for illegle values. */
941
      for ( auto &neig : NN )
942
      {
943
        if ( neig.second < 0 || neig.second >= NN.col() )
944
        {
945
          printf( "Illegle neighbor gid %lu\n", neig.second );
946
          break;
947
        }
948
      }
949
950
      mpi::PrintProgress( "[END] NeighborSearch ...", this->GetComm() );
951
      return NN;
952
    }; /** end AllNearestNeighbor() */
953
954
955
956
957
    /** @brief partition n points using a distributed binary tree. */
958
    void TreePartition()
959
    {
960
      mpi::PrintProgress( "[BEG] TreePartitioning ...", this->GetComm() );
961
962
      /** Set up total problem size n and leaf node size m. */
963
      this->n = this->setup.ProblemSize();
964
      this->m = this->setup.LeafNodeSize();
965
966
      /** Initial gids distribution (asssuming Round-Robin). */
967
      //for ( size_t i = rank; i < this->n; i += size )
968
      for ( size_t i = this->GetCommRank(); i < this->n; i += this->GetCommSize() )
969
        this->global_indices.push_back( i );
970
      /** Local problem size (assuming Round-Robin). */
971
      num_points_owned = this->global_indices.size();
972
      /** Allocate distributed tree nodes in advance. */
973
      AllocateNodes( this->global_indices );
974
975
976
977
      DependencyCleanUp();
978
979
980
981
982
      DistSplitTask<MPINODE> mpiSPLITtask;
983
      tree::SplitTask<NODE> seqSPLITtask;
984
      DistTraverseDown( mpiSPLITtask );
985
      LocaTraverseDown( seqSPLITtask );
986
      ExecuteAllTasks();
987
988
989
990
      tree::IndexPermuteTask<NODE> seqINDXtask;
991
      LocaTraverseUp( seqINDXtask );
992
      DistIndexPermuteTask<MPINODE> mpiINDXtask;
993
      DistTraverseUp( mpiINDXtask );
994
      ExecuteAllTasks();
995
996
      //printf( "rank %d finish split\n", rank ); fflush( stdout );
997
998
999
      /** Allocate space for point MortonID. */
1000
      (this->setup).morton.resize( this->n );
1001
1002
      /** Compute Morton ID for both distributed and local trees. */
1003
      RecursiveMorton( mpitreelists[ 0 ], MortonHelper::Root() );
1004
1005
      /** Clean up the map. */
1006
      this->morton2node.clear();
1007
1008
      /** Construct morton2node map for the local tree. */
1009
      for ( auto node : this->treelist ) this->morton2node[ node->morton ] = node;
1010
1011
      /**Construc morton2node map for the distributed tree. */
1012
      for ( auto node : this->mpitreelists )
1013
      {
1014
        this->morton2node[ node->morton ] = node;
1015
        auto *sibling = node->sibling;
1016
        if ( node->l ) this->morton2node[ sibling->morton ] = sibling;
1017
      }
1018
1019
      this->Barrier();
1020
      mpi::PrintProgress( "[END] TreePartitioning ...", this->GetComm() );
1021
    }; /** end TreePartition() */
1022
1023
1024
    /** Assign MortonID to each node recursively. */
1025
    void RecursiveMorton( MPINODE *node, MortonHelper::Recursor r )
1026
    {
1027
      /** MPI Support. */
1028
      int comm_size = this->GetCommSize();
1029
      int comm_rank = this->GetCommRank();
1030
      int node_size = node->GetCommSize();
1031
      int node_rank = node->GetCommRank();
1032
1033
      /** Set the node MortonID. */
1034
      node->morton = MortonHelper::MortonID( r );
1035
      /** Set my sibling's MortonID. */
1036
      if ( node->sibling )
1037
        node->sibling->morton = MortonHelper::SiblingMortonID( r );
1038
1039
      if ( node_size < 2 )
1040
      {
1041
        /** Compute MortonID recursively for the local tree. */
1042
        tree::Tree<SETUP, NODEDATA>::RecursiveMorton( node, r );
1043
        /** Prepare to exchange all <gid,MortonID> pairs. */
1044
        auto &gids = this->treelist[ 0 ]->gids;
1045
        vector<int> recv_size( comm_size, 0 );
1046
        vector<int> recv_disp( comm_size, 0 );
1047
        vector<pair<size_t, size_t>> send_pairs;
1048
        vector<pair<size_t, size_t>> recv_pairs( this->n );
1049
1050
        /** Gather pairs I own. */
1051
        for ( auto it : gids )
1052
        {
1053
          send_pairs.push_back(
1054
              pair<size_t, size_t>( it, this->setup.morton[ it ]) );
1055
        }
1056
1057
        /** Exchange send_pairs.size(). */
1058
        int send_size = send_pairs.size();
1059
        mpi::Allgather( &send_size, 1, recv_size.data(), 1, this->GetComm() );
1060
        /** Compute displacement for Allgatherv. */
1061
        for ( size_t p = 1; p < comm_size; p ++ )
1062
        {
1063
          recv_disp[ p ] = recv_disp[ p - 1 ] + recv_size[ p - 1 ];
1064
        }
1065
        /** Sanity check for the total size. */
1066
        size_t total_gids = 0;
1067
        for ( size_t p = 0; p < comm_size; p ++ )
1068
        {
1069
          total_gids += recv_size[ p ];
1070
        }
1071
        assert( total_gids == this->n );
1072
        /** Exchange all pairs. */
1073
        mpi::Allgatherv( send_pairs.data(), send_size,
1074
            recv_pairs.data(), recv_size.data(), recv_disp.data(), this->GetComm() );
1075
        /** Fill in all MortonIDs. */
1076
        for ( auto it : recv_pairs ) this->setup.morton[ it.first ] = it.second;
1077
      }
1078
      else
1079
      {
1080
        if ( node_rank < node_size / 2 )
1081
        {
1082
          RecursiveMorton( node->child, MortonHelper::RecurLeft( r ) );
1083
        }
1084
        else
1085
        {
1086
          RecursiveMorton( node->child, MortonHelper::RecurRight( r ) );
1087
        }
1088
      }
1089
    }; /** end RecursiveMorton() */
1090
1091
1092
1093
1094
    Data<int> CheckAllInteractions()
1095
    {
1096
      /** Get the total depth of the tree. */
1097
      int total_depth = this->treelist.back()->l;
1098
      /** Number of total leaf nodes. */
1099
      int num_leafs = 1 << total_depth;
1100
      /** Create a 2^l-by-2^l table to check all interactions. */
1101
      Data<int> A( num_leafs, num_leafs, 0 );
1102
      Data<int> B( num_leafs, num_leafs, 0 );
1103
      /** Now traverse all tree nodes (excluding the root). */
1104
      for ( int t = 1; t < this->treelist.size(); t ++ )
1105
      {
1106
        auto *node = this->treelist[ t ];
1107
        ///** Loop over all near interactions. */
1108
        //for ( auto it : node->NNNearNodeMortonIDs )
1109
        //{
1110
        //  auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1111
        //  auto J = MortonHelper::Morton2Offsets(   it, total_depth );
1112
        //  for ( auto i : I ) for ( auto j : J ) A( i, j ) += 1;
1113
        //}
1114
        ///** Loop over all far interactions. */
1115
        //for ( auto it : node->NNFarNodeMortonIDs )
1116
        //{
1117
        //  auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1118
        //  auto J = MortonHelper::Morton2Offsets(   it, total_depth );
1119
        //  for ( auto i : I ) for ( auto j : J ) A( i, j ) += 1;
1120
        //}
1121
1122
        for ( int p = 0; p < this->GetCommSize(); p ++ )
1123
        {
1124
          if ( node->isleaf )
1125
          {
1126
            for ( auto & it : node->DistNear[ p ] )
1127
            {
1128
              auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1129
              auto J = MortonHelper::Morton2Offsets(   it.first, total_depth );
1130
              for ( auto i : I ) for ( auto j : J )
1131
              {
1132
                assert( i < num_leafs && j < num_leafs );
1133
                A( i, j ) += 1;
1134
              }
1135
            }
1136
          }
1137
          for ( auto & it : node->DistFar[ p ] )
1138
          {
1139
            auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1140
            auto J = MortonHelper::Morton2Offsets(   it.first, total_depth );
1141
            for ( auto i : I ) for ( auto j : J )
1142
            {
1143
              assert( i < num_leafs && j < num_leafs );
1144
              A( i, j ) += 1;
1145
            }
1146
          }
1147
        }
1148
      }
1149
1150
      for ( auto *node : mpitreelists )
1151
      {
1152
        ///** Loop over all near interactions. */
1153
        //for ( auto it : node->NNNearNodeMortonIDs )
1154
        //{
1155
        //  auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1156
        //  auto J = MortonHelper::Morton2Offsets(   it, total_depth );
1157
        //  for ( auto i : I ) for ( auto j : J ) A( i, j ) += 1;
1158
        //}
1159
        ///** Loop over all far interactions. */
1160
        //for ( auto it : node->NNFarNodeMortonIDs )
1161
        //{
1162
        //  auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1163
        //  auto J = MortonHelper::Morton2Offsets(   it, total_depth );
1164
        //  for ( auto i : I ) for ( auto j : J ) A( i, j ) += 1;
1165
        //}
1166
        for ( int p = 0; p < this->GetCommSize(); p ++ )
1167
        {
1168
          if ( node->isleaf )
1169
          {
1170
          for ( auto & it : node->DistNear[ p ] )
1171
          {
1172
            auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1173
            auto J = MortonHelper::Morton2Offsets(   it.first, total_depth );
1174
            for ( auto i : I ) for ( auto j : J )
1175
            {
1176
              assert( i < num_leafs && j < num_leafs );
1177
              A( i, j ) += 1;
1178
            }
1179
          }
1180
          }
1181
          for ( auto & it : node->DistFar[ p ] )
1182
          {
1183
            auto I = MortonHelper::Morton2Offsets( node->morton, total_depth );
1184
            auto J = MortonHelper::Morton2Offsets(   it.first, total_depth );
1185
            for ( auto i : I ) for ( auto j : J )
1186
            {
1187
              assert( i < num_leafs && j < num_leafs );
1188
              A( i, j ) += 1;
1189
            }
1190
          }
1191
        }
1192
      }
1193
1194
      /** Reduce */
1195
      mpi::Reduce( A.data(), B.data(), A.size(), MPI_SUM, 0, this->GetComm() );
1196
1197
      if ( this->GetCommRank() == 0 )
1198
      {
1199
        for ( size_t i = 0; i < num_leafs; i ++ )
1200
        {
1201
          for ( size_t j = 0; j < num_leafs; j ++ ) printf( "%d", B( i, j ) );
1202
          printf( "\n" );
1203
        }
1204
      }
1205
1206
      return B;
1207
    }; /** end CheckAllInteractions() */
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
    /** */
1222
    int Morton2Rank( size_t it )
1223
    {
1224
      return MortonHelper::Morton2Rank( it, this->GetCommSize() );
1225
    }; /** end Morton2Rank() */
1226
1227
    int Index2Rank( size_t gid )
1228
    {
1229
       return Morton2Rank( this->setup.morton[ gid ] );
1230
    }; /** end Morton2Rank() */
1231
1232
1233
1234
1235
1236
    template<typename TASK, typename... Args>
1237
    void LocaTraverseUp( TASK &dummy, Args&... args )
1238
    {
1239
      /** contain at lesat one tree node */
1240
      assert( this->treelist.size() );
1241
1242
      /**
1243
       *  traverse the local tree without the root
1244
       *
1245
       *  IMPORTANT: local root alias of the distributed leaf node
1246
       *  IMPORTANT: here l must be int, size_t will wrap over
1247
       *
1248
       */
1249
1250
			//printf( "depth %lu\n", this->depth ); fflush( stdout );
1251
1252
      for ( int l = this->depth; l >= 1; l -- )
1253
      {
1254
        size_t n_nodes = 1 << l;
1255
        auto level_beg = this->treelist.begin() + n_nodes - 1;
1256
1257
        /** loop over each node at level-l */
1258
        for ( size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1259
        {
1260
          auto *node = *(level_beg + node_ind);
1261
          RecuTaskSubmit( node, dummy, args... );
1262
        }
1263
      }
1264
    }; /** end LocaTraverseUp() */
1265
1266
1267
    template<typename TASK, typename... Args>
1268
    void DistTraverseUp( TASK &dummy, Args&... args )
1269
    {
1270
      MPINODE *node = mpitreelists.back();
1271
      while ( node )
1272
      {
1273
        if ( this->DoOutOfOrder() ) RecuTaskSubmit(  node, dummy, args... );
1274
        else                        RecuTaskExecute( node, dummy, args... );
1275
        /** move to its parent */
1276
        node = (MPINODE*)node->parent;
1277
      }
1278
    }; /** end DistTraverseUp() */
1279
1280
1281
    template<typename TASK, typename... Args>
1282
    void LocaTraverseDown( TASK &dummy, Args&... args )
1283
    {
1284
      /** contain at lesat one tree node */
1285
      assert( this->treelist.size() );
1286
1287
      /**
1288
       *  traverse the local tree without the root
1289
       *
1290
       *  IMPORTANT: local root alias of the distributed leaf node
1291
       *  IMPORTANT: here l must be int, size_t will wrap over
1292
       *
1293
       */
1294
      for ( int l = 1; l <= this->depth; l ++ )
1295
      {
1296
        size_t n_nodes = 1 << l;
1297
        auto level_beg = this->treelist.begin() + n_nodes - 1;
1298
1299
        for ( size_t node_ind = 0; node_ind < n_nodes; node_ind ++ )
1300
        {
1301
          auto *node = *(level_beg + node_ind);
1302
          RecuTaskSubmit( node, dummy, args... );
1303
        }
1304
      }
1305
    }; /** end LocaTraverseDown() */
1306
1307
1308
    template<typename TASK, typename... Args>
1309
    void DistTraverseDown( TASK &dummy, Args&... args )
1310
    {
1311
      auto *node = mpitreelists.front();
1312
      while ( node )
1313
      {
1314
				//printf( "now at level %lu\n", node->l ); fflush( stdout );
1315
        if ( this->DoOutOfOrder() ) RecuTaskSubmit(  node, dummy, args... );
1316
        else                        RecuTaskExecute( node, dummy, args... );
1317
				//printf( "RecuTaskSubmit at level %lu\n", node->l ); fflush( stdout );
1318
1319
        /**
1320
         *  move to its child
1321
         *  IMPORTANT: here we need to cast the pointer back to mpitree::Node*
1322
         */
1323
        node = node->child;
1324
      }
1325
    }; /** end DistTraverseDown() */
1326
1327
1328
    template<typename TASK, typename... Args>
1329
    void LocaTraverseLeafs( TASK &dummy, Args&... args )
1330
    {
1331
      /** contain at lesat one tree node */
1332
      assert( this->treelist.size() );
1333
1334
      int n_nodes = 1 << this->depth;
1335
      auto level_beg = this->treelist.begin() + n_nodes - 1;
1336
1337
      for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
1338
      {
1339
        auto *node = *(level_beg + node_ind);
1340
        RecuTaskSubmit( node, dummy, args... );
1341
      }
1342
    }; /** end LocaTraverseLeaf() */
1343
1344
1345
    /**
1346
     *  @brief For unordered traversal, we just call local
1347
     *         downward traversal.
1348
     */
1349
    template<typename TASK, typename... Args>
1350
    void LocaTraverseUnOrdered( TASK &dummy, Args&... args )
1351
    {
1352
      LocaTraverseDown( dummy, args... );
1353
    }; /** end LocaTraverseUnOrdered() */
1354
1355
1356
    /**
1357
     *  @brief For unordered traversal, we just call distributed
1358
     *         downward traversal.
1359
     */
1360
    template<typename TASK, typename... Args>
1361
    void DistTraverseUnOrdered( TASK &dummy, Args&... args )
1362
    {
1363
      DistTraverseDown( dummy, args... );
1364
    }; /** end DistTraverseUnOrdered() */
1365
1366
1367
1368
1369
1370
    void DependencyCleanUp()
1371
    {
1372
      for ( auto node : mpitreelists ) node->DependencyCleanUp();
1373
      //for ( size_t i = 0; i < mpitreelists.size(); i ++ )
1374
      //{
1375
      //  mpitreelists[ i ]->DependencyCleanUp();
1376
      //}
1377
1378
      tree::Tree<SETUP, NODEDATA>::DependencyCleanUp();
1379
1380
1381
1382
1383
1384
1385
      for ( auto p : NearRecvFrom ) p.DependencyCleanUp();
1386
      for ( auto p :  FarRecvFrom ) p.DependencyCleanUp();
1387
1388
      /** TODO also clean up the LET node */
1389
1390
    }; /** end DependencyCleanUp() */
1391
1392
1393
    /** @brief */
1394
    void ExecuteAllTasks()
1395
    {
1396
      hmlp_run();
1397
      this->Barrier();
1398
      DependencyCleanUp();
1399
    }; /** end ExecuteAllTasks() */
1400
1401
    /** @brief */
1402
    void DependOnNearInteractions( int p, Task *task )
1403
    {
1404
      /** Describe the dependencies of rank p. */
1405
      for ( auto it : NearSentToRank[ p ] )
1406
      {
1407
        auto *node = this->morton2node[ it ];
1408
        node->DependencyAnalysis( R, task );
1409
      }
1410
      /** Try to enqueue if there is no dependency. */
1411
      task->TryEnqueue();
1412
    }; /** end DependOnNearInteractions() */
1413
1414
    /** @brief */
1415
    void DependOnFarInteractions( int p, Task *task )
1416
    {
1417
      /** Describe the dependencies of rank p. */
1418
      for ( auto it : FarSentToRank[ p ] )
1419
      {
1420
        auto *node = this->morton2node[ it ];
1421
        node->DependencyAnalysis( R, task );
1422
      }
1423
      /** Try to enqueue if there is no dependency. */
1424
      task->TryEnqueue();
1425
    }; /** end DependOnFarInteractions() */
1426
1427
1428
1429
    /**
1430
     *  Interaction lists per rank
1431
     *
1432
     *  NearSentToRank[ p ]   contains all near node MortonIDs sent   to rank p.
1433
     *  NearRecvFromRank[ p ] contains all near node MortonIDs recv from rank p.
1434
     *  NearRecvFromRank[ p ][ morton ] = offset in the received vector.
1435
     */
1436
    vector<vector<size_t>>   NearSentToRank;
1437
    vector<map<size_t, int>> NearRecvFromRank;
1438
    vector<vector<size_t>>   FarSentToRank;
1439
    vector<map<size_t, int>> FarRecvFromRank;
1440
1441
    vector<ReadWrite> NearRecvFrom;
1442
    vector<ReadWrite> FarRecvFrom;
1443
1444
  private:
1445
1446
    /** global communicator error message. */
1447
    int ierr = 0;
1448
    /** n = sum( num_points_owned ) from all MPI processes. */
1449
    size_t num_points_owned = 0;
1450
1451
}; /** end class Tree */
1452
1453
1454
}; /** end namespace mpitree */
1455
}; /** end namespace hmlp */
1456
1457
#endif /** define MPITREE_HPP */