GCC Code Coverage Report
Directory: . Exec Total Coverage
File: gofmm/gofmm_mpi.hpp Lines: 0 1319 0.0 %
Date: 2019-01-14 Branches: 0 7158 0.0 %

Line Exec Source
1
/**
2
 *  HMLP (High-Performance Machine Learning Primitives)
3
 *
4
 *  Copyright (C) 2014-2017, The University of Texas at Austin
5
 *
6
 *  This program is free software: you can redistribute it and/or modify
7
 *  it under the terms of the GNU General Public License as published by
8
 *  the Free Software Foundation, either version 3 of the License, or
9
 *  (at your option) any later version.
10
 *
11
 *  This program is distributed in the hope that it will be useful,
12
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
 *  GNU General Public License for more details.
15
 *
16
 *  You should have received a copy of the GNU General Public License
17
 *  along with this program. If not, see the LICENSE file.
18
 *
19
 **/
20
21
22
#ifndef GOFMM_MPI_HPP
23
#define GOFMM_MPI_HPP
24
25
/** Inherit most of the classes from shared-memory GOFMM. */
26
#include <gofmm.hpp>
27
/** Use distributed metric trees. */
28
#include <tree_mpi.hpp>
29
#include <igofmm_mpi.hpp>
30
/** Use distributed matrices inspired by the Elemental notation. */
31
//#include <DistData.hpp>
32
/** Use STL and HMLP namespaces. */
33
using namespace std;
34
using namespace hmlp;
35
36
37
namespace hmlp
38
{
39
namespace mpigofmm
40
{
41
42
43
///**
44
// *  @biref This class does not have to inherit DistData, but it have to
45
// *         inherit DistVirtualMatrix<T>
46
// *
47
// */
48
//template<typename T>
49
//class DistSPDMatrix : public DistData<STAR, CBLK, T>
50
//{
51
//  public:
52
//
53
//    DistSPDMatrix( size_t m, size_t n, mpi::Comm comm ) :
54
//      DistData<STAR, CBLK, T>( m, n, comm )
55
//    {
56
//    };
57
//
58
//
59
//    /** ESSENTIAL: this is an abstract function  */
60
//    virtual T operator()( size_t i, size_t j, mpi::Comm comm )
61
//    {
62
//      T Kij = 0;
63
//
64
//      /** MPI */
65
//      int size, rank;
66
//      hmlp::mpi::Comm_size( comm, &size );
67
//      hmlp::mpi::Comm_rank( comm, &rank );
68
//
69
//      std::vector<std::vector<size_t>> sendrids( size );
70
//      std::vector<std::vector<size_t>> recvrids( size );
71
//      std::vector<std::vector<size_t>> sendcids( size );
72
//      std::vector<std::vector<size_t>> recvcids( size );
73
//
74
//      /** request Kij from rank ( j % size ) */
75
//      sendrids[ i % size ].push_back( i );
76
//      sendcids[ j % size ].push_back( j );
77
//
78
//      /** exchange ids */
79
//      mpi::AlltoallVector( sendrids, recvrids, comm );
80
//      mpi::AlltoallVector( sendcids, recvcids, comm );
81
//
82
//      /** allocate buffer for data */
83
//      std::vector<std::vector<T>> senddata( size );
84
//      std::vector<std::vector<T>> recvdata( size );
85
//
86
//      /** fetch subrows */
87
//      for ( size_t p = 0; p < size; p ++ )
88
//      {
89
//        assert( recvrids[ p ].size() == recvcids[ p ].size() );
90
//        for ( size_t j = 0; j < recvcids[ p ].size(); j ++ )
91
//        {
92
//          size_t rid = recvrids[ p ][ j ];
93
//          size_t cid = recvcids[ p ][ j ];
94
//          senddata[ p ].push_back( (*this)( rid, cid ) );
95
//        }
96
//      }
97
//
98
//      /** exchange data */
99
//      mpi::AlltoallVector( senddata, recvdata, comm );
100
//
101
//      for ( size_t p = 0; p < size; p ++ )
102
//      {
103
//        assert( recvdata[ p ].size() <= 1 );
104
//        if ( recvdata[ p ] ) Kij = recvdata[ p ][ 0 ];
105
//      }
106
//
107
//      return Kij;
108
//    };
109
//
110
//
111
//    /** ESSENTIAL: return a submatrix */
112
//    virtual hmlp::Data<T> operator()
113
//		( std::vector<size_t> &imap, std::vector<size_t> &jmap, hmlp::mpi::Comm comm )
114
//    {
115
//      hmlp::Data<T> KIJ( imap.size(), jmap.size() );
116
//
117
//      /** MPI */
118
//      int size, rank;
119
//      hmlp::mpi::Comm_size( comm, &size );
120
//      hmlp::mpi::Comm_rank( comm, &rank );
121
//
122
//
123
//
124
//      std::vector<std::vector<size_t>> jmapcids( size );
125
//
126
//      std::vector<std::vector<size_t>> sendrids( size );
127
//      std::vector<std::vector<size_t>> recvrids( size );
128
//      std::vector<std::vector<size_t>> sendcids( size );
129
//      std::vector<std::vector<size_t>> recvcids( size );
130
//
131
//      /** request KIJ from rank ( j % size ) */
132
//      for ( size_t j = 0; j < jmap.size(); j ++ )
133
//      {
134
//        size_t cid = jmap[ j ];
135
//        sendcids[ cid % size ].push_back( cid );
136
//        jmapcids[ cid % size ].push_back(   j );
137
//      }
138
//
139
//      for ( size_t p = 0; p < size; p ++ )
140
//      {
141
//        if ( sendcids[ p ].size() ) sendrids[ p ] = imap;
142
//      }
143
//
144
//      /** exchange ids */
145
//      mpi::AlltoallVector( sendrids, recvrids, comm );
146
//      mpi::AlltoallVector( sendcids, recvcids, comm );
147
//
148
//      /** allocate buffer for data */
149
//      std::vector<hmlp::Data<T>> senddata( size );
150
//      std::vector<hmlp::Data<T>> recvdata( size );
151
//
152
//      /** fetch submatrix */
153
//      for ( size_t p = 0; p < size; p ++ )
154
//      {
155
//        if ( recvcids[ p ].size() && recvrids[ p ].size() )
156
//        {
157
//          senddata[ p ] = (*this)( recvrids[ p ], recvcids[ p ] );
158
//        }
159
//      }
160
//
161
//      /** exchange data */
162
//      mpi::AlltoallVector( senddata, recvdata, comm );
163
//
164
//      /** merging data */
165
//      for ( size_t p = 0; j < size; p ++ )
166
//      {
167
//        assert( recvdata[ p ].size() == imap.size() * recvcids[ p ].size() );
168
//        recvdata[ p ].resize( imap.size(), recvcids[ p ].size() );
169
//        for ( size_t j = 0; j < recvcids[ p ]; i ++ )
170
//        {
171
//          for ( size_t i = 0; i < imap.size(); i ++ )
172
//          {
173
//            KIJ( i, jmapcids[ p ][ j ] ) = recvdata[ p ]( i, j );
174
//          }
175
//        }
176
//      };
177
//
178
//      return KIJ;
179
//    };
180
//
181
//
182
//
183
//
184
//
185
//    virtual hmlp::Data<T> operator()
186
//		( std::vector<int> &imap, std::vector<int> &jmap, hmlp::mpi::Comm comm )
187
//    {
188
//      printf( "operator() not implemented yet\n" );
189
//      exit( 1 );
190
//    };
191
//
192
//
193
//
194
//    /** overload operator */
195
//
196
//
197
//  private:
198
//
199
//}; /** end class DistSPDMatrix */
200
//
201
//
202
203
204
/**
205
 *  @brief These are data that shared by the whole local tree.
206
 *         Distributed setup inherits mpitree::Setup.
207
 */
208
template<typename SPDMATRIX, typename SPLITTER, typename T>
209
class Setup : public mpitree::Setup<SPLITTER, T>,
210
              public gofmm::Configuration<T>
211
{
212
  public:
213
214
    /** Shallow copy from the config. */
215
    void FromConfiguration( gofmm::Configuration<T> &config,
216
        SPDMATRIX &K, SPLITTER &splitter,
217
        DistData<STAR, CBLK, pair<T, size_t>>* NN_cblk )
218
    {
219
      this->CopyFrom( config );
220
      this->K = &K;
221
      this->splitter = splitter;
222
      this->NN_cblk = NN_cblk;
223
    };
224
225
    /** The SPDMATRIX (accessed with gids: dense, CSC or OOC) */
226
    SPDMATRIX *K = NULL;
227
228
    /** rhs-by-n, all weights and potentials. */
229
    Data<T> *w = NULL;
230
    Data<T> *u = NULL;
231
232
    /** buffer space, either dimension needs to be n  */
233
    Data<T> *input = NULL;
234
    Data<T> *output = NULL;
235
236
    /** regularization */
237
    T lambda = 0.0;
238
239
    /** whether the matrix is symmetric */
240
    //bool issymmetric = true;
241
242
    /** use ULV or Sherman-Morrison-Woodbury */
243
    bool do_ulv_factorization = true;
244
245
246
  private:
247
248
}; /** end class Setup */
249
250
251
252
253
254
/**
255
 *  @brief This task creates an hierarchical tree view for
256
 *         weights<RIDS> and potentials<RIDS>.
257
 */
258
template<typename NODE>
259
class DistTreeViewTask : public Task
260
{
261
  public:
262
263
    NODE *arg = NULL;
264
265
    void Set( NODE *user_arg )
266
    {
267
      arg = user_arg;
268
      name = string( "TreeView" );
269
      label = to_string( arg->treelist_id );
270
      cost = 1.0;
271
    };
272
273
    /** Preorder dependencies (with a single source node) */
274
    void DependencyAnalysis() { arg->DependOnParent( this ); };
275
276
    void Execute( Worker* user_worker )
277
    {
278
      auto *node   = arg;
279
280
      /** w and u can be Data<T> or DistData<RIDS,STAR,T> */
281
      auto &w = *(node->setup->w);
282
      auto &u = *(node->setup->u);
283
284
      /** get the matrix view of this tree node */
285
      auto &U = node->data.u_view;
286
      auto &W = node->data.w_view;
287
288
      /** Both w and u are column-majored, thus nontranspose. */
289
      U.Set( u );
290
      W.Set( w );
291
292
      /** Create sub matrix views for local nodes. */
293
      if ( !node->isleaf && !node->child )
294
      {
295
        assert( node->lchild && node->rchild );
296
        auto &UL = node->lchild->data.u_view;
297
        auto &UR = node->rchild->data.u_view;
298
        auto &WL = node->lchild->data.w_view;
299
        auto &WR = node->rchild->data.w_view;
300
        /**
301
         *  U = [ UL;    W = [ WL;
302
         *        UR; ]        WR; ]
303
         */
304
        U.Partition2x1( UL,
305
                        UR, node->lchild->n, TOP );
306
        W.Partition2x1( WL,
307
                        WR, node->lchild->n, TOP );
308
      }
309
    };
310
311
}; /** end class DistTreeViewTask */
312
313
314
315
316
317
318
319
320
321
/** @brief Split values into two halfs accroding to the median. */
322
template<typename T>
323
vector<vector<size_t>> DistMedianSplit( vector<T> &values, mpi::Comm comm )
324
{
325
  int n = 0;
326
  int num_points_owned = values.size();
327
  /** n = sum( num_points_owned ) over all MPI processes in comm */
328
  mpi::Allreduce( &num_points_owned, &n, 1, MPI_SUM, comm );
329
  T  median = combinatorics::Select( n / 2, values, comm );
330
331
  vector<vector<size_t>> split( 2 );
332
  vector<size_t> middle;
333
334
  if ( n == 0 ) return split;
335
336
  for ( size_t i = 0; i < values.size(); i ++ )
337
  {
338
    auto v = values[ i ];
339
    if ( std::fabs( v - median ) < 1E-6 ) middle.push_back( i );
340
    else if ( v < median ) split[ 0 ].push_back( i );
341
    else split[ 1 ].push_back( i );
342
  }
343
344
  int nmid = 0;
345
  int nlhs = 0;
346
  int nrhs = 0;
347
  int num_mid_owned = middle.size();
348
  int num_lhs_owned = split[ 0 ].size();
349
  int num_rhs_owned = split[ 1 ].size();
350
351
  /** nmid = sum( num_mid_owned ) over all MPI processes in comm. */
352
  mpi::Allreduce( &num_mid_owned, &nmid, 1, MPI_SUM, comm );
353
  mpi::Allreduce( &num_lhs_owned, &nlhs, 1, MPI_SUM, comm );
354
  mpi::Allreduce( &num_rhs_owned, &nrhs, 1, MPI_SUM, comm );
355
356
  /** Assign points in the middle to left or right. */
357
  if ( nmid )
358
  {
359
    int nlhs_required, nrhs_required;
360
361
    if ( nlhs > nrhs )
362
    {
363
      nlhs_required = ( n - 1 ) / 2 + 1 - nlhs;
364
      nrhs_required = nmid - nlhs_required;
365
    }
366
    else
367
    {
368
      nrhs_required = ( n - 1 ) / 2 + 1 - nrhs;
369
      nlhs_required = nmid - nrhs_required;
370
    }
371
372
    assert( nlhs_required >= 0 && nrhs_required >= 0 );
373
374
    /** Now decide the portion */
375
    double lhs_ratio = ( (double)nlhs_required ) / nmid;
376
    int nlhs_required_owned = num_mid_owned * lhs_ratio;
377
    int nrhs_required_owned = num_mid_owned - nlhs_required_owned;
378
379
    //printf( "rank %d [ %d %d ] [ %d %d ]\n",
380
    //  global_rank,
381
    //  nlhs_required_owned, nlhs_required,
382
    //  nrhs_required_owned, nrhs_required ); fflush( stdout );
383
384
    assert( nlhs_required_owned >= 0 && nrhs_required_owned >= 0 );
385
386
    for ( size_t i = 0; i < middle.size(); i ++ )
387
    {
388
      if ( i < nlhs_required_owned )
389
        split[ 0 ].push_back( middle[ i ] );
390
      else
391
        split[ 1 ].push_back( middle[ i ] );
392
    }
393
  }
394
395
  return split;
396
}; /** end MedianSplit() */
397
398
399
400
401
/**
402
 *  @brief This the main splitter used to build the Spd-Askit tree.
403
 *         First compute the approximate center using subsamples.
404
 *         Then find the two most far away points to do the
405
 *         projection.
406
 */
407
template<typename SPDMATRIX, int N_SPLIT, typename T>
408
struct centersplit : public gofmm::centersplit<SPDMATRIX, N_SPLIT, T>
409
{
410
411
  centersplit() : gofmm::centersplit<SPDMATRIX, N_SPLIT, T>() {};
412
413
  centersplit( SPDMATRIX& K ) : gofmm::centersplit<SPDMATRIX, N_SPLIT, T>( K ) {};
414
415
  /** Shared-memory operator. */
416
  inline vector<vector<size_t> > operator() ( vector<size_t>& gids ) const
417
  {
418
    return gofmm::centersplit<SPDMATRIX, N_SPLIT, T>::operator() ( gids );
419
  };
420
421
  /** Distributed operator. */
422
  inline vector<vector<size_t> > operator() ( vector<size_t>& gids, mpi::Comm comm ) const
423
  {
424
    /** All assertions */
425
    assert( N_SPLIT == 2 );
426
    assert( this->Kptr );
427
428
    /** MPI Support. */
429
    int size; mpi::Comm_size( comm, &size );
430
    int rank; mpi::Comm_rank( comm, &rank );
431
    auto &K = *(this->Kptr);
432
433
    /** */
434
    vector<T> temp( gids.size(), 0.0 );
435
436
    /** Collecting column samples of K. */
437
    auto column_samples = combinatorics::SampleWithoutReplacement(
438
        this->n_centroid_samples, gids );
439
440
    /** Bcast column_samples from rank 0. */
441
    mpi::Bcast( column_samples.data(), column_samples.size(), 0, comm );
442
    K.BcastIndices( column_samples, 0, comm );
443
444
    /** Compute all pairwise distances. */
445
    auto DIC = K.Distances( this->metric, gids, column_samples );
446
447
    /** Zero out the temporary buffer. */
448
    for ( auto & it : temp ) it = 0;
449
450
    /** Accumulate distances to the temporary buffer. */
451
    for ( size_t j = 0; j < DIC.col(); j ++ )
452
      for ( size_t i = 0; i < DIC.row(); i ++ )
453
        temp[ i ] += DIC( i, j );
454
455
    /** Find the f2c (far most to center) from points owned */
456
    auto idf2c = distance( temp.begin(), max_element( temp.begin(), temp.end() ) );
457
458
    /** Create a pair for MPI Allreduce */
459
    mpi::NumberIntPair<T> local_max_pair, max_pair;
460
    local_max_pair.val = temp[ idf2c ];
461
    local_max_pair.key = rank;
462
463
    /** max_pair = max( local_max_pairs ) over all MPI processes in comm */
464
    mpi::Allreduce( &local_max_pair, &max_pair, 1, MPI_MAXLOC, comm );
465
466
    /** Boardcast gidf2c from the MPI process which has the max_pair */
467
    int gidf2c = gids[ idf2c ];
468
    mpi::Bcast( &gidf2c, 1, MPI_INT, max_pair.key, comm );
469
470
471
    //printf( "rank %d val %E key %d; global val %E key %d\n",
472
    //    rank, local_max_pair.val, local_max_pair.key,
473
    //    max_pair.val, max_pair.key ); fflush( stdout );
474
    //printf( "rank %d gidf2c %d\n", rank, gidf2c  ); fflush( stdout );
475
476
    /** Collecting KIP and kpp */
477
    vector<size_t> P( 1, gidf2c );
478
    K.BcastIndices( P, max_pair.key, comm );
479
480
    /** Compute all pairwise distances. */
481
    auto DIP = K.Distances( this->metric, gids, P );
482
483
    /** Find f2f (far most to far most) from owned points */
484
    auto idf2f = distance( DIP.begin(), max_element( DIP.begin(), DIP.end() ) );
485
486
    /** Create a pair for MPI Allreduce */
487
    local_max_pair.val = DIP[ idf2f ];
488
    local_max_pair.key = rank;
489
490
    /** max_pair = max( local_max_pairs ) over all MPI processes in comm */
491
    mpi::Allreduce( &local_max_pair, &max_pair, 1, MPI_MAXLOC, comm );
492
493
    /** boardcast gidf2f from the MPI process which has the max_pair */
494
    int gidf2f = gids[ idf2f ];
495
    mpi::Bcast( &gidf2f, 1, MPI_INT, max_pair.key, comm );
496
497
    //printf( "rank %d val %E key %d; global val %E key %d\n",
498
    //    rank, local_max_pair.val, local_max_pair.key,
499
    //    max_pair.val, max_pair.key ); fflush( stdout );
500
    //printf( "rank %d gidf2f %d\n", rank, gidf2f  ); fflush( stdout );
501
502
    /** Collecting KIQ and kqq */
503
    vector<size_t> Q( 1, gidf2f );
504
    K.BcastIndices( Q, max_pair.key, comm );
505
506
    /** Compute all pairwise distances. */
507
    auto DIQ = K.Distances( this->metric, gids, P );
508
509
    /** We use relative distances (dip - diq) for clustering. */
510
    for ( size_t i = 0; i < temp.size(); i ++ )
511
      temp[ i ] = DIP[ i ] - DIQ[ i ];
512
513
    /** Split gids into two clusters using median split. */
514
    auto split = DistMedianSplit( temp, comm );
515
516
    /** Perform P2P redistribution. */
517
    mpi::Status status;
518
    vector<size_t> sent_gids;
519
    int partner = ( rank + size / 2 ) % size;
520
    if ( rank < size / 2 )
521
    {
522
      for ( auto it : split[ 1 ] )
523
        sent_gids.push_back( gids[ it ] );
524
      K.SendIndices( sent_gids, partner, comm );
525
      K.RecvIndices( partner, comm, &status );
526
    }
527
    else
528
    {
529
      for ( auto it : split[ 0 ] )
530
        sent_gids.push_back( gids[ it ] );
531
      K.RecvIndices( partner, comm, &status );
532
      K.SendIndices( sent_gids, partner, comm );
533
    }
534
535
    return split;
536
  };
537
538
539
}; /** end struct centersplit */
540
541
542
543
544
545
template<typename SPDMATRIX, int N_SPLIT, typename T>
546
struct randomsplit : public gofmm::randomsplit<SPDMATRIX, N_SPLIT, T>
547
{
548
549
  randomsplit() : gofmm::randomsplit<SPDMATRIX, N_SPLIT, T>() {};
550
551
  randomsplit( SPDMATRIX& K ) : gofmm::randomsplit<SPDMATRIX, N_SPLIT, T>( K ) {};
552
553
  /** Shared-memory operator. */
554
  inline vector<vector<size_t> > operator() ( vector<size_t>& gids ) const
555
  {
556
    return gofmm::randomsplit<SPDMATRIX, N_SPLIT, T>::operator() ( gids );
557
  };
558
559
  /** Distributed operator. */
560
  inline vector<vector<size_t> > operator() ( vector<size_t>& gids, mpi::Comm comm ) const
561
  {
562
    /** All assertions */
563
    assert( N_SPLIT == 2 );
564
    assert( this->Kptr );
565
566
    /** Declaration */
567
    int size, rank, global_rank, global_size;
568
    mpi::Comm_size( comm, &size );
569
    mpi::Comm_rank( comm, &rank );
570
    mpi::Comm_rank( MPI_COMM_WORLD, &global_rank );
571
    mpi::Comm_size( MPI_COMM_WORLD, &global_size );
572
    SPDMATRIX &K = *(this->Kptr);
573
    //vector<vector<size_t>> split( N_SPLIT );
574
575
    if ( size == global_size )
576
    {
577
      for ( size_t i = 0; i < gids.size(); i ++ )
578
        assert( gids[ i ] == i * size + rank );
579
    }
580
581
582
583
584
    /** Reduce to get the total size of gids. */
585
    int n = 0;
586
    int num_points_owned = gids.size();
587
    vector<T> temp( gids.size(), 0.0 );
588
589
    /** n = sum( num_points_owned ) over all MPI processes in comm */
590
    mpi::Allreduce( &num_points_owned, &n, 1, MPI_INT, MPI_SUM, comm );
591
592
    /** Early return */
593
    //if ( n == 0 ) return split;
594
595
    /** Randomly select two points p and q */
596
    size_t gidf2c, gidf2f;
597
    if ( gids.size() )
598
    {
599
      gidf2c = gids[ std::rand() % gids.size() ];
600
      gidf2f = gids[ std::rand() % gids.size() ];
601
    }
602
603
    /** Create a pair <gids.size(), rank> for MPI Allreduce */
604
    mpi::NumberIntPair<T> local_max_pair, max_pair;
605
    local_max_pair.val = gids.size();
606
    local_max_pair.key = rank;
607
608
    /** max_pair = max( local_max_pairs ) over all MPI processes in comm */
609
    mpi::Allreduce( &local_max_pair, &max_pair, 1, MPI_MAXLOC, comm );
610
611
    /** Bcast gidf2c from the rank that has the most gids */
612
    mpi::Bcast( &gidf2c, 1, max_pair.key, comm );
613
    vector<size_t> P( 1, gidf2c );
614
    K.BcastIndices( P, max_pair.key, comm );
615
616
    /** Choose the second MPI rank */
617
    if ( rank == max_pair.key ) local_max_pair.val = 0;
618
619
    /** max_pair = max( local_max_pairs ) over all MPI processes in comm */
620
    mpi::Allreduce( &local_max_pair, &max_pair, 1, MPI_MAXLOC, comm );
621
622
    /** Bcast gidf2c from the rank that has the most gids */
623
    mpi::Bcast( &gidf2f, 1, max_pair.key, comm );
624
    vector<size_t> Q( 1, gidf2f );
625
    K.BcastIndices( Q, max_pair.key, comm );
626
627
628
    auto DIP = K.Distances( this->metric, gids, P );
629
    auto DIQ = K.Distances( this->metric, gids, Q );
630
631
    /** We use relative distances (dip - diq) for clustering. */
632
    for ( size_t i = 0; i < temp.size(); i ++ )
633
      temp[ i ] = DIP[ i ] - DIQ[ i ];
634
635
    /** Split gids into two clusters using median split. */
636
    auto split = DistMedianSplit( temp, comm );
637
638
    /** Perform P2P redistribution. */
639
    mpi::Status status;
640
    vector<size_t> sent_gids;
641
    int partner = ( rank + size / 2 ) % size;
642
    if ( rank < size / 2 )
643
    {
644
      for ( auto it : split[ 1 ] )
645
        sent_gids.push_back( gids[ it ] );
646
      K.SendIndices( sent_gids, partner, comm );
647
      K.RecvIndices( partner, comm, &status );
648
    }
649
    else
650
    {
651
      for ( auto it : split[ 0 ] )
652
        sent_gids.push_back( gids[ it ] );
653
      K.RecvIndices( partner, comm, &status );
654
      K.SendIndices( sent_gids, partner, comm );
655
    }
656
657
    return split;
658
  };
659
660
661
}; /** end struct randomsplit */
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
/**
684
 *  @brief Compute skeleton weights.
685
 *
686
 *
687
 */
688
template<typename NODE>
689
void DistUpdateWeights( NODE *node )
690
{
691
  /** Derive type T from NODE. */
692
  using T = typename NODE::T;
693
  /** MPI Support. */
694
  mpi::Status status;
695
  auto comm = node->GetComm();
696
  int  size = node->GetCommSize();
697
  int  rank = node->GetCommRank();
698
699
  /** Early return if this is the root or there is no skeleton. */
700
  if ( !node->parent || !node->data.isskel ) return;
701
702
  if ( size < 2 )
703
  {
704
    /** This is the root of the local tree. */
705
    gofmm::UpdateWeights( node );
706
  }
707
  else
708
  {
709
    /** Gather shared data and create reference. */
710
    auto &w = *node->setup->w;
711
    size_t nrhs = w.col();
712
713
    /** gather per node data and create reference */
714
    auto &data   = node->data;
715
    auto &proj   = data.proj;
716
    auto &w_skel = data.w_skel;
717
718
    /** This is the corresponding MPI rank. */
719
    if ( rank == 0 )
720
    {
721
      size_t s  = proj.row();
722
			size_t sl = node->child->data.skels.size();
723
			size_t sr = proj.col() - sl;
724
      /** w_skel is s-by-nrhs, initial values are not important. */
725
      w_skel.resize( s, nrhs );
726
      /** Create matrix views. */
727
      View<T> P( false,   proj ), PL, PR;
728
      View<T> W( false, w_skel ), WL( false, node->child->data.w_skel );
729
      /** P = [ PL, PR ] */
730
      P.Partition1x2( PL, PR, sl, LEFT );
731
      /** W  = PL * WL */
732
      gemm::xgemm<GEMM_NB>( (T)1.0, PL, WL, (T)0.0, W );
733
734
      Data<T> w_skel_sib;
735
      mpi::ExchangeVector( w_skel, size / 2, 0, w_skel_sib, size / 2, 0, comm, &status );
736
			/** Reduce from my sibling. */
737
      #pragma omp parallel for
738
			for ( size_t i = 0; i < w_skel.size(); i ++ )
739
				w_skel[ i ] += w_skel_sib[ i ];
740
    }
741
742
    /** The rank that holds the skeleton weight of the right child. */
743
    if ( rank == size / 2 )
744
    {
745
      size_t s  = proj.row();
746
			size_t sr = node->child->data.skels.size();
747
			size_t sl = proj.col() - sr;
748
      /** w_skel is s-by-nrhs, initial values are not important. */
749
      w_skel.resize( s, nrhs );
750
      /** Create a transpose view proj_v */
751
      View<T> P( false,   proj ), PL, PR;
752
      View<T> W( false, w_skel ), WR( false, node->child->data.w_skel );
753
      /** P = [ PL, PR ] */
754
      P.Partition1x2( PL, PR, sl, LEFT );
755
      /** W += PR * WR */
756
      gemm::xgemm<GEMM_NB>( (T)1.0, PR, WR, (T)0.0, W );
757
758
759
			Data<T> w_skel_sib;
760
			mpi::ExchangeVector( w_skel, 0, 0, w_skel_sib, 0, 0, comm, &status );
761
			w_skel.clear();
762
    }
763
  }
764
}; /** end DistUpdateWeights() */
765
766
767
768
769
/**
770
 *  @brief Notice that NODE here is MPITree::Node.
771
 */
772
template<typename NODE, typename T>
773
class DistUpdateWeightsTask : public Task
774
{
775
  public:
776
777
    NODE *arg = NULL;
778
779
    void Set( NODE *user_arg )
780
    {
781
      arg = user_arg;
782
      name = string( "DistN2S" );
783
      label = to_string( arg->treelist_id );
784
785
      /** Compute FLOPS and MOPS */
786
      double flops = 0.0, mops = 0.0;
787
      auto &gids = arg->gids;
788
      auto &skels = arg->data.skels;
789
      auto &w = *arg->setup->w;
790
791
			if ( !arg->child )
792
			{
793
        if ( arg->isleaf )
794
        {
795
          auto m = skels.size();
796
          auto n = w.col();
797
          auto k = gids.size();
798
          flops = 2.0 * m * n * k;
799
          mops = 2.0 * ( m * n + m * k + k * n );
800
        }
801
        else
802
        {
803
          auto &lskels = arg->lchild->data.skels;
804
          auto &rskels = arg->rchild->data.skels;
805
          auto m = skels.size();
806
          auto n = w.col();
807
          auto k = lskels.size() + rskels.size();
808
          flops = 2.0 * m * n * k;
809
          mops  = 2.0 * ( m * n + m * k + k * n );
810
        }
811
			}
812
			else
813
			{
814
				if ( arg->GetCommRank() == 0 )
815
				{
816
          auto &lskels = arg->child->data.skels;
817
          auto m = skels.size();
818
          auto n = w.col();
819
          auto k = lskels.size();
820
          flops = 2.0 * m * n * k;
821
          mops = 2.0 * ( m * n + m * k + k * n );
822
				}
823
				if ( arg->GetCommRank() == arg->GetCommSize() / 2 )
824
				{
825
          auto &rskels = arg->child->data.skels;
826
          auto m = skels.size();
827
          auto n = w.col();
828
          auto k = rskels.size();
829
          flops = 2.0 * m * n * k;
830
          mops = 2.0 * ( m * n + m * k + k * n );
831
				}
832
			}
833
834
      /** Setup the event */
835
      event.Set( label + name, flops, mops );
836
      /** Assume computation bound */
837
      cost = flops / 1E+9;
838
      /** "HIGH" priority (critical path) */
839
      priority = true;
840
    };
841
842
    void DependencyAnalysis() { arg->DependOnChildren( this ); };
843
844
    void Execute( Worker* user_worker ) { DistUpdateWeights( arg ); };
845
846
}; /** end class DistUpdateWeightsTask */
847
848
849
850
851
/**
852
 *
853
 */
854
//template<bool NNPRUNE, typename NODE, typename T>
855
//class DistSkeletonsToSkeletonsTask : public Task
856
//{
857
//  public:
858
//
859
//    NODE *arg = NULL;
860
//
861
//    void Set( NODE *user_arg )
862
//    {
863
//      arg = user_arg;
864
//      name = string( "DistS2S" );
865
//      label = to_string( arg->treelist_id );
866
//      /** compute flops and mops */
867
//      double flops = 0.0, mops = 0.0;
868
//      auto &w = *arg->setup->w;
869
//      size_t m = arg->data.skels.size();
870
//      size_t n = w.col();
871
//
872
//      auto *FarNodes = &arg->FarNodes;
873
//      if ( NNPRUNE ) FarNodes = &arg->NNFarNodes;
874
//
875
//      for ( auto it = FarNodes->begin(); it != FarNodes->end(); it ++ )
876
//      {
877
//        size_t k = (*it)->data.skels.size();
878
//        flops += 2.0 * m * n * k;
879
//        mops  += m * k; // cost of Kab
880
//        mops  += 2.0 * ( m * n + n * k + k * n );
881
//      }
882
//
883
//      /** setup the event */
884
//      event.Set( label + name, flops, mops );
885
//
886
//      /** assume computation bound */
887
//      cost = flops / 1E+9;
888
//
889
//      /** "LOW" priority */
890
//      priority = false;
891
//    };
892
//
893
//
894
//
895
//    void DependencyAnalysis()
896
//    {
897
//      for ( auto p : arg->data.FarDependents )
898
//        hmlp_msg_dependency_analysis( 306, p, R, this );
899
//
900
//      auto *FarNodes = &arg->FarNodes;
901
//      if ( NNPRUNE ) FarNodes = &arg->NNFarNodes;
902
//      for ( auto it : *FarNodes ) it->DependencyAnalysis( R, this );
903
//
904
//      arg->DependencyAnalysis( RW, this );
905
//      this->TryEnqueue();
906
//    };
907
//
908
//    /**
909
//     *  @brief Notice that S2S depends on all Far interactions, which
910
//     *         may include local tree nodes or let nodes.
911
//     *         For HSS case, the only Far interaction is the sibling.
912
//     *         Skeleton weight of the sibling will always be exchanged
913
//     *         by default in N2S. Thus, currently we do not need
914
//     *         a distributed S2S, because the skeleton weight is already
915
//     *         in place.
916
//     *
917
//     */
918
//    void Execute( Worker* user_worker )
919
//    {
920
//      auto *node = arg;
921
//      /** MPI Support. */
922
//      auto comm = node->GetComm();
923
//      auto size = node->GetCommSize();
924
//      auto rank = node->GetCommRank();
925
//
926
//      if ( size < 2 )
927
//      {
928
//        gofmm::SkeletonsToSkeletons<NNPRUNE, NODE, T>( node );
929
//      }
930
//      else
931
//      {
932
//        /** Only 0th rank (owner) will execute this task. */
933
//        if ( rank == 0 ) gofmm::SkeletonsToSkeletons<NNPRUNE, NODE, T>( node );
934
//      }
935
//    };
936
//
937
//}; /** end class DistSkeletonsToSkeletonsTask */
938
//
939
940
template<typename NODE, typename LETNODE, typename T>
941
class S2STask2 : public Task
942
{
943
  public:
944
945
    NODE *arg = NULL;
946
947
    vector<LETNODE*> Sources;
948
949
    int p = 0;
950
951
    Lock *lock = NULL;
952
953
    int *num_arrived_subtasks;
954
955
    void Set( NODE *user_arg, vector<LETNODE*> user_src, int user_p, Lock *user_lock,
956
        int *user_num_arrived_subtasks )
957
    {
958
      arg = user_arg;
959
      Sources = user_src;
960
      p = user_p;
961
      lock = user_lock;
962
      num_arrived_subtasks = user_num_arrived_subtasks;
963
      name = string( "S2S" );
964
      label = to_string( arg->treelist_id );
965
966
      /** Compute FLOPS and MOPS */
967
      double flops = 0.0, mops = 0.0;
968
      size_t nrhs = arg->setup->w->col();
969
      size_t m = arg->data.skels.size();
970
      for ( auto src : Sources )
971
      {
972
        size_t k = src->data.skels.size();
973
        flops += 2 * m * k * nrhs;
974
        mops  += 2 * ( m * k + ( m + k ) * nrhs );
975
        flops += 2 * m * nrhs;
976
        flops += m * k * ( 2 * 18 + 100 );
977
      }
978
      /** Setup the event */
979
      event.Set( label + name, flops, mops );
980
      /** Assume computation bound */
981
      cost = flops / 1E+9;
982
      /** Assume computation bound */
983
      if ( arg->treelist_id == 0 ) priority = true;
984
    };
985
986
    void DependencyAnalysis()
987
    {
988
      if ( p == hmlp_get_mpi_rank() )
989
      {
990
        for ( auto src : Sources ) src->DependencyAnalysis( R, this );
991
      }
992
      else hmlp_msg_dependency_analysis( 306, p, R, this );
993
      this->TryEnqueue();
994
    };
995
996
    void Execute( Worker* user_worker )
997
    {
998
      auto *node = arg;
999
      if ( !node->parent || !node->data.isskel ) return;
1000
      size_t nrhs = node->setup->w->col();
1001
      auto &K = *node->setup->K;
1002
      auto &I = node->data.skels;
1003
1004
      /** Temporary buffer */
1005
      Data<T> u( I.size(), nrhs, 0.0 );
1006
1007
      for ( auto src : Sources )
1008
      {
1009
        auto &J = src->data.skels;
1010
        auto &w = src->data.w_skel;
1011
        bool is_cached = true;
1012
1013
        auto &KIJ = node->DistFar[ p ][ src->morton ];
1014
        if ( KIJ.row() != I.size() || KIJ.col() != J.size() )
1015
        {
1016
          //printf( "KIJ %lu %lu I %lu J %lu\n", KIJ.row(), KIJ.col(), I.size(), J.size() );
1017
          KIJ = K( I, J );
1018
          is_cached = false;
1019
        }
1020
1021
        assert( w.col() == nrhs );
1022
        assert( w.row() == J.size() );
1023
        //xgemm
1024
        //(
1025
        //  "N", "N", u.row(), u.col(), w.row(),
1026
        //  1.0, KIJ.data(), KIJ.row(),
1027
        //         w.data(),   w.row(),
1028
        //  1.0,   u.data(),   u.row()
1029
        //);
1030
        gemm::xgemm( (T)1.0, KIJ, w, (T)1.0, u );
1031
1032
        /** Free KIJ, if !is_cached. */
1033
        if ( !is_cached )
1034
        {
1035
          KIJ.resize( 0, 0 );
1036
          KIJ.shrink_to_fit();
1037
        }
1038
      }
1039
1040
      lock->Acquire();
1041
      {
1042
        auto &u_skel = node->data.u_skel;
1043
        for ( int i = 0; i < u.size(); i ++ )
1044
          u_skel[ i ] += u[ i ];
1045
      }
1046
      lock->Release();
1047
      #pragma omp atomic update
1048
      *num_arrived_subtasks += 1;
1049
    };
1050
};
1051
1052
template<typename NODE, typename LETNODE, typename T>
1053
class S2SReduceTask2 : public Task
1054
{
1055
  public:
1056
1057
    NODE *arg = NULL;
1058
1059
    vector<S2STask2<NODE, LETNODE, T>*> subtasks;
1060
1061
    Lock lock;
1062
1063
    int num_arrived_subtasks = 0;
1064
1065
    const size_t batch_size = 2;
1066
1067
    void Set( NODE *user_arg )
1068
    {
1069
      arg = user_arg;
1070
      name = string( "S2SR" );
1071
      label = to_string( arg->treelist_id );
1072
1073
      /** Reset u_skel */
1074
      if ( arg )
1075
      {
1076
        size_t nrhs = arg->setup->w->col();
1077
        auto &I = arg->data.skels;
1078
        arg->data.u_skel.resize( 0, 0 );
1079
        arg->data.u_skel.resize( I.size(), nrhs, 0 );
1080
      }
1081
1082
      /** Create subtasks */
1083
      for ( int p = 0; p < hmlp_get_mpi_size(); p ++ )
1084
      {
1085
        vector<LETNODE*> Sources;
1086
        for ( auto &it : arg->DistFar[ p ] )
1087
        {
1088
          Sources.push_back( (*arg->morton2node)[ it.first ] );
1089
          if ( Sources.size() == batch_size )
1090
          {
1091
            subtasks.push_back( new S2STask2<NODE, LETNODE, T>() );
1092
            subtasks.back()->Submit();
1093
            subtasks.back()->Set( user_arg, Sources, p, &lock, &num_arrived_subtasks );
1094
            subtasks.back()->DependencyAnalysis();
1095
            Sources.clear();
1096
          }
1097
        }
1098
        if ( Sources.size() )
1099
        {
1100
          subtasks.push_back( new S2STask2<NODE, LETNODE, T>() );
1101
          subtasks.back()->Submit();
1102
          subtasks.back()->Set( user_arg, Sources, p, &lock, &num_arrived_subtasks );
1103
          subtasks.back()->DependencyAnalysis();
1104
          Sources.clear();
1105
        }
1106
      }
1107
      /** Compute FLOPS and MOPS. */
1108
      double flops = 0, mops = 0;
1109
      /** Setup the event */
1110
      event.Set( label + name, flops, mops );
1111
      /** Assume computation bound */
1112
      priority = true;
1113
    };
1114
1115
    void DependencyAnalysis()
1116
    {
1117
      for ( auto task : subtasks ) Scheduler::DependencyAdd( task, this );
1118
      arg->DependencyAnalysis( RW, this );
1119
      this->TryEnqueue();
1120
    };
1121
1122
    void Execute( Worker* user_worker )
1123
    {
1124
      /** Place holder */
1125
      assert( num_arrived_subtasks == subtasks.size() );
1126
    };
1127
};
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
template<bool NNPRUNE, typename NODE, typename T>
1149
void DistSkeletonsToNodes( NODE *node )
1150
{
1151
  /** MPI Support. */
1152
  auto comm = node->GetComm();
1153
  auto size = node->GetCommSize();
1154
  auto rank = node->GetCommRank();
1155
  mpi::Status status;
1156
1157
  /** gather shared data and create reference */
1158
  auto &K = *node->setup->K;
1159
  auto &w = *node->setup->w;
1160
1161
1162
  size_t nrhs = w.col();
1163
1164
1165
  /** Early return if this is the root or has no skeleton. */
1166
  if ( !node->parent || !node->data.isskel ) return;
1167
1168
  if ( size < 2 )
1169
  {
1170
    /** Call the shared-memory implementation. */
1171
    gofmm::SkeletonsToNodes( node );
1172
  }
1173
  else
1174
  {
1175
    auto &data = node->data;
1176
    auto &proj = data.proj;
1177
    auto &u_skel = data.u_skel;
1178
1179
    if ( rank == 0 )
1180
    {
1181
			size_t sl = node->child->data.skels.size();
1182
			size_t sr = proj.col() - sl;
1183
      /** Send u_skel to my sibling. */
1184
      mpi::SendVector( u_skel, size / 2, 0, comm );
1185
      /** Create a transpose matrix view for proj. */
1186
      View<T> P(  true,   proj ), PL, PR;
1187
      View<T> U( false, u_skel ), UL( false, node->child->data.u_skel );
1188
      /** P' = [ PL, PR ]' */
1189
      P.Partition2x1( PL,
1190
                      PR, sl, TOP );
1191
      /** UL += PL' * U */
1192
      gemm::xgemm<GEMM_NB>( (T)1.0, PL, U, (T)1.0, UL );
1193
    }
1194
1195
    /**  */
1196
    if ( rank == size / 2 )
1197
    {
1198
      size_t s  = proj.row();
1199
			size_t sr = node->child->data.skels.size();
1200
      size_t sl = proj.col() - sr;
1201
      /** Receive u_skel from my sibling. */
1202
      mpi::RecvVector( u_skel, 0, 0, comm, &status );
1203
			u_skel.resize( s, nrhs );
1204
      /** create a transpose view proj_v */
1205
      View<T> P(  true,   proj ), PL, PR;
1206
      View<T> U( false, u_skel ), UR( false, node->child->data.u_skel );
1207
      /** P' = [ PL, PR ]' */
1208
      P.Partition2x1( PL,
1209
                      PR, sl, TOP );
1210
      /** UR += PR' * U */
1211
      gemm::xgemm<GEMM_NB>( (T)1.0, PR, U, (T)1.0, UR );
1212
    }
1213
  }
1214
}; /** end DistSkeletonsToNodes() */
1215
1216
1217
1218
1219
1220
template<bool NNPRUNE, typename NODE, typename T>
1221
class DistSkeletonsToNodesTask : public Task
1222
{
1223
  public:
1224
1225
    NODE *arg;
1226
1227
    void Set( NODE *user_arg )
1228
    {
1229
      arg = user_arg;
1230
      name = string( "PS2N" );
1231
      label = to_string( arg->l );
1232
1233
      double flops = 0.0, mops = 0.0;
1234
      auto &gids = arg->gids;
1235
      auto &skels = arg->data.skels;
1236
      auto &w = *arg->setup->w;
1237
1238
			if ( !arg->child )
1239
			{
1240
        if ( arg->isleaf )
1241
        {
1242
          auto m = skels.size();
1243
          auto n = w.col();
1244
          auto k = gids.size();
1245
          flops = 2.0 * m * n * k;
1246
          mops = 2.0 * ( m * n + m * k + k * n );
1247
        }
1248
        else
1249
        {
1250
          auto &lskels = arg->lchild->data.skels;
1251
          auto &rskels = arg->rchild->data.skels;
1252
          auto m = skels.size();
1253
          auto n = w.col();
1254
          auto k = lskels.size() + rskels.size();
1255
          flops = 2.0 * m * n * k;
1256
          mops  = 2.0 * ( m * n + m * k + k * n );
1257
        }
1258
			}
1259
			else
1260
			{
1261
				if ( arg->GetCommRank() == 0 )
1262
				{
1263
          auto &lskels = arg->child->data.skels;
1264
          auto m = skels.size();
1265
          auto n = w.col();
1266
          auto k = lskels.size();
1267
          flops = 2.0 * m * n * k;
1268
          mops = 2.0 * ( m * n + m * k + k * n );
1269
				}
1270
				if ( arg->GetCommRank() == arg->GetCommSize() / 2 )
1271
				{
1272
          auto &rskels = arg->child->data.skels;
1273
          auto m = skels.size();
1274
          auto n = w.col();
1275
          auto k = rskels.size();
1276
          flops = 2.0 * m * n * k;
1277
          mops = 2.0 * ( m * n + m * k + k * n );
1278
				}
1279
			}
1280
1281
      /** Setup the event */
1282
      event.Set( label + name, flops, mops );
1283
      /** Asuume computation bound */
1284
      cost = flops / 1E+9;
1285
      /** "HIGH" priority (critical path) */
1286
      priority = true;
1287
    };
1288
1289
    void DependencyAnalysis() { arg->DependOnParent( this ); };
1290
1291
    void Execute( Worker* user_worker ) { DistSkeletonsToNodes<NNPRUNE, NODE, T>( arg ); };
1292
1293
}; /** end class DistSkeletonsToNodesTask */
1294
1295
1296
1297
template<typename NODE, typename T>
1298
class L2LTask2 : public Task
1299
{
1300
  public:
1301
1302
    NODE *arg = NULL;
1303
1304
    /** A list of source node pointers. */
1305
    vector<NODE*> Sources;
1306
1307
    int p = 0;
1308
1309
    /** Write lock */
1310
    Lock *lock = NULL;
1311
1312
    int *num_arrived_subtasks;
1313
1314
    void Set( NODE *user_arg, vector<NODE*> user_src, int user_p, Lock *user_lock,
1315
        int* user_num_arrived_subtasks )
1316
    {
1317
      arg = user_arg;
1318
      Sources = user_src;
1319
      p = user_p;
1320
      lock = user_lock;
1321
      num_arrived_subtasks = user_num_arrived_subtasks;
1322
      name = string( "L2L" );
1323
      label = to_string( arg->treelist_id );
1324
1325
      /** Compute FLOPS and MOPS. */
1326
      double flops = 0.0, mops = 0.0;
1327
      size_t nrhs = arg->setup->w->col();
1328
      size_t m = arg->gids.size();
1329
      for ( auto src : Sources )
1330
      {
1331
        size_t k = src->gids.size();
1332
        flops += 2 * m * k * nrhs;
1333
        mops  += 2 * ( m * k + ( m + k ) * nrhs );
1334
        flops += 2 * m * nrhs;
1335
        flops += m * k * ( 2 * 18 + 100 );
1336
      }
1337
      /** Setup the event */
1338
      event.Set( label + name, flops, mops );
1339
      /** Assume computation bound */
1340
      cost = flops / 1E+9;
1341
      /** "LOW" priority */
1342
      priority = false;
1343
    };
1344
1345
    void DependencyAnalysis()
1346
    {
1347
      /** If p is a distributed process, then depends on the message. */
1348
      if ( p != hmlp_get_mpi_rank() )
1349
        hmlp_msg_dependency_analysis( 300, p, R, this );
1350
      this->TryEnqueue();
1351
    };
1352
1353
    void Execute( Worker* user_worker )
1354
    {
1355
      auto *node = arg;
1356
      size_t nrhs = node->setup->w->col();
1357
      auto &K = *node->setup->K;
1358
      auto &I = node->gids;
1359
1360
      double beg = omp_get_wtime();
1361
      /** Temporary buffer */
1362
      Data<T> u( I.size(), nrhs, 0.0 );
1363
      size_t k;
1364
1365
      for ( auto src : Sources )
1366
      {
1367
        /** Get W view of this treenode. (available for non-LET nodes) */
1368
        View<T> &W = src->data.w_view;
1369
        Data<T> &w = src->data.w_leaf;
1370
1371
        bool is_cached = true;
1372
        auto &J = src->gids;
1373
        auto &KIJ = node->DistNear[ p ][ src->morton ];
1374
        if ( KIJ.row() != I.size() || KIJ.col() != J.size() )
1375
        {
1376
          KIJ = K( I, J );
1377
          is_cached = false;
1378
        }
1379
1380
        if ( W.col() == nrhs && W.row() == J.size() )
1381
        {
1382
          k += W.row();
1383
          xgemm
1384
          (
1385
            "N", "N", u.row(), u.col(), W.row(),
1386
            1.0, KIJ.data(), KIJ.row(),
1387
                   W.data(),   W.ld(),
1388
            1.0,   u.data(),   u.row()
1389
          );
1390
        }
1391
        else
1392
        {
1393
          k += w.row();
1394
          xgemm
1395
          (
1396
            "N", "N", u.row(), u.col(), w.row(),
1397
            1.0, KIJ.data(), KIJ.row(),
1398
                   w.data(),   w.row(),
1399
            1.0,   u.data(),   u.row()
1400
          );
1401
        }
1402
1403
        /** Free KIJ, if !is_cached. */
1404
        if ( !is_cached )
1405
        {
1406
          KIJ.resize( 0, 0 );
1407
          KIJ.shrink_to_fit();
1408
        }
1409
      }
1410
1411
      double lock_beg = omp_get_wtime();
1412
      lock->Acquire();
1413
      {
1414
        /** Get U view of this treenode. */
1415
        View<T> &U = node->data.u_view;
1416
        for ( int j = 0; j < u.col(); j ++ )
1417
          for ( int i = 0; i < u.row(); i ++ )
1418
            U( i, j ) += u( i, j );
1419
      }
1420
      lock->Release();
1421
      double lock_time = omp_get_wtime() - lock_beg;
1422
1423
      double gemm_time = omp_get_wtime() - beg;
1424
      double GFLOPS = 2.0 * u.row() * u.col() * k / ( 1E+9 * gemm_time );
1425
      //printf( "GEMM %4lu %4lu %4lu %lf GFLOPS, lock(%lf/%lf)\n",
1426
      //    u.row(), u.col(), k, GFLOPS, lock_time, gemm_time ); fflush( stdout );
1427
      #pragma omp atomic update
1428
      *num_arrived_subtasks += 1;
1429
    };
1430
};
1431
1432
1433
1434
1435
template<typename NODE, typename T>
1436
class L2LReduceTask2 : public Task
1437
{
1438
  public:
1439
1440
    NODE *arg = NULL;
1441
1442
    vector<L2LTask2<NODE, T>*> subtasks;
1443
1444
    Lock lock;
1445
1446
    int num_arrived_subtasks = 0;
1447
1448
    const size_t batch_size = 2;
1449
1450
    void Set( NODE *user_arg )
1451
    {
1452
      arg = user_arg;
1453
      name = string( "L2LR" );
1454
      label = to_string( arg->treelist_id );
1455
      /** Create subtasks */
1456
      for ( int p = 0; p < hmlp_get_mpi_size(); p ++ )
1457
      {
1458
        vector<NODE*> Sources;
1459
        for ( auto &it : arg->DistNear[ p ] )
1460
        {
1461
          Sources.push_back( (*arg->morton2node)[ it.first ] );
1462
          if ( Sources.size() == batch_size )
1463
          {
1464
            subtasks.push_back( new L2LTask2<NODE, T>() );
1465
            subtasks.back()->Submit();
1466
            subtasks.back()->Set( user_arg, Sources, p, &lock, &num_arrived_subtasks );
1467
            subtasks.back()->DependencyAnalysis();
1468
            Sources.clear();
1469
          }
1470
        }
1471
        if ( Sources.size() )
1472
        {
1473
          subtasks.push_back( new L2LTask2<NODE, T>() );
1474
          subtasks.back()->Submit();
1475
          subtasks.back()->Set( user_arg, Sources, p, &lock, &num_arrived_subtasks );
1476
          subtasks.back()->DependencyAnalysis();
1477
          Sources.clear();
1478
        }
1479
      }
1480
1481
1482
1483
1484
      /** Compute FLOPS and MOPS */
1485
      double flops = 0, mops = 0;
1486
      /** Setup the event */
1487
      event.Set( label + name, flops, mops );
1488
      /** "LOW" priority (critical path) */
1489
      priority = false;
1490
    };
1491
1492
    void DependencyAnalysis()
1493
    {
1494
      for ( auto task : subtasks ) Scheduler::DependencyAdd( task, this );
1495
      arg->DependencyAnalysis( RW, this );
1496
      this->TryEnqueue();
1497
    };
1498
1499
    void Execute( Worker* user_worker )
1500
    {
1501
      assert( num_arrived_subtasks == subtasks.size() );
1502
    };
1503
};
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
/**
1522
 *  @brief (FMM specific) Compute Near( leaf nodes ). This is just like
1523
 *         the neighbor list but the granularity is in nodes but not points.
1524
 *         The algorithm is to compute the node morton ids of neighbor points.
1525
 *         Get the pointers of these nodes and insert them into a std::set.
1526
 *         std::set will automatic remove duplication. Here the insertion
1527
 *         will be performed twice each time to get a symmetric one. That is
1528
 *         if alpha has beta in its list, then beta will also have alpha in
1529
 *         its list.
1530
 *
1531
 *         Only leaf nodes will have the list `` NearNodes''.
1532
 *
1533
 *         This list will later be used to get the FarNodes using a recursive
1534
 *         node traversal scheme.
1535
 *
1536
 */
1537
template<typename TREE>
1538
void FindNearInteractions( TREE &tree )
1539
{
1540
  mpi::PrintProgress( "[BEG] Finish FindNearInteractions ...", tree.GetComm() );
1541
  /** Derive type NODE from TREE. */
1542
  using NODE = typename TREE::NODE;
1543
  auto &setup = tree.setup;
1544
  auto &NN = *setup.NN;
1545
  double budget = setup.Budget();
1546
  size_t n_leafs = ( 1 << tree.depth );
1547
  /**
1548
   *  The type here is tree::Node but not mpitree::Node.
1549
   *  NearNodes and NNNearNodes also take tree::Node.
1550
   *  This is ok, because they will only contain leaf nodes,
1551
   *  which will never be distributed.
1552
   *  However, FarNodes and NNFarNodes may contain distributed
1553
   *  tree nodes. In this case, we have to do type casting.
1554
   */
1555
  auto level_beg = tree.treelist.begin() + n_leafs - 1;
1556
1557
  /** Traverse all leaf nodes. **/
1558
  #pragma omp parallel for
1559
  for ( size_t node_ind = 0; node_ind < n_leafs; node_ind ++ )
1560
  {
1561
    auto *node = *(level_beg + node_ind);
1562
    auto &data = node->data;
1563
    size_t n_nodes = ( 1 << node->l );
1564
1565
    /** Add myself to the near interaction list.  */
1566
    node->NNNearNodes.insert( node );
1567
    node->NNNearNodeMortonIDs.insert( node->morton );
1568
1569
    /** Compute ballots for all near interactions */
1570
    multimap<size_t, size_t> sorted_ballot = gofmm::NearNodeBallots( node );
1571
1572
    /** Insert near node cadidates until reaching the budget limit. */
1573
    for ( auto it  = sorted_ballot.rbegin();
1574
               it != sorted_ballot.rend(); it ++ )
1575
    {
1576
      /** Exit if we have enough near interactions. */
1577
      if ( node->NNNearNodes.size() >= n_nodes * budget ) break;
1578
1579
      /**
1580
       *  Get the node pointer from MortonID.
1581
       *
1582
       *  Two situations:
1583
       *  1. the pointer doesn't exist, then creates a lettreenode
1584
       */
1585
      #pragma omp critical
1586
      {
1587
        if ( !(*node->morton2node).count( (*it).second ) )
1588
        {
1589
          /** Create a LET node. */
1590
          (*node->morton2node)[ (*it).second ] = new NODE( (*it).second );
1591
        }
1592
        /** Insert */
1593
        auto *target = (*node->morton2node)[ (*it).second ];
1594
        node->NNNearNodeMortonIDs.insert( (*it).second );
1595
        node->NNNearNodes.insert( target );
1596
      } /** end pragma omp critical */
1597
    }
1598
  } /** end for each leaf owned leaf node in the local tree */
1599
  mpi::PrintProgress( "[END] Finish FindNearInteractions ...", tree.GetComm() );
1600
}; /** end FindNearInteractions() */
1601
1602
1603
1604
1605
template<typename NODE>
1606
void FindFarNodes( MortonHelper::Recursor r, NODE *target )
1607
{
1608
  /** Return while reaching the leaf level (recursion base case). */
1609
  if ( r.second > target->l ) return;
1610
  /** Compute the MortonID of the visiting node. */
1611
  size_t node_morton = MortonHelper::MortonID( r );
1612
1613
  //bool prunable = true;
1614
  auto & NearMortonIDs = target->NNNearNodeMortonIDs;
1615
1616
  /** Recur to children if the current node contains near interactions. */
1617
  if ( MortonHelper::ContainAny( node_morton, NearMortonIDs ) )
1618
  {
1619
    FindFarNodes( MortonHelper::RecurLeft( r ), target );
1620
    FindFarNodes( MortonHelper::RecurRight( r ), target );
1621
  }
1622
  else
1623
  {
1624
    if ( node_morton >= target->morton )
1625
      target->NNFarNodeMortonIDs.insert( node_morton );
1626
  }
1627
}; /** end FindFarNodes() */
1628
1629
1630
1631
1632
1633
1634
template<typename TREE>
1635
void SymmetrizeNearInteractions( TREE & tree )
1636
{
1637
  mpi::PrintProgress( "[BEG] SymmetrizeNearInteractions ...", tree.GetComm() );
1638
1639
  /** Derive type NODE from TREE. */
1640
  using NODE = typename TREE::NODE;
1641
  /** MPI Support */
1642
  int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
1643
  int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
1644
1645
  vector<vector<pair<size_t, size_t>>> sendlist( comm_size );
1646
  vector<vector<pair<size_t, size_t>>> recvlist( comm_size );
1647
1648
1649
  /**
1650
   *  Traverse local leaf nodes:
1651
   *
1652
   *  Loop over all near node MortonIDs, create
1653
   *
1654
   */
1655
  int n_nodes = 1 << tree.depth;
1656
  auto level_beg = tree.treelist.begin() + n_nodes - 1;
1657
1658
  #pragma omp parallel
1659
  {
1660
    /** Create a per thread list. Merge them into sendlist afterward. */
1661
    vector<vector<pair<size_t, size_t>>> list( comm_size );
1662
1663
    #pragma omp for
1664
    for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
1665
    {
1666
      auto *node = *(level_beg + node_ind);
1667
      //auto & NearMortonIDs = node->NNNearNodeMortonIDs;
1668
      for ( auto it : node->NNNearNodeMortonIDs )
1669
      {
1670
        int dest = tree.Morton2Rank( it );
1671
        if ( dest >= comm_size ) printf( "%8lu dest %d\n", it, dest );
1672
        list[ dest ].push_back( make_pair( it, node->morton ) );
1673
      }
1674
    } /** end pramga omp for */
1675
1676
    #pragma omp critical
1677
    {
1678
      for ( int p = 0; p < comm_size; p ++ )
1679
      {
1680
        sendlist[ p ].insert( sendlist[ p ].end(),
1681
            list[ p ].begin(), list[ p ].end() );
1682
      }
1683
    } /** end pragma omp critical*/
1684
  }; /** end pargma omp parallel */
1685
1686
1687
  /** Alltoallv */
1688
  mpi::AlltoallVector( sendlist, recvlist, tree.GetComm() );
1689
1690
1691
  /** Loop over queries. */
1692
  for ( int p = 0; p < comm_size; p ++ )
1693
  {
1694
    for ( auto & query : recvlist[ p ]  )
1695
    {
1696
      /** Check if query node is allocated? */
1697
      #pragma omp critical
1698
      {
1699
        auto* node = tree.morton2node[ query.first ];
1700
        if ( !tree.morton2node.count( query.second ) )
1701
        {
1702
          tree.morton2node[ query.second ] = new NODE( query.second );
1703
        }
1704
        node->data.lock.Acquire();
1705
        {
1706
          node->NNNearNodes.insert( tree.morton2node[ query.second ] );
1707
          node->NNNearNodeMortonIDs.insert( query.second );
1708
        }
1709
        node->data.lock.Release();
1710
      }
1711
    }; /** end pargma omp parallel for */
1712
  }
1713
  mpi::Barrier( tree.GetComm() );
1714
  mpi::PrintProgress( "[END] SymmetrizeNearInteractions ...", tree.GetComm() );
1715
}; /** end SymmetrizeNearInteractions() */
1716
1717
1718
template<typename TREE>
1719
void SymmetrizeFarInteractions( TREE & tree )
1720
{
1721
  mpi::PrintProgress( "[BEG] SymmetrizeFarInteractions ...", tree.GetComm() );
1722
1723
  /** Derive type NODE from TREE. */
1724
  using NODE = typename TREE::NODE;
1725
  ///** MPI Support. */
1726
  //int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
1727
  //int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
1728
1729
  vector<vector<pair<size_t, size_t>>> sendlist( tree.GetCommSize() );
1730
  vector<vector<pair<size_t, size_t>>> recvlist( tree.GetCommSize() );
1731
1732
  /** Local traversal */
1733
  #pragma omp parallel
1734
  {
1735
    /** Create a per thread list. Merge them into sendlist afterward. */
1736
    vector<vector<pair<size_t, size_t>>> list( tree.GetCommSize() );
1737
1738
    #pragma omp for
1739
    for ( size_t i = 1; i < tree.treelist.size(); i ++ )
1740
    {
1741
      auto *node = tree.treelist[ i ];
1742
      for ( auto it  = node->NNFarNodeMortonIDs.begin();
1743
                 it != node->NNFarNodeMortonIDs.end(); it ++ )
1744
      {
1745
        /** Allocate if not exist */
1746
        #pragma omp critical
1747
        {
1748
          if ( !tree.morton2node.count( *it ) )
1749
          {
1750
            tree.morton2node[ *it ] = new NODE( *it );
1751
          }
1752
          node->NNFarNodes.insert( tree.morton2node[ *it ] );
1753
        }
1754
        int dest = tree.Morton2Rank( *it );
1755
        if ( dest >= tree.GetCommSize() ) printf( "%8lu dest %d\n", *it, dest );
1756
        list[ dest ].push_back( make_pair( *it, node->morton ) );
1757
      }
1758
    }
1759
1760
    #pragma omp critical
1761
    {
1762
      for ( int p = 0; p < tree.GetCommSize(); p ++ )
1763
      {
1764
        sendlist[ p ].insert( sendlist[ p ].end(),
1765
            list[ p ].begin(), list[ p ].end() );
1766
      }
1767
    } /** end pragma omp critical*/
1768
  }
1769
1770
1771
  /** Distributed traversal */
1772
  #pragma omp parallel
1773
  {
1774
    /** Create a per thread list. Merge them into sendlist afterward. */
1775
    vector<vector<pair<size_t, size_t>>> list( tree.GetCommSize() );
1776
1777
    #pragma omp for
1778
    for ( size_t i = 0; i < tree.mpitreelists.size(); i ++ )
1779
    {
1780
      auto *node = tree.mpitreelists[ i ];
1781
      for ( auto it  = node->NNFarNodeMortonIDs.begin();
1782
                 it != node->NNFarNodeMortonIDs.end(); it ++ )
1783
      {
1784
        /** Allocate if not exist */
1785
        #pragma omp critical
1786
        {
1787
          if ( !tree.morton2node.count( *it ) )
1788
          {
1789
            tree.morton2node[ *it ] = new NODE( *it );
1790
          }
1791
          node->NNFarNodes.insert( tree.morton2node[ *it ] );
1792
        }
1793
        int dest = tree.Morton2Rank( *it );
1794
        if ( dest >= tree.GetCommSize() ) printf( "%8lu dest %d\n", *it, dest ); fflush( stdout );
1795
        list[ dest ].push_back( make_pair( *it, node->morton ) );
1796
      }
1797
    }
1798
1799
    #pragma omp critical
1800
    {
1801
      for ( int p = 0; p < tree.GetCommSize(); p ++ )
1802
      {
1803
        sendlist[ p ].insert( sendlist[ p ].end(),
1804
            list[ p ].begin(), list[ p ].end() );
1805
      }
1806
    } /** end pragma omp critical*/
1807
  }
1808
1809
  /** Alltoallv */
1810
  mpi::AlltoallVector( sendlist, recvlist, tree.GetComm() );
1811
1812
  /** Loop over queries */
1813
  for ( int p = 0; p < tree.GetCommSize(); p ++ )
1814
  {
1815
    //#pragma omp parallel for
1816
    for ( auto & query : recvlist[ p ] )
1817
    {
1818
      /** Check if query node is allocated?  */
1819
      #pragma omp critical
1820
      {
1821
        if ( !tree.morton2node.count( query.second ) )
1822
        {
1823
          tree.morton2node[ query.second ] = new NODE( query.second );
1824
          //printf( "rank %d, %8lu level %lu creates far LET %8lu (symmetrize)\n",
1825
          //    comm_rank, node->morton, node->l, query.second );
1826
        }
1827
        auto* node = tree.morton2node[ query.first ];
1828
        node->data.lock.Acquire();
1829
        {
1830
          node->NNFarNodes.insert( tree.morton2node[ query.second ] );
1831
          node->NNFarNodeMortonIDs.insert( query.second );
1832
        }
1833
        node->data.lock.Release();
1834
        assert( tree.Morton2Rank( node->morton ) == tree.GetCommRank() );
1835
      } /** end pragma omp critical */
1836
    } /** end pargma omp parallel for */
1837
  }
1838
1839
  mpi::Barrier( tree.GetComm() );
1840
  mpi::PrintProgress( "[END] SymmetrizeFarInteractions ...", tree.GetComm() );
1841
}; /** end SymmetrizeFarInteractions() */
1842
1843
1844
1845
/**
1846
 *  TODO: need send and recv interaction lists for each rank
1847
 *
1848
 *  SendNNNear[ rank ][ local  morton ]
1849
 *  RecvNNNear[ rank ][ remote morton ]
1850
 *
1851
 *  for each leaf alpha and beta in Near(alpha)
1852
 *    SendNNNear[ rank(beta) ] += Morton(alpha)
1853
 *
1854
 *  Alltoallv( SendNNNear, rbuff );
1855
 *
1856
 *  for each rank
1857
 *    RecvNNNear[ rank ][ remote morton ] = offset in rbuff
1858
 *
1859
 */
1860
template<typename TREE>
1861
void BuildInteractionListPerRank( TREE &tree, bool is_near )
1862
{
1863
  /** Derive type T from TREE. */
1864
  using T = typename TREE::T;
1865
  /** MPI Support. */
1866
  int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
1867
  int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
1868
1869
  /** Interaction set per rank in MortonID. */
1870
  vector<set<size_t>> lists( comm_size );
1871
1872
  if ( is_near )
1873
  {
1874
    /** Traverse leaf nodes (near interation lists) */
1875
    int n_nodes = 1 << tree.depth;
1876
    auto level_beg = tree.treelist.begin() + n_nodes - 1;
1877
1878
    #pragma omp parallel
1879
    {
1880
      /** Create a per thread list. Merge them into sendlist afterward. */
1881
      vector<set<size_t>> list( comm_size );
1882
1883
      #pragma omp for
1884
      for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
1885
      {
1886
        auto *node = *(level_beg + node_ind);
1887
        auto & NearMortonIDs = node->NNNearNodeMortonIDs;
1888
        node->DistNear.resize( comm_size );
1889
        for ( auto it : NearMortonIDs )
1890
        {
1891
          int dest = tree.Morton2Rank( it );
1892
          if ( dest >= comm_size ) printf( "%8lu dest %d\n", it, dest );
1893
          if ( dest != comm_rank ) list[ dest ].insert( node->morton );
1894
          node->DistNear[ dest ][ it ] = Data<T>();
1895
        }
1896
      } /** end pramga omp for */
1897
1898
      #pragma omp critical
1899
      {
1900
        for ( int p = 0; p < comm_size; p ++ )
1901
          lists[ p ].insert( list[ p ].begin(), list[ p ].end() );
1902
      } /** end pragma omp critical*/
1903
    }; /** end pargma omp parallel */
1904
1905
1906
    /** Cast set to vector. */
1907
    vector<vector<size_t>> recvlist( comm_size );
1908
    if ( !tree.NearSentToRank.size() ) tree.NearSentToRank.resize( comm_size );
1909
    if ( !tree.NearRecvFromRank.size() ) tree.NearRecvFromRank.resize( comm_size );
1910
    #pragma omp parallel for
1911
    for ( int p = 0; p < comm_size; p ++ )
1912
    {
1913
      tree.NearSentToRank[ p ].insert( tree.NearSentToRank[ p ].end(),
1914
          lists[ p ].begin(), lists[ p ].end() );
1915
    }
1916
1917
    /** Use buffer recvlist to catch Alltoallv results. */
1918
    mpi::AlltoallVector( tree.NearSentToRank, recvlist, tree.GetComm() );
1919
1920
    /** Cast vector of vectors to vector of maps */
1921
    #pragma omp parallel for
1922
    for ( int p = 0; p < comm_size; p ++ )
1923
      for ( int i = 0; i < recvlist[ p ].size(); i ++ )
1924
        tree.NearRecvFromRank[ p ][ recvlist[ p ][ i ] ] = i;
1925
  }
1926
  else
1927
  {
1928
    #pragma omp parallel
1929
    {
1930
      /** Create a per thread list. Merge them into sendlist afterward. */
1931
      vector<set<size_t>> list( comm_size );
1932
1933
      /** Local traversal */
1934
      #pragma omp for
1935
      for ( size_t i = 1; i < tree.treelist.size(); i ++ )
1936
      {
1937
        auto *node = tree.treelist[ i ];
1938
        node->DistFar.resize( comm_size );
1939
        for ( auto it  = node->NNFarNodeMortonIDs.begin();
1940
                   it != node->NNFarNodeMortonIDs.end(); it ++ )
1941
        {
1942
          int dest = tree.Morton2Rank( *it );
1943
          if ( dest >= comm_size ) printf( "%8lu dest %d\n", *it, dest );
1944
          if ( dest != comm_rank )
1945
          {
1946
            list[ dest ].insert( node->morton );
1947
            //node->data.FarDependents.insert( dest );
1948
          }
1949
          node->DistFar[ dest ][ *it ] = Data<T>();
1950
        }
1951
      }
1952
1953
      /** Distributed traversal */
1954
      #pragma omp for
1955
      for ( size_t i = 0; i < tree.mpitreelists.size(); i ++ )
1956
      {
1957
        auto *node = tree.mpitreelists[ i ];
1958
        node->DistFar.resize( comm_size );
1959
        /** Add to the list iff this MPI rank owns the distributed node */
1960
        if ( tree.Morton2Rank( node->morton ) == comm_rank )
1961
        {
1962
          for ( auto it  = node->NNFarNodeMortonIDs.begin();
1963
                     it != node->NNFarNodeMortonIDs.end(); it ++ )
1964
          {
1965
            int dest = tree.Morton2Rank( *it );
1966
            if ( dest >= comm_size ) printf( "%8lu dest %d\n", *it, dest );
1967
            if ( dest != comm_rank )
1968
            {
1969
              list[ dest ].insert( node->morton );
1970
              //node->data.FarDependents.insert( dest );
1971
            }
1972
            node->DistFar[ dest ][ *it ]  = Data<T>();
1973
          }
1974
        }
1975
      }
1976
      /** Merge lists from all threads */
1977
      #pragma omp critical
1978
      {
1979
        for ( int p = 0; p < comm_size; p ++ )
1980
          lists[ p ].insert( list[ p ].begin(), list[ p ].end() );
1981
      } /** end pragma omp critical*/
1982
1983
    }; /** end pargma omp parallel */
1984
1985
    /** Cast set to vector */
1986
    vector<vector<size_t>> recvlist( comm_size );
1987
    if ( !tree.FarSentToRank.size() ) tree.FarSentToRank.resize( comm_size );
1988
    if ( !tree.FarRecvFromRank.size() ) tree.FarRecvFromRank.resize( comm_size );
1989
    #pragma omp parallel for
1990
    for ( int p = 0; p < comm_size; p ++ )
1991
    {
1992
      tree.FarSentToRank[ p ].insert( tree.FarSentToRank[ p ].end(),
1993
          lists[ p ].begin(), lists[ p ].end() );
1994
    }
1995
1996
1997
    /** Use buffer recvlist to catch Alltoallv results. */
1998
    mpi::AlltoallVector( tree.FarSentToRank, recvlist, tree.GetComm() );
1999
2000
    /** Cast vector of vectors to vector of maps */
2001
    #pragma omp parallel for
2002
    for ( int p = 0; p < comm_size; p ++ )
2003
      for ( int i = 0; i < recvlist[ p ].size(); i ++ )
2004
        tree.FarRecvFromRank[ p ][ recvlist[ p ][ i ] ] = i;
2005
  }
2006
2007
  mpi::Barrier( tree.GetComm() );
2008
}; /** end BuildInteractionListPerRank() */
2009
2010
2011
template<typename TREE>
2012
pair<double, double> NonCompressedRatio( TREE &tree )
2013
{
2014
  /** Tree MPI communicator */
2015
  int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
2016
  int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
2017
2018
  /** Use double for accumulation. */
2019
  double ratio_n = 0.0;
2020
  double ratio_f = 0.0;
2021
2022
2023
  /** Traverse all nodes in the local tree. */
2024
  for ( auto &tar : tree.treelist )
2025
  {
2026
    if ( tar->isleaf )
2027
    {
2028
      for ( auto nearID : tar->NNNearNodeMortonIDs )
2029
      {
2030
        auto *src = tree.morton2node[ nearID ];
2031
        assert( src );
2032
        double m = tar->gids.size();
2033
        double n = src->gids.size();
2034
        double N = tree.n;
2035
        ratio_n += ( m / N ) * ( n / N );
2036
      }
2037
    }
2038
2039
    for ( auto farID : tar->NNFarNodeMortonIDs )
2040
    {
2041
      auto *src = tree.morton2node[ farID ];
2042
      assert( src );
2043
      double m = tar->data.skels.size();
2044
      double n = src->data.skels.size();
2045
      double N = tree.n;
2046
      ratio_f += ( m / N ) * ( n / N );
2047
    }
2048
  }
2049
2050
  /** Traverse all nodes in the distributed tree. */
2051
  for ( auto &tar : tree.mpitreelists )
2052
  {
2053
    if ( !tar->child || tar->GetCommRank() ) continue;
2054
    for ( auto farID : tar->NNFarNodeMortonIDs )
2055
    {
2056
      auto *src = tree.morton2node[ farID ];
2057
      assert( src );
2058
      double m = tar->data.skels.size();
2059
      double n = src->data.skels.size();
2060
      double N = tree.n;
2061
      ratio_f += ( m / N ) * ( n / N );
2062
    }
2063
  }
2064
2065
  /** Allreduce total evaluations from all MPI processes. */
2066
  pair<double, double> ret( 0, 0 );
2067
  mpi::Allreduce( &ratio_n, &(ret.first),  1, MPI_SUM, tree.GetComm() );
2068
  mpi::Allreduce( &ratio_f, &(ret.second), 1, MPI_SUM, tree.GetComm() );
2069
2070
  return ret;
2071
};
2072
2073
2074
2075
template<typename T, typename TREE>
2076
void PackNear( TREE &tree, string option, int p,
2077
    vector<size_t> &sendsizes,
2078
    vector<size_t> &sendskels,
2079
    vector<T> &sendbuffs )
2080
{
2081
  vector<size_t> offsets( 1, 0 );
2082
2083
  for ( auto it : tree.NearSentToRank[ p ] )
2084
  {
2085
    auto *node = tree.morton2node[ it ];
2086
    auto &gids = node->gids;
2087
    if ( !option.compare( string( "leafgids" ) ) )
2088
    {
2089
      sendsizes.push_back( gids.size() );
2090
      sendskels.insert( sendskels.end(), gids.begin(), gids.end() );
2091
    }
2092
    else
2093
    {
2094
      auto &w_view = node->data.w_view;
2095
      sendsizes.push_back( gids.size() * w_view.col() );
2096
      offsets.push_back( sendsizes.back() + offsets.back() );
2097
    }
2098
  }
2099
2100
  if ( offsets.size() ) sendbuffs.resize( offsets.back() );
2101
2102
  if ( !option.compare( string( "leafweights" ) ) )
2103
  {
2104
    #pragma omp parallel for
2105
    for ( size_t i = 0; i < tree.NearSentToRank[ p ].size(); i ++ )
2106
    {
2107
      auto *node = tree.morton2node[ tree.NearSentToRank[ p ][ i ] ];
2108
      auto &gids = node->gids;
2109
      auto &w_view = node->data.w_view;
2110
      auto  w_leaf = w_view.toData();
2111
      size_t offset = offsets[ i ];
2112
      for ( size_t j = 0; j < w_leaf.size(); j ++ )
2113
        sendbuffs[ offset + j ] = w_leaf[ j ];
2114
    }
2115
  }
2116
};
2117
2118
2119
template<typename T, typename TREE>
2120
void UnpackLeaf( TREE &tree, string option, int p,
2121
    const vector<size_t> &recvsizes,
2122
    const vector<size_t> &recvskels,
2123
    const vector<T> &recvbuffs )
2124
{
2125
  vector<size_t> offsets( 1, 0 );
2126
  for ( auto it : recvsizes ) offsets.push_back( offsets.back() + it );
2127
2128
  for ( auto it : tree.NearRecvFromRank[ p ] )
2129
  {
2130
    auto *node = tree.morton2node[ it.first ];
2131
    if ( !option.compare( string( "leafgids" ) ) )
2132
    {
2133
      auto &gids = node->gids;
2134
      size_t i = it.second;
2135
      gids.reserve( recvsizes[ i ] );
2136
      for ( uint64_t j  = offsets[ i + 0 ];
2137
                     j  < offsets[ i + 1 ];
2138
                     j ++ )
2139
      {
2140
        gids.push_back( recvskels[ j ] );
2141
      }
2142
    }
2143
    else
2144
    {
2145
      /** Number of right hand sides */
2146
      size_t nrhs = tree.setup.w->col();
2147
      auto &w_leaf = node->data.w_leaf;
2148
      size_t i = it.second;
2149
      w_leaf.resize( recvsizes[ i ] / nrhs, nrhs );
2150
      //printf( "%d recv w_leaf from %d [%lu %lu]\n",
2151
      //    comm_rank, p, w_leaf.row(), w_leaf.col() ); fflush( stdout );
2152
      for ( uint64_t j  = offsets[ i + 0 ], jj = 0;
2153
                     j  < offsets[ i + 1 ];
2154
                     j ++,                 jj ++ )
2155
      {
2156
        w_leaf[ jj ] = recvbuffs[ j ];
2157
      }
2158
    }
2159
  }
2160
};
2161
2162
2163
template<typename T, typename TREE>
2164
void PackFar( TREE &tree, string option, int p,
2165
    vector<size_t> &sendsizes,
2166
    vector<size_t> &sendskels,
2167
    vector<T> &sendbuffs )
2168
{
2169
  for ( auto it : tree.FarSentToRank[ p ] )
2170
  {
2171
    auto *node = tree.morton2node[ it ];
2172
    auto &skels = node->data.skels;
2173
    if ( !option.compare( string( "skelgids" ) ) )
2174
    {
2175
      sendsizes.push_back( skels.size() );
2176
      sendskels.insert( sendskels.end(), skels.begin(), skels.end() );
2177
    }
2178
    else
2179
    {
2180
      auto &w_skel = node->data.w_skel;
2181
      sendsizes.push_back( w_skel.size() );
2182
      sendbuffs.insert( sendbuffs.end(), w_skel.begin(), w_skel.end() );
2183
    }
2184
  }
2185
}; /** end PackFar() */
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
/** @brief Pack a list of weights and their sizes to two messages. */
2199
template<typename TREE, typename T>
2200
void PackWeights( TREE &tree, int p,
2201
    vector<T> &sendbuffs, vector<size_t> &sendsizes )
2202
{
2203
  for ( auto it : tree.NearSentToRank[ p ] )
2204
  {
2205
    auto *node = tree.morton2node[ it ];
2206
    auto w_leaf = node->data.w_view.toData();
2207
    sendbuffs.insert( sendbuffs.end(), w_leaf.begin(), w_leaf.end() );
2208
    sendsizes.push_back( w_leaf.size() );
2209
  }
2210
}; /** end PackWeights() */
2211
2212
2213
/** @brief Unpack a list of weights and their sizes. */
2214
template<typename TREE, typename T>
2215
void UnpackWeights( TREE &tree, int p,
2216
    const vector<T> recvbuffs, const vector<size_t> &recvsizes )
2217
{
2218
  vector<size_t> offsets( 1, 0 );
2219
  for ( auto it : recvsizes ) offsets.push_back( offsets.back() + it );
2220
2221
  for ( auto it : tree.NearRecvFromRank[ p ] )
2222
  {
2223
    /** Get LET node pointer. */
2224
    auto *node = tree.morton2node[ it.first ];
2225
    /** Number of right hand sides */
2226
    size_t nrhs = tree.setup.w->col();
2227
    auto &w_leaf = node->data.w_leaf;
2228
    size_t i = it.second;
2229
    w_leaf.resize( recvsizes[ i ] / nrhs, nrhs );
2230
    for ( uint64_t j  = offsets[ i + 0 ], jj = 0;
2231
                   j  < offsets[ i + 1 ];
2232
                   j ++,                  jj ++ )
2233
    {
2234
      w_leaf[ jj ] = recvbuffs[ j ];
2235
    }
2236
  }
2237
}; /** end UnpackWeights() */
2238
2239
2240
2241
/** @brief Pack a list of skeletons and their sizes to two messages. */
2242
template<typename TREE>
2243
void PackSkeletons( TREE &tree, int p,
2244
    vector<size_t> &sendbuffs, vector<size_t> &sendsizes )
2245
{
2246
  for ( auto it : tree.FarSentToRank[ p ] )
2247
  {
2248
    /** Get LET node pointer. */
2249
    auto *node = tree.morton2node[ it ];
2250
    auto &skels = node->data.skels;
2251
    sendbuffs.insert( sendbuffs.end(), skels.begin(), skels.end() );
2252
    sendsizes.push_back( skels.size() );
2253
  }
2254
}; /** end PackSkeletons() */
2255
2256
2257
/** @brief Unpack a list of skeletons and their sizes. */
2258
template<typename TREE>
2259
void UnpackSkeletons( TREE &tree, int p,
2260
    const vector<size_t> recvbuffs, const vector<size_t> &recvsizes )
2261
{
2262
  vector<size_t> offsets( 1, 0 );
2263
  for ( auto it : recvsizes ) offsets.push_back( offsets.back() + it );
2264
2265
  for ( auto it : tree.FarRecvFromRank[ p ] )
2266
  {
2267
    /** Get LET node pointer. */
2268
    auto *node = tree.morton2node[ it.first ];
2269
    auto &skels = node->data.skels;
2270
    size_t i = it.second;
2271
    skels.clear();
2272
    skels.reserve( recvsizes[ i ] );
2273
    for ( uint64_t j  = offsets[ i + 0 ];
2274
                   j  < offsets[ i + 1 ];
2275
                   j ++ )
2276
    {
2277
      skels.push_back( recvbuffs[ j ] );
2278
    }
2279
  }
2280
}; /** end UnpackSkeletons() */
2281
2282
2283
2284
/** @brief Pack a list of skeleton weights and their sizes to two messages. */
2285
template<typename TREE, typename T>
2286
void PackSkeletonWeights( TREE &tree, int p,
2287
    vector<T> &sendbuffs, vector<size_t> &sendsizes )
2288
{
2289
  for ( auto it : tree.FarSentToRank[ p ] )
2290
  {
2291
    auto *node = tree.morton2node[ it ];
2292
    auto &w_skel = node->data.w_skel;
2293
    sendbuffs.insert( sendbuffs.end(), w_skel.begin(), w_skel.end() );
2294
    sendsizes.push_back( w_skel.size() );
2295
  }
2296
}; /** end PackSkeletonWeights() */
2297
2298
2299
/** @brief Unpack a list of skeletons and their sizes. */
2300
template<typename TREE, typename T>
2301
void UnpackSkeletonWeights( TREE &tree, int p,
2302
    const vector<T> recvbuffs, const vector<size_t> &recvsizes )
2303
{
2304
  vector<size_t> offsets( 1, 0 );
2305
  for ( auto it : recvsizes ) offsets.push_back( offsets.back() + it );
2306
2307
  for ( auto it : tree.FarRecvFromRank[ p ] )
2308
  {
2309
    /** Get LET node pointer. */
2310
    auto *node = tree.morton2node[ it.first ];
2311
    /** Number of right hand sides */
2312
    size_t nrhs = tree.setup.w->col();
2313
    auto &w_skel = node->data.w_skel;
2314
    size_t i = it.second;
2315
    w_skel.resize( recvsizes[ i ] / nrhs, nrhs );
2316
    for ( uint64_t j  = offsets[ i + 0 ], jj = 0;
2317
                   j  < offsets[ i + 1 ];
2318
                   j ++,                  jj ++ )
2319
    {
2320
      w_skel[ jj ] = recvbuffs[ j ];
2321
    }
2322
  }
2323
}; /** end UnpackSkeletonWeights() */
2324
2325
2326
2327
2328
2329
2330
template<typename T, typename TREE>
2331
void UnpackFar( TREE &tree, string option, int p,
2332
    const vector<size_t> &recvsizes,
2333
    const vector<size_t> &recvskels,
2334
    const vector<T> &recvbuffs )
2335
{
2336
  vector<size_t> offsets( 1, 0 );
2337
  for ( auto it : recvsizes ) offsets.push_back( offsets.back() + it );
2338
2339
  for ( auto it : tree.FarRecvFromRank[ p ] )
2340
  {
2341
    /** Get LET node pointer */
2342
    auto *node = tree.morton2node[ it.first ];
2343
    if ( !option.compare( string( "skelgids" ) ) )
2344
    {
2345
      auto &skels = node->data.skels;
2346
      size_t i = it.second;
2347
      skels.clear();
2348
      skels.reserve( recvsizes[ i ] );
2349
      for ( uint64_t j  = offsets[ i + 0 ];
2350
                     j  < offsets[ i + 1 ];
2351
                     j ++ )
2352
      {
2353
        skels.push_back( recvskels[ j ] );
2354
      }
2355
    }
2356
    else
2357
    {
2358
      /** Number of right hand sides */
2359
      size_t nrhs = tree.setup.w->col();
2360
      auto &w_skel = node->data.w_skel;
2361
      size_t i = it.second;
2362
      w_skel.resize( recvsizes[ i ] / nrhs, nrhs );
2363
      //printf( "%d recv w_skel (%8lu) from %d [%lu %lu], i %lu, offset[%lu %lu] \n",
2364
      //    comm_rank, (*it).first, p, w_skel.row(), w_skel.col(), i,
2365
      //    offsets[ p ][ i + 0 ], offsets[ p ][ i + 1 ] ); fflush( stdout );
2366
      for ( uint64_t j  = offsets[ i + 0 ], jj = 0;
2367
                     j  < offsets[ i + 1 ];
2368
                     j ++,                  jj ++ )
2369
      {
2370
        w_skel[ jj ] = recvbuffs[ j ];
2371
        //if ( jj < 5 ) printf( "%E ", w_skel[ jj ] ); fflush( stdout );
2372
      }
2373
      //printf( "\n" ); fflush( stdout );
2374
    }
2375
  }
2376
};
2377
2378
2379
template<typename T, typename TREE>
2380
class PackNearTask : public SendTask<T, TREE>
2381
{
2382
  public:
2383
2384
    PackNearTask( TREE *tree, int src, int tar, int key )
2385
      : SendTask<T, TREE>( tree, src, tar, key )
2386
    {
2387
      /** Submit and perform dependency analysis automaticallu. */
2388
      this->Submit();
2389
      this->DependencyAnalysis();
2390
    };
2391
2392
    void DependencyAnalysis()
2393
    {
2394
      TREE &tree = *(this->arg);
2395
      tree.DependOnNearInteractions( this->tar, this );
2396
    };
2397
2398
    /** Instansiate Pack() for SendTask. */
2399
    void Pack()
2400
    {
2401
      PackWeights( *this->arg, this->tar,
2402
          this->send_buffs, this->send_sizes );
2403
    };
2404
2405
}; /** end class PackNearTask */
2406
2407
2408
2409
2410
/**
2411
 *  AlltoallvTask is used perform MPI_Alltoallv in asynchronous.
2412
 *  Overall there will be (p - 1) tasks per MPI rank. Each task
2413
 *  performs Isend while the dependencies toward the destination
2414
 *  is fullfilled.
2415
 *
2416
 *  To receive the results, each MPI rank also actively runs a
2417
 *  ListenerTask. Listener will keep pulling for incioming message
2418
 *  that matches. Once the received results are secured, it will
2419
 *  release dependent tasks.
2420
 */
2421
template<typename T, typename TREE>
2422
class UnpackLeafTask : public RecvTask<T, TREE>
2423
{
2424
  public:
2425
2426
    UnpackLeafTask( TREE *tree, int src, int tar, int key )
2427
      : RecvTask<T, TREE>( tree, src, tar, key )
2428
    {
2429
      /** Submit and perform dependency analysis automaticallu. */
2430
      this->Submit();
2431
      this->DependencyAnalysis();
2432
    };
2433
2434
    void Unpack()
2435
    {
2436
      UnpackWeights( *this->arg, this->src,
2437
          this->recv_buffs, this->recv_sizes );
2438
    };
2439
2440
}; /** end class UnpackLeafTask */
2441
2442
2443
/** @brief */
2444
template<typename T, typename TREE>
2445
class PackFarTask : public SendTask<T, TREE>
2446
{
2447
  public:
2448
2449
    PackFarTask( TREE *tree, int src, int tar, int key )
2450
      : SendTask<T, TREE>( tree, src, tar, key )
2451
    {
2452
      /** Submit and perform dependency analysis automaticallu. */
2453
      this->Submit();
2454
      this->DependencyAnalysis();
2455
    };
2456
2457
    void DependencyAnalysis()
2458
    {
2459
      TREE &tree = *(this->arg);
2460
      tree.DependOnFarInteractions( this->tar, this );
2461
    };
2462
2463
    /** Instansiate Pack() for SendTask. */
2464
    void Pack()
2465
    {
2466
      PackSkeletonWeights( *this->arg, this->tar,
2467
          this->send_buffs, this->send_sizes );
2468
    };
2469
2470
}; /** end class PackFarTask */
2471
2472
2473
/** @brief */
2474
template<typename T, typename TREE>
2475
class UnpackFarTask : public RecvTask<T, TREE>
2476
{
2477
  public:
2478
2479
    UnpackFarTask( TREE *tree, int src, int tar, int key )
2480
      : RecvTask<T, TREE>( tree, src, tar, key )
2481
    {
2482
      /** Submit and perform dependency analysis automaticallu. */
2483
      this->Submit();
2484
      this->DependencyAnalysis();
2485
    };
2486
2487
    void Unpack()
2488
    {
2489
      UnpackSkeletonWeights( *this->arg, this->src,
2490
          this->recv_buffs, this->recv_sizes );
2491
    };
2492
2493
}; /** end class UnpackFarTask */
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
/**
2506
 *  Send my skeletons (in gids and params) to other ranks
2507
 *  using FarSentToRank[:].
2508
 *
2509
 *  Recv skeletons from other ranks
2510
 *  using FarRecvFromRank[:].
2511
 */
2512
template<typename TREE>
2513
void ExchangeLET( TREE &tree, string option )
2514
{
2515
  /** Derive type T from TREE. */
2516
  using T = typename TREE::T;
2517
  /** MPI Support. */
2518
  int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
2519
  int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
2520
2521
  /** Buffers for sizes and skeletons */
2522
  vector<vector<size_t>> sendsizes( comm_size );
2523
  vector<vector<size_t>> recvsizes( comm_size );
2524
  vector<vector<size_t>> sendskels( comm_size );
2525
  vector<vector<size_t>> recvskels( comm_size );
2526
  vector<vector<T>>      sendbuffs( comm_size );
2527
  vector<vector<T>>      recvbuffs( comm_size );
2528
2529
  /** Pack */
2530
  #pragma omp parallel for
2531
  for ( int p = 0; p < comm_size; p ++ )
2532
  {
2533
    if ( !option.compare( 0, 4, "leaf" ) )
2534
    {
2535
      PackNear( tree, option, p, sendsizes[ p ], sendskels[ p ], sendbuffs[ p ] );
2536
    }
2537
    else if ( !option.compare( 0, 4, "skel" ) )
2538
    {
2539
      PackFar( tree, option, p, sendsizes[ p ], sendskels[ p ], sendbuffs[ p ] );
2540
    }
2541
    else
2542
    {
2543
      printf( "ExchangeLET: option <%s> not available.\n", option.data() );
2544
      exit( 1 );
2545
    }
2546
  }
2547
2548
  /** Alltoallv */
2549
  mpi::AlltoallVector( sendsizes, recvsizes, tree.GetComm() );
2550
  if ( !option.compare( string( "skelgids" ) ) ||
2551
       !option.compare( string( "leafgids" ) ) )
2552
  {
2553
    auto &K = *tree.setup.K;
2554
    mpi::AlltoallVector( sendskels, recvskels, tree.GetComm() );
2555
    K.RequestIndices( recvskels );
2556
  }
2557
  else
2558
  {
2559
    double beg = omp_get_wtime();
2560
    mpi::AlltoallVector( sendbuffs, recvbuffs, tree.GetComm() );
2561
    double a2av_time = omp_get_wtime() - beg;
2562
    if ( comm_rank == 0 ) printf( "a2av_time %lfs\n", a2av_time );
2563
  }
2564
2565
2566
  /** Uppack */
2567
  #pragma omp parallel for
2568
  for ( int p = 0; p < comm_size; p ++ )
2569
  {
2570
    if ( !option.compare( 0, 4, "leaf" ) )
2571
    {
2572
      UnpackLeaf( tree, option, p, recvsizes[ p ], recvskels[ p ], recvbuffs[ p ] );
2573
    }
2574
    else if ( !option.compare( 0, 4, "skel" ) )
2575
    {
2576
      UnpackFar( tree, option, p, recvsizes[ p ], recvskels[ p ], recvbuffs[ p ] );
2577
    }
2578
    else
2579
    {
2580
      printf( "ExchangeLET: option <%s> not available.\n", option.data() );
2581
      exit( 1 );
2582
    }
2583
  }
2584
2585
2586
}; /** end ExchangeLET() */
2587
2588
2589
2590
template<typename T, typename TREE>
2591
void AsyncExchangeLET( TREE &tree, string option )
2592
{
2593
  /** MPI */
2594
  int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
2595
  int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
2596
2597
  /** Create sending tasks. */
2598
  for ( int p = 0; p < comm_size; p ++ )
2599
  {
2600
    if ( !option.compare( 0, 4, "leaf" ) )
2601
    {
2602
      auto *task = new PackNearTask<T, TREE>( &tree, comm_rank, p, 300 );
2603
      /** Set src, tar, and key (tags). */
2604
      //task->Set( &tree, comm_rank, p, 300 );
2605
      //task->Submit();
2606
      //task->DependencyAnalysis();
2607
    }
2608
    else if ( !option.compare( 0, 4, "skel" ) )
2609
    {
2610
      auto *task = new PackFarTask<T, TREE>( &tree, comm_rank, p, 306 );
2611
      /** Set src, tar, and key (tags). */
2612
      //task->Set( &tree, comm_rank, p, 306 );
2613
      //task->Submit();
2614
      //task->DependencyAnalysis();
2615
    }
2616
    else
2617
    {
2618
      printf( "AsyncExchangeLET: option <%s> not available.\n", option.data() );
2619
      exit( 1 );
2620
    }
2621
  }
2622
2623
  /** Create receiving tasks */
2624
  for ( int p = 0; p < comm_size; p ++ )
2625
  {
2626
    if ( !option.compare( 0, 4, "leaf" ) )
2627
    {
2628
      auto *task = new UnpackLeafTask<T, TREE>( &tree, p, comm_rank, 300 );
2629
      /** Set src, tar, and key (tags). */
2630
      //task->Set( &tree, p, comm_rank, 300 );
2631
      //task->Submit();
2632
      //task->DependencyAnalysis();
2633
    }
2634
    else if ( !option.compare( 0, 4, "skel" ) )
2635
    {
2636
      auto *task = new UnpackFarTask<T, TREE>( &tree, p, comm_rank, 306 );
2637
      /** Set src, tar, and key (tags). */
2638
      //task->Set( &tree, p, comm_rank, 306 );
2639
      //task->Submit();
2640
      //task->DependencyAnalysis();
2641
    }
2642
    else
2643
    {
2644
      printf( "AsyncExchangeLET: option <%s> not available.\n", option.data() );
2645
      exit( 1 );
2646
    }
2647
  }
2648
2649
}; /** AsyncExchangeLET() */
2650
2651
2652
2653
2654
template<typename T, typename TREE>
2655
void ExchangeNeighbors( TREE &tree )
2656
{
2657
  mpi::PrintProgress( "[BEG] ExchangeNeighbors ...", tree.GetComm() );
2658
2659
  int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
2660
  int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
2661
2662
  /** Alltoallv buffers */
2663
  vector<vector<size_t>> send_buff( comm_size );
2664
  vector<vector<size_t>> recv_buff( comm_size );
2665
2666
  /** NN<STAR, CIDS, pair<T, size_t>> */
2667
  unordered_set<size_t> requested_gids;
2668
  auto &NN = *tree.setup.NN;
2669
2670
  /** Remove duplication. */
2671
  for ( auto & it : NN )
2672
  {
2673
    if ( it.second >= 0 && it.second < tree.n )
2674
      requested_gids.insert( it.second );
2675
  }
2676
2677
  /** Remove owned gids. */
2678
  for ( auto it : tree.treelist[ 0 ]->gids ) requested_gids.erase( it );
2679
2680
  /** Assume gid is owned by (gid % size) */
2681
  for ( auto it :requested_gids )
2682
  {
2683
    int p = it % comm_size;
2684
    if ( p != comm_rank ) send_buff[ p ].push_back( it );
2685
  }
2686
2687
  /** Redistribute K. */
2688
  auto &K = *tree.setup.K;
2689
  K.RequestIndices( send_buff );
2690
2691
  mpi::PrintProgress( "[END] ExchangeNeighbors ...", tree.GetComm() );
2692
}; /** end ExchangeNeighbors() */
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
template<bool SYMMETRIC, typename NODE, typename T>
2705
void MergeFarNodes( NODE *node )
2706
{
2707
  /** if I don't have any skeleton, then I'm nobody's far field */
2708
  //if ( !node->data.isskel ) return;
2709
2710
  /**
2711
   *  Examine "Near" interaction list
2712
   */
2713
  //if ( node->isleaf )
2714
  //{
2715
  //   auto & NearMortonIDs = node->NNNearNodeMortonIDs;
2716
  //   #pragma omp critical
2717
  //   {
2718
  //     int rank;
2719
  //     mpi::Comm_rank( MPI_COMM_WORLD, &rank );
2720
  //     string outfile = to_string( rank );
2721
  //     FILE * pFile = fopen( outfile.data(), "a+" );
2722
  //     fprintf( pFile, "(%8lu) ", node->morton );
2723
  //     for ( auto it = NearMortonIDs.begin(); it != NearMortonIDs.end(); it ++ )
2724
  //       fprintf( pFile, "%8lu, ", (*it) );
2725
  //     fprintf( pFile, "\n" ); //fflush( stdout );
2726
  //   }
2727
2728
  //   //auto & NearNodes = node->NNNearNodes;
2729
  //   //for ( auto it = NearNodes.begin(); it != NearNodes.end(); it ++ )
2730
  //   //{
2731
  //   //  if ( !(*it)->NNNearNodes.count( node ) )
2732
  //   //  {
2733
  //   //    printf( "(%8lu) misses %lu\n", (*it)->morton, node->morton ); fflush( stdout );
2734
  //   //  }
2735
  //   //}
2736
  //};
2737
2738
2739
  /** Add my sibling (in the same level) to far interaction lists */
2740
  assert( !node->FarNodeMortonIDs.size() );
2741
  assert( !node->FarNodes.size() );
2742
  node->FarNodeMortonIDs.insert( node->sibling->morton );
2743
  node->FarNodes.insert( node->sibling );
2744
2745
  /** Construct NN far interaction lists */
2746
  if ( node->isleaf )
2747
  {
2748
    FindFarNodes( MortonHelper::Root(), node );
2749
  }
2750
  else
2751
  {
2752
    /** Merge Far( lchild ) and Far( rchild ) from children */
2753
    auto *lchild = node->lchild;
2754
    auto *rchild = node->rchild;
2755
2756
    /** case: NNPRUNE (FMM specific) */
2757
    auto &pNNFarNodes =   node->NNFarNodeMortonIDs;
2758
    auto &lNNFarNodes = lchild->NNFarNodeMortonIDs;
2759
    auto &rNNFarNodes = rchild->NNFarNodeMortonIDs;
2760
2761
    /** Far( parent ) = Far( lchild ) intersects Far( rchild ) */
2762
    for ( auto it  = lNNFarNodes.begin();
2763
               it != lNNFarNodes.end(); it ++ )
2764
    {
2765
      if ( rNNFarNodes.count( *it ) )
2766
      {
2767
        pNNFarNodes.insert( *it );
2768
      }
2769
    }
2770
    /** Far( lchild ) \= Far( parent ); Far( rchild ) \= Far( parent ) */
2771
    for ( auto it  = pNNFarNodes.begin();
2772
               it != pNNFarNodes.end(); it ++ )
2773
    {
2774
      lNNFarNodes.erase( *it );
2775
      rNNFarNodes.erase( *it );
2776
    }
2777
  }
2778
2779
}; /** end MergeFarNodes() */
2780
2781
2782
2783
template<bool SYMMETRIC, typename NODE, typename T>
2784
class MergeFarNodesTask : public Task
2785
{
2786
  public:
2787
2788
    NODE *arg;
2789
2790
    void Set( NODE *user_arg )
2791
    {
2792
      arg = user_arg;
2793
      name = string( "merge" );
2794
      label = to_string( arg->treelist_id );
2795
      /** we don't know the exact cost here */
2796
      cost = 5.0;
2797
      /** high priority */
2798
      priority = true;
2799
    };
2800
2801
		/** read this node and write to children */
2802
    void DependencyAnalysis()
2803
    {
2804
      arg->DependencyAnalysis( RW, this );
2805
      if ( !arg->isleaf )
2806
      {
2807
        arg->lchild->DependencyAnalysis( RW, this );
2808
        arg->rchild->DependencyAnalysis( RW, this );
2809
      }
2810
      this->TryEnqueue();
2811
    };
2812
2813
    void Execute( Worker* user_worker )
2814
    {
2815
      MergeFarNodes<SYMMETRIC, NODE, T>( arg );
2816
    };
2817
2818
}; /** end class MergeFarNodesTask */
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
template<bool SYMMETRIC, typename NODE, typename T>
2832
void DistMergeFarNodes( NODE *node )
2833
{
2834
  /** MPI */
2835
  mpi::Status status;
2836
  mpi::Comm comm = node->GetComm();
2837
  int comm_size = node->GetCommSize();
2838
  int comm_rank = node->GetCommRank();
2839
2840
  /** if I don't have any skeleton, then I'm nobody's far field */
2841
  //if ( !node->data.isskel ) return;
2842
2843
2844
  /** Early return if this is the root node. */
2845
  if ( !node->parent ) return;
2846
2847
  /** Distributed treenode */
2848
  if ( node->GetCommSize() < 2 )
2849
  {
2850
    MergeFarNodes<SYMMETRIC, NODE, T>( node );
2851
  }
2852
  else
2853
  {
2854
    /** merge Far( lchild ) and Far( rchild ) from children */
2855
    auto *child = node->child;
2856
2857
    if ( comm_rank == 0 )
2858
    {
2859
      auto &pNNFarNodes =  node->NNFarNodeMortonIDs;
2860
      auto &lNNFarNodes = child->NNFarNodeMortonIDs;
2861
      vector<size_t> recvFarNodes;
2862
2863
      /** Recv rNNFarNodes */
2864
      mpi::RecvVector( recvFarNodes, comm_size / 2, 0, comm, &status );
2865
2866
      /** Far( parent ) = Far( lchild ) intersects Far( rchild ). */
2867
      for ( auto it : recvFarNodes )
2868
      {
2869
        if ( lNNFarNodes.count( it ) )
2870
        {
2871
          pNNFarNodes.insert( it );
2872
        }
2873
      }
2874
2875
      /** Reuse space to send pNNFarNodes. */
2876
      recvFarNodes.clear();
2877
      recvFarNodes.reserve( pNNFarNodes.size() );
2878
2879
      /** Far( lchild ) \= Far( parent ); Far( rchild ) \= Far( parent ). */
2880
      for ( auto it : pNNFarNodes )
2881
      {
2882
        lNNFarNodes.erase( it );
2883
        recvFarNodes.push_back( it );
2884
      }
2885
2886
      /** Send pNNFarNodes. */
2887
      mpi::SendVector( recvFarNodes, comm_size / 2, 0, comm );
2888
    }
2889
2890
2891
    if ( comm_rank == comm_size / 2 )
2892
    {
2893
      auto &rNNFarNodes = child->NNFarNodeMortonIDs;
2894
      vector<size_t> sendFarNodes( rNNFarNodes.begin(), rNNFarNodes.end() );
2895
2896
      /** Send rNNFarNodes. */
2897
      mpi::SendVector( sendFarNodes, 0, 0, comm );
2898
      /** Recuse sendFarNodes to receive pNNFarNodes. */
2899
      mpi::RecvVector( sendFarNodes, 0, 0, comm, &status );
2900
      /** Far( lchild ) \= Far( parent ); Far( rchild ) \= Far( parent ) */
2901
      for ( auto it : sendFarNodes ) rNNFarNodes.erase( it );
2902
    }
2903
  }
2904
2905
}; /** end DistMergeFarNodes() */
2906
2907
2908
2909
template<bool SYMMETRIC, typename NODE, typename T>
2910
class DistMergeFarNodesTask : public Task
2911
{
2912
  public:
2913
2914
    NODE *arg = NULL;
2915
2916
    void Set( NODE *user_arg )
2917
    {
2918
      arg = user_arg;
2919
      name = string( "dist-merge" );
2920
      label = to_string( arg->treelist_id );
2921
      /** we don't know the exact cost here */
2922
      cost = 5.0;
2923
      /** high priority */
2924
      priority = true;
2925
    };
2926
2927
		/** read this node and write to children */
2928
    void DependencyAnalysis()
2929
    {
2930
      arg->DependencyAnalysis( RW, this );
2931
      if ( !arg->isleaf )
2932
      {
2933
        if ( arg->GetCommSize() > 1 )
2934
        {
2935
          arg->child->DependencyAnalysis( RW, this );
2936
        }
2937
        else
2938
        {
2939
          arg->lchild->DependencyAnalysis( RW, this );
2940
          arg->rchild->DependencyAnalysis( RW, this );
2941
        }
2942
      }
2943
      this->TryEnqueue();
2944
    };
2945
2946
    void Execute( Worker* user_worker )
2947
    {
2948
      DistMergeFarNodes<SYMMETRIC, NODE, T>( arg );
2949
    };
2950
2951
}; /** end class DistMergeFarNodesTask */
2952
2953
2954
2955
2956
2957
2958
/**
2959
 *
2960
 */
2961
template<bool NNPRUNE, typename NODE>
2962
class CacheFarNodesTask : public Task
2963
{
2964
  public:
2965
2966
    NODE *arg = NULL;
2967
2968
    void Set( NODE *user_arg )
2969
    {
2970
      arg = user_arg;
2971
      name = string( "FKIJ" );
2972
      label = to_string( arg->treelist_id );
2973
      /** Compute FLOPS and MOPS. */
2974
      double flops = 0, mops = 0;
2975
      /** We don't know the exact cost here. */
2976
      cost = 5.0;
2977
    };
2978
2979
    void DependencyAnalysis()
2980
    {
2981
      arg->DependencyAnalysis( RW, this );
2982
      this->TryEnqueue();
2983
    };
2984
2985
    void Execute( Worker* user_worker )
2986
    {
2987
      auto *node = arg;
2988
      auto &K = *node->setup->K;
2989
2990
      for ( int p = 0; p < node->DistFar.size(); p ++ )
2991
      {
2992
        for ( auto &it : node->DistFar[ p ] )
2993
        {
2994
          auto *src = (*node->morton2node)[ it.first ];
2995
          auto &I = node->data.skels;
2996
          auto &J = src->data.skels;
2997
          it.second = K( I, J );
2998
          //printf( "Cache I %lu J %lu\n", I.size(), J.size() ); fflush( stdout );
2999
        }
3000
      }
3001
    };
3002
3003
}; /** end class CacheFarNodesTask */
3004
3005
3006
3007
3008
3009
3010
/**
3011
 *
3012
 */
3013
template<bool NNPRUNE, typename NODE>
3014
class CacheNearNodesTask : public Task
3015
{
3016
  public:
3017
3018
    NODE *arg = NULL;
3019
3020
    void Set( NODE *user_arg )
3021
    {
3022
      arg = user_arg;
3023
      name = string( "NKIJ" );
3024
      label = to_string( arg->treelist_id );
3025
      /** We don't know the exact cost here */
3026
      cost = 5.0;
3027
    };
3028
3029
    void DependencyAnalysis()
3030
    {
3031
      arg->DependencyAnalysis( RW, this );
3032
      this->TryEnqueue();
3033
    };
3034
3035
    void Execute( Worker* user_worker )
3036
    {
3037
      auto *node = arg;
3038
      auto &K = *node->setup->K;
3039
3040
      for ( int p = 0; p < node->DistNear.size(); p ++ )
3041
      {
3042
        for ( auto &it : node->DistNear[ p ] )
3043
        {
3044
          auto *src = (*node->morton2node)[ it.first ];
3045
          auto &I = node->gids;
3046
          auto &J = src->gids;
3047
          it.second = K( I, J );
3048
          //printf( "Cache I %lu J %lu\n", I.size(), J.size() ); fflush( stdout );
3049
        }
3050
      }
3051
    };
3052
3053
}; /** end class CacheNearNodesTask */
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
template<typename NODE, typename T>
3067
void DistRowSamples( NODE *node, size_t nsamples )
3068
{
3069
  /** MPI */
3070
  mpi::Comm comm = node->GetComm();
3071
  int size = node->GetCommSize();
3072
  int rank = node->GetCommRank();
3073
3074
  /** gather shared data and create reference */
3075
  auto &K = *node->setup->K;
3076
3077
  /** amap contains nsamples of row gids of K */
3078
  vector<size_t> &I = node->data.candidate_rows;
3079
3080
  /** Clean up candidates from previous iteration */
3081
	I.clear();
3082
3083
  /** Fill-on snids first */
3084
	if ( rank == 0 )
3085
  {
3086
    /** reserve space */
3087
    I.reserve( nsamples );
3088
3089
    auto &snids = node->data.snids;
3090
    multimap<T, size_t> ordered_snids = gofmm::flip_map( snids );
3091
3092
    for ( auto it  = ordered_snids.begin();
3093
               it != ordered_snids.end(); it++ )
3094
    {
3095
      /** (*it) has type pair<T, size_t> */
3096
      I.push_back( (*it).second );
3097
      if ( I.size() >= nsamples ) break;
3098
    }
3099
  }
3100
3101
	/** buffer space */
3102
	vector<size_t> candidates( nsamples );
3103
3104
	size_t n_required = nsamples - I.size();
3105
3106
	/** bcast the termination criteria */
3107
	mpi::Bcast( &n_required, 1, 0, comm );
3108
3109
	while ( n_required )
3110
	{
3111
		if ( rank == 0 )
3112
		{
3113
  	  for ( size_t i = 0; i < nsamples; i ++ )
3114
      {
3115
        auto important_sample = K.ImportantSample( 0 );
3116
        candidates[ i ] =  important_sample.second;
3117
      }
3118
		}
3119
3120
		/** Bcast candidates */
3121
		mpi::Bcast( candidates.data(), candidates.size(), 0, comm );
3122
3123
		/** validation */
3124
		vector<size_t> vconsensus( nsamples, 0 );
3125
	  vector<size_t> validation = node->setup->ContainAny( candidates, node->morton );
3126
3127
		/** reduce validation */
3128
		mpi::Reduce( validation.data(), vconsensus.data(), nsamples, MPI_SUM, 0, comm );
3129
3130
	  if ( rank == 0 )
3131
		{
3132
  	  for ( size_t i = 0; i < nsamples; i ++ )
3133
			{
3134
				/** exit is there is enough samples */
3135
				if ( I.size() >= nsamples )
3136
				{
3137
					I.resize( nsamples );
3138
					break;
3139
				}
3140
				/** Push the candidate to I after validation */
3141
				if ( !vconsensus[ i ] )
3142
				{
3143
					if ( find( I.begin(), I.end(), candidates[ i ] ) == I.end() )
3144
						I.push_back( candidates[ i ] );
3145
				}
3146
			};
3147
3148
			/** Update n_required */
3149
	    n_required = nsamples - I.size();
3150
		}
3151
3152
	  /** Bcast the termination criteria */
3153
	  mpi::Bcast( &n_required, 1, 0, comm );
3154
	}
3155
3156
}; /** end DistRowSamples() */
3157
3158
3159
3160
3161
3162
3163
/** @brief Involve MPI routins */
3164
template<bool NNPRUNE, typename NODE>
3165
void DistSkeletonKIJ( NODE *node )
3166
{
3167
  /** Derive type T from NODE. */
3168
  using T = typename NODE::T;
3169
	/** Early return */
3170
	if ( !node->parent ) return;
3171
  /** Gather shared data and create reference. */
3172
  auto &K = *(node->setup->K);
3173
  /** Gather per node data and create reference. */
3174
  auto &data = node->data;
3175
  auto &candidate_rows = data.candidate_rows;
3176
  auto &candidate_cols = data.candidate_cols;
3177
  auto &KIJ = data.KIJ;
3178
3179
  /** MPI Support. */
3180
  auto comm = node->GetComm();
3181
  auto size = node->GetCommSize();
3182
  auto rank = node->GetCommRank();
3183
  mpi::Status status;
3184
3185
  if ( size < 2 )
3186
  {
3187
    /** This node is the root of the local tree. */
3188
    gofmm::SkeletonKIJ<NNPRUNE>( node );
3189
  }
3190
  else
3191
  {
3192
    /** This node (mpitree::Node) belongs to the distributed tree
3193
     *  only executed by 0th and size/2 th rank of
3194
     *  the node communicator. At this moment, children have been
3195
     *  skeletonized. Thus, we should first update (isskel) to
3196
     *  all MPI processes. Then we gather information for the
3197
     *  skeletonization.
3198
     */
3199
    NODE *child = node->child;
3200
		size_t nsamples = 0;
3201
3202
    /** Bcast (isskel) to all MPI processes using children's communicator. */
3203
    int child_isskel = child->data.isskel;
3204
    mpi::Bcast( &child_isskel, 1, 0, child->GetComm() );
3205
    child->data.isskel = child_isskel;
3206
3207
3208
    /** rank-0 owns data of this node, and it also owns the left child. */
3209
    if ( rank == 0 )
3210
    {
3211
      candidate_cols = child->data.skels;
3212
      vector<size_t> rskel;
3213
      /** Receive rskel from my sibling. */
3214
      mpi::RecvVector( rskel, size / 2, 10, comm, &status );
3215
      /** Correspondingly, we need to redistribute the matrix K. */
3216
      K.RecvIndices( size / 2, comm, &status );
3217
      /** Concatinate [ lskels, rskels ]. */
3218
      candidate_cols.insert( candidate_cols.end(), rskel.begin(), rskel.end() );
3219
			/** Use two times of skeletons */
3220
      nsamples = 2 * candidate_cols.size();
3221
      /** Make sure we at least m samples */
3222
      if ( nsamples < 2 * node->setup->LeafNodeSize() )
3223
        nsamples = 2 * node->setup->LeafNodeSize();
3224
3225
      /** Gather rsnids. */
3226
      auto &lsnids = node->child->data.snids;
3227
      vector<T>      recv_rsdist;
3228
      vector<size_t> recv_rsnids;
3229
3230
      /** Receive rsnids from size / 2. */
3231
      mpi::RecvVector( recv_rsdist, size / 2, 20, comm, &status );
3232
      mpi::RecvVector( recv_rsnids, size / 2, 30, comm, &status );
3233
      /** Correspondingly, we need to redistribute the matrix K. */
3234
      K.RecvIndices( size / 2, comm, &status );
3235
3236
3237
      /** Merge snids and update the smallest distance. */
3238
      auto &snids = node->data.snids;
3239
      snids = lsnids;
3240
3241
      for ( size_t i = 0; i < recv_rsdist.size(); i ++ )
3242
      {
3243
        pair<size_t, T> query( recv_rsnids[ i ], recv_rsdist[ i ] );
3244
        auto ret = snids.insert( query );
3245
        if ( !ret.second )
3246
        {
3247
          if ( ret.first->second > recv_rsdist[ i ] )
3248
            ret.first->second = recv_rsdist[ i ];
3249
        }
3250
      }
3251
3252
      /** Remove gids from snids */
3253
      for ( auto gid : node->gids ) snids.erase( gid );
3254
    }
3255
3256
    if ( rank == size / 2 )
3257
    {
3258
      /** Send rskel to rank 0. */
3259
      mpi::SendVector( child->data.skels, 0, 10, comm );
3260
      /** Correspondingly, we need to redistribute the matrix K. */
3261
      K.SendIndices( child->data.skels, 0, comm );
3262
3263
      /** Gather rsnids */
3264
      auto &rsnids = node->child->data.snids;
3265
      vector<T>      send_rsdist;
3266
      vector<size_t> send_rsnids;
3267
3268
      /** reserve space and push in from map */
3269
      send_rsdist.reserve( rsnids.size() );
3270
      send_rsnids.reserve( rsnids.size() );
3271
3272
      for ( auto it = rsnids.begin(); it != rsnids.end(); it ++ )
3273
      {
3274
        /** (*it) has type std::pair<size_t, T>  */
3275
        send_rsnids.push_back( (*it).first  );
3276
        send_rsdist.push_back( (*it).second );
3277
      }
3278
3279
      /** send rsnids to rank-0 */
3280
      mpi::SendVector( send_rsdist, 0, 20, comm );
3281
      mpi::SendVector( send_rsnids, 0, 30, comm );
3282
3283
      /** Correspondingly, we need to redistribute the matrix K. */
3284
      K.SendIndices( send_rsnids, 0, comm );
3285
    }
3286
3287
		/** Bcast nsamples. */
3288
		mpi::Bcast( &nsamples, 1, 0, comm );
3289
		/** Distributed row samples. */
3290
		DistRowSamples<NODE, T>( node, nsamples );
3291
    /** only rank-0 has non-empty I and J sets */
3292
    if ( rank != 0 )
3293
    {
3294
      assert( !candidate_rows.size() );
3295
      assert( !candidate_cols.size() );
3296
    }
3297
    /**
3298
     *  Now rank-0 has the correct ( I, J ). All other ranks in the
3299
     *  communicator must flush their I and J sets before evaluation.
3300
     *  all MPI process must participate in operator ()
3301
     */
3302
    KIJ = K( candidate_rows, candidate_cols );
3303
  }
3304
}; /** end DistSkeletonKIJ() */
3305
3306
3307
/**
3308
 *
3309
 */
3310
template<bool NNPRUNE, typename NODE, typename T>
3311
class DistSkeletonKIJTask : public Task
3312
{
3313
  public:
3314
3315
    NODE *arg = NULL;
3316
3317
    void Set( NODE *user_arg )
3318
    {
3319
      arg = user_arg;
3320
      name = string( "par-gskm" );
3321
      label = to_string( arg->treelist_id );
3322
      /** We don't know the exact cost here */
3323
      cost = 5.0;
3324
      /** "High" priority */
3325
      priority = true;
3326
    };
3327
3328
    void DependencyAnalysis() { arg->DependOnChildren( this ); };
3329
3330
    void Execute( Worker* user_worker ) { DistSkeletonKIJ<NNPRUNE>( arg ); };
3331
3332
}; /** end class DistSkeletonKIJTask */
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
/**
3346
 *  @brief Skeletonization with interpolative decomposition.
3347
 */
3348
template<typename NODE, typename T>
3349
void DistSkeletonize( NODE *node )
3350
{
3351
  /** early return if we do not need to skeletonize */
3352
  if ( !node->parent ) return;
3353
3354
  /** gather shared data and create reference */
3355
  auto &K   = *(node->setup->K);
3356
  auto &NN  = *(node->setup->NN);
3357
  auto maxs = node->setup->MaximumRank();
3358
  auto stol = node->setup->Tolerance();
3359
  bool secure_accuracy = node->setup->SecureAccuracy();
3360
  bool use_adaptive_ranks = node->setup->UseAdaptiveRanks();
3361
3362
  /** gather per node data and create reference */
3363
  auto &data  = node->data;
3364
  auto &skels = data.skels;
3365
  auto &proj  = data.proj;
3366
  auto &jpvt  = data.jpvt;
3367
  auto &KIJ   = data.KIJ;
3368
  auto &candidate_cols = data.candidate_cols;
3369
3370
  /** interpolative decomposition */
3371
  size_t N = K.col();
3372
  size_t m = KIJ.row();
3373
  size_t n = KIJ.col();
3374
  size_t q = node->n;
3375
3376
  if ( secure_accuracy )
3377
  {
3378
    /** TODO: need to check of both children's isskel to preceed */
3379
  }
3380
3381
3382
  /** Bill's l2 norm scaling factor */
3383
  T scaled_stol = std::sqrt( (T)n / q ) * std::sqrt( (T)m / (N - q) ) * stol;
3384
3385
  /** account for uniform sampling */
3386
  scaled_stol *= std::sqrt( (T)q / N );
3387
3388
  lowrank::id
3389
  (
3390
    use_adaptive_ranks, secure_accuracy,
3391
    KIJ.row(), KIJ.col(), maxs, scaled_stol,
3392
    KIJ, skels, proj, jpvt
3393
  );
3394
3395
  /** Free KIJ for spaces */
3396
  KIJ.resize( 0, 0 );
3397
  KIJ.shrink_to_fit();
3398
3399
  /** depending on the flag, decide isskel or not */
3400
  if ( secure_accuracy )
3401
  {
3402
    /** TODO: this needs to be bcast to other nodes */
3403
    data.isskel = (skels.size() != 0);
3404
  }
3405
  else
3406
  {
3407
    assert( skels.size() );
3408
    assert( proj.size() );
3409
    assert( jpvt.size() );
3410
    data.isskel = true;
3411
  }
3412
3413
  /** Relabel skeletions with the real gids */
3414
  for ( size_t i = 0; i < skels.size(); i ++ )
3415
  {
3416
    skels[ i ] = candidate_cols[ skels[ i ] ];
3417
  }
3418
3419
3420
}; /** end DistSkeletonize() */
3421
3422
3423
3424
3425
template<typename NODE, typename T>
3426
class SkeletonizeTask : public hmlp::Task
3427
{
3428
  public:
3429
3430
    NODE *arg;
3431
3432
    void Set( NODE *user_arg )
3433
    {
3434
      arg = user_arg;
3435
      name = string( "SK" );
3436
      label = to_string( arg->treelist_id );
3437
      /** We don't know the exact cost here */
3438
      cost = 5.0;
3439
      /** "High" priority */
3440
      priority = true;
3441
    };
3442
3443
    void GetEventRecord()
3444
    {
3445
      double flops = 0.0, mops = 0.0;
3446
3447
      auto &K = *arg->setup->K;
3448
      size_t n = arg->data.proj.col();
3449
      size_t m = 2 * n;
3450
      size_t k = arg->data.proj.row();
3451
3452
      /** GEQP3 */
3453
      flops += ( 4.0 / 3.0 ) * n * n * ( 3 * m - n );
3454
      mops += ( 2.0 / 3.0 ) * n * n * ( 3 * m - n );
3455
3456
      /* TRSM */
3457
      flops += k * ( k - 1 ) * ( n + 1 );
3458
      mops  += 2.0 * ( k * k + k * n );
3459
3460
      event.Set( label + name, flops, mops );
3461
      arg->data.skeletonize = event;
3462
    };
3463
3464
    void DependencyAnalysis()
3465
    {
3466
      arg->DependencyAnalysis( RW, this );
3467
      this->TryEnqueue();
3468
    };
3469
3470
    void Execute( Worker* user_worker )
3471
    {
3472
      //printf( "%d Par-Skel beg\n", global_rank );
3473
3474
      DistSkeletonize<NODE, T>( arg );
3475
3476
      //printf( "%d Par-Skel end\n", global_rank );
3477
    };
3478
3479
}; /** end class SkeletonTask */
3480
3481
3482
3483
3484
/**
3485
 *
3486
 */
3487
template<typename NODE, typename T>
3488
class DistSkeletonizeTask : public hmlp::Task
3489
{
3490
  public:
3491
3492
    NODE *arg;
3493
3494
    void Set( NODE *user_arg )
3495
    {
3496
      arg = user_arg;
3497
      name = string( "PSK" );
3498
      label = to_string( arg->treelist_id );
3499
3500
      /** We don't know the exact cost here */
3501
      cost = 5.0;
3502
      /** "High" priority */
3503
      priority = true;
3504
    };
3505
3506
    void GetEventRecord()
3507
    {
3508
      double flops = 0.0, mops = 0.0;
3509
3510
      auto &K = *arg->setup->K;
3511
      size_t n = arg->data.proj.col();
3512
      size_t m = 2 * n;
3513
      size_t k = arg->data.proj.row();
3514
3515
			if ( arg->GetCommRank() == 0 )
3516
			{
3517
        /** GEQP3 */
3518
        flops += ( 4.0 / 3.0 ) * n * n * ( 3 * m - n );
3519
        mops += ( 2.0 / 3.0 ) * n * n * ( 3 * m - n );
3520
3521
        /* TRSM */
3522
        flops += k * ( k - 1 ) * ( n + 1 );
3523
        mops  += 2.0 * ( k * k + k * n );
3524
			}
3525
3526
      event.Set( label + name, flops, mops );
3527
      arg->data.skeletonize = event;
3528
    };
3529
3530
    void DependencyAnalysis()
3531
    {
3532
      arg->DependencyAnalysis( RW, this );
3533
      this->TryEnqueue();
3534
    };
3535
3536
    void Execute( Worker* user_worker )
3537
    {
3538
      mpi::Comm comm = arg->GetComm();
3539
3540
      double beg = omp_get_wtime();
3541
      if ( arg->GetCommRank() == 0 )
3542
      {
3543
        DistSkeletonize<NODE, T>( arg );
3544
      }
3545
      double skel_t = omp_get_wtime() - beg;
3546
3547
			/** Bcast isskel to every MPI processes in the same comm */
3548
			int isskel = arg->data.isskel;
3549
			mpi::Bcast( &isskel, 1, 0, comm );
3550
			arg->data.isskel = isskel;
3551
3552
      /** Bcast skels and proj to every MPI processes in the same comm */
3553
      auto &skels = arg->data.skels;
3554
      size_t nskels = skels.size();
3555
      mpi::Bcast( &nskels, 1, 0, comm );
3556
      if ( skels.size() != nskels ) skels.resize( nskels );
3557
      mpi::Bcast( skels.data(), skels.size(), 0, comm );
3558
3559
    };
3560
3561
}; /** end class DistSkeletonTask */
3562
3563
3564
3565
3566
/**
3567
 *  @brief
3568
 */
3569
template<typename NODE>
3570
class InterpolateTask : public Task
3571
{
3572
  public:
3573
3574
    NODE *arg = NULL;
3575
3576
    void Set( NODE *user_arg )
3577
    {
3578
      arg = user_arg;
3579
      name = string( "PROJ" );
3580
      label = to_string( arg->treelist_id );
3581
      // Need an accurate cost model.
3582
      cost = 1.0;
3583
    };
3584
3585
    void DependencyAnalysis() { arg->DependOnNoOne( this ); };
3586
3587
    void Execute( Worker* user_worker )
3588
    {
3589
      /** MPI Support. */
3590
      auto comm = arg->GetComm();
3591
      /** Only executed by rank 0. */
3592
      if ( arg->GetCommRank() == 0 ) gofmm::Interpolate( arg );
3593
3594
      auto &proj  = arg->data.proj;
3595
      size_t nrow  = proj.row();
3596
      size_t ncol  = proj.col();
3597
      mpi::Bcast( &nrow, 1, 0, comm );
3598
      mpi::Bcast( &ncol, 1, 0, comm );
3599
      if ( proj.row() != nrow || proj.col() != ncol ) proj.resize( nrow, ncol );
3600
      mpi::Bcast( proj.data(), proj.size(), 0, comm );
3601
    };
3602
3603
}; /** end class InterpolateTask */
3604
3605
3606
3607
3608
3609
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
3627
3628
3629
3630
3631
3632
3633
3634
/**
3635
 *  @brief ComputeAll
3636
 */
3637
template<bool NNPRUNE = true, typename TREE, typename T>
3638
DistData<RIDS, STAR, T> Evaluate( TREE &tree, DistData<RIDS, STAR, T> &weights )
3639
{
3640
  try
3641
  {
3642
    /** MPI Support. */
3643
    int size; mpi::Comm_size( tree.GetComm(), &size );
3644
    int rank; mpi::Comm_rank( tree.GetComm(), &rank );
3645
    /** Derive type NODE and MPINODE from TREE. */
3646
    using NODE    = typename TREE::NODE;
3647
    using MPINODE = typename TREE::MPINODE;
3648
3649
    /** All timers */
3650
    double beg, time_ratio, evaluation_time = 0.0;
3651
    double direct_evaluation_time = 0.0, computeall_time, telescope_time, let_exchange_time, async_time;
3652
    double overhead_time;
3653
    double forward_permute_time, backward_permute_time;
3654
3655
    /** Clean up all r/w dependencies left on tree nodes. */
3656
    tree.DependencyCleanUp();
3657
3658
    /** n-by-nrhs, initialize potentials. */
3659
    size_t n    = weights.row();
3660
    size_t nrhs = weights.col();
3661
3662
    /** Potentials must be in [RIDS,STAR] distribution */
3663
    auto &gids_owned = tree.treelist[ 0 ]->gids;
3664
    DistData<RIDS, STAR, T> potentials( n, nrhs, gids_owned, tree.GetComm() );
3665
    potentials.setvalue( 0.0 );
3666
3667
    /** Provide pointers. */
3668
    tree.setup.w = &weights;
3669
    tree.setup.u = &potentials;
3670
3671
    /** TreeView (downward traversal) */
3672
    gofmm::TreeViewTask<NODE>           seqVIEWtask;
3673
    mpigofmm::DistTreeViewTask<MPINODE> mpiVIEWtask;
3674
    /** Telescope (upward traversal) */
3675
    gofmm::UpdateWeightsTask<NODE, T>           seqN2Stask;
3676
    mpigofmm::DistUpdateWeightsTask<MPINODE, T> mpiN2Stask;
3677
    /** L2L (sum of direct evaluations) */
3678
    //mpigofmm::DistLeavesToLeavesTask<NNPRUNE, NODE, T> seqL2Ltask;
3679
    //mpigofmm::L2LReduceTask<NODE, T> seqL2LReducetask;
3680
    mpigofmm::L2LReduceTask2<NODE, T> seqL2LReducetask2;
3681
    /** S2S (sum of low-rank approximation) */
3682
    //gofmm::SkeletonsToSkeletonsTask<NNPRUNE, NODE, T>           seqS2Stask;
3683
    //mpigofmm::DistSkeletonsToSkeletonsTask<NNPRUNE, MPINODE, T> mpiS2Stask;
3684
    //mpigofmm::S2SReduceTask<NODE, T>    seqS2SReducetask;
3685
    //mpigofmm::S2SReduceTask<MPINODE, T> mpiS2SReducetask;
3686
    mpigofmm::S2SReduceTask2<NODE, NODE, T>    seqS2SReducetask2;
3687
    mpigofmm::S2SReduceTask2<MPINODE, NODE, T> mpiS2SReducetask2;
3688
    /** Telescope (downward traversal) */
3689
    gofmm::SkeletonsToNodesTask<NNPRUNE, NODE, T>           seqS2Ntask;
3690
    mpigofmm::DistSkeletonsToNodesTask<NNPRUNE, MPINODE, T> mpiS2Ntask;
3691
3692
      /** Global barrier and timer */
3693
      mpi::Barrier( tree.GetComm() );
3694
3695
      //{
3696
      //  /** Stage 1: TreeView and upward telescoping */
3697
      //  beg = omp_get_wtime();
3698
      //  tree.DependencyCleanUp();
3699
      //  tree.DistTraverseDown( mpiVIEWtask );
3700
      //  tree.LocaTraverseDown( seqVIEWtask );
3701
      //  tree.LocaTraverseUp( seqN2Stask );
3702
      //  tree.DistTraverseUp( mpiN2Stask );
3703
      //  hmlp_run();
3704
      //  mpi::Barrier( tree.GetComm() );
3705
      //  telescope_time = omp_get_wtime() - beg;
3706
3707
      //  /** Stage 2: LET exchange */
3708
      //  beg = omp_get_wtime();
3709
      //  ExchangeLET<T>( tree, string( "skelweights" ) );
3710
      //  mpi::Barrier( tree.GetComm() );
3711
      //  ExchangeLET<T>( tree, string( "leafweights" ) );
3712
      //  mpi::Barrier( tree.GetComm() );
3713
      //  let_exchange_time = omp_get_wtime() - beg;
3714
3715
      //  /** Stage 3: L2L */
3716
      //  beg = omp_get_wtime();
3717
      //  tree.DependencyCleanUp();
3718
      //  tree.LocaTraverseLeafs( seqL2LReducetask2 );
3719
      //  hmlp_run();
3720
      //  mpi::Barrier( tree.GetComm() );
3721
      //  direct_evaluation_time = omp_get_wtime() - beg;
3722
3723
      //  /** Stage 4: S2S and downward telescoping */
3724
      //  beg = omp_get_wtime();
3725
      //  tree.DependencyCleanUp();
3726
      //  tree.LocaTraverseUnOrdered( seqS2SReducetask2 );
3727
      //  tree.DistTraverseUnOrdered( mpiS2SReducetask2 );
3728
      //  tree.DistTraverseDown( mpiS2Ntask );
3729
      //  tree.LocaTraverseDown( seqS2Ntask );
3730
      //  hmlp_run();
3731
      //  mpi::Barrier( tree.GetComm() );
3732
      //  computeall_time = omp_get_wtime() - beg;
3733
      //}
3734
3735
3736
    /** Global barrier and timer */
3737
    potentials.setvalue( 0.0 );
3738
    mpi::Barrier( tree.GetComm() );
3739
3740
    /** Stage 1: TreeView and upward telescoping */
3741
    beg = omp_get_wtime();
3742
    tree.DependencyCleanUp();
3743
    tree.DistTraverseDown( mpiVIEWtask );
3744
    tree.LocaTraverseDown( seqVIEWtask );
3745
    tree.ExecuteAllTasks();
3746
    /** Stage 2: redistribute weights from IDS to LET. */
3747
    AsyncExchangeLET<T>( tree, string( "leafweights" ) );
3748
    /** Stage 3: N2S. */
3749
    tree.LocaTraverseUp( seqN2Stask );
3750
    tree.DistTraverseUp( mpiN2Stask );
3751
    /** Stage 4: redistribute skeleton weights from IDS to LET. */
3752
    AsyncExchangeLET<T>( tree, string( "skelweights" ) );
3753
    /** Stage 5: L2L */
3754
    tree.LocaTraverseLeafs( seqL2LReducetask2 );
3755
    /** Stage 6: S2S */
3756
    tree.LocaTraverseUnOrdered( seqS2SReducetask2 );
3757
    tree.DistTraverseUnOrdered( mpiS2SReducetask2 );
3758
    /** Stage 7: S2N */
3759
    tree.DistTraverseDown( mpiS2Ntask );
3760
    tree.LocaTraverseDown( seqS2Ntask );
3761
    overhead_time = omp_get_wtime() - beg;
3762
    tree.ExecuteAllTasks();
3763
    async_time = omp_get_wtime() - beg;
3764
3765
3766
3767
    /** Compute the breakdown cost */
3768
    evaluation_time += direct_evaluation_time;
3769
    evaluation_time += telescope_time;
3770
    evaluation_time += let_exchange_time;
3771
    evaluation_time += computeall_time;
3772
    time_ratio = 100 / evaluation_time;
3773
3774
    if ( rank == 0 && REPORT_EVALUATE_STATUS )
3775
    {
3776
      printf( "========================================================\n");
3777
      printf( "GOFMM evaluation phase\n" );
3778
      printf( "========================================================\n");
3779
      //printf( "Allocate ------------------------------ %5.2lfs (%5.1lf%%)\n",
3780
      //    allocate_time, allocate_time * time_ratio );
3781
      //printf( "Forward permute ----------------------- %5.2lfs (%5.1lf%%)\n",
3782
      //    forward_permute_time, forward_permute_time * time_ratio );
3783
      printf( "Upward telescope ---------------------- %5.2lfs (%5.1lf%%)\n",
3784
          telescope_time, telescope_time * time_ratio );
3785
      printf( "LET exchange -------------------------- %5.2lfs (%5.1lf%%)\n",
3786
          let_exchange_time, let_exchange_time * time_ratio );
3787
      printf( "L2L ----------------------------------- %5.2lfs (%5.1lf%%)\n",
3788
          direct_evaluation_time, direct_evaluation_time * time_ratio );
3789
      printf( "S2S, S2N ------------------------------ %5.2lfs (%5.1lf%%)\n",
3790
          computeall_time, computeall_time * time_ratio );
3791
      //printf( "Backward permute ---------------------- %5.2lfs (%5.1lf%%)\n",
3792
      //    backward_permute_time, backward_permute_time * time_ratio );
3793
      printf( "========================================================\n");
3794
      printf( "Evaluate ------------------------------ %5.2lfs (%5.1lf%%)\n",
3795
          evaluation_time, evaluation_time * time_ratio );
3796
      printf( "Evaluate (Async) ---------------------- %5.2lfs (%5.2lfs)\n",
3797
          async_time, overhead_time );
3798
      printf( "========================================================\n\n");
3799
    }
3800
3801
    return potentials;
3802
  }
3803
  catch ( const exception & e )
3804
  {
3805
    cout << e.what() << endl;
3806
    exit( 1 );
3807
  }
3808
}; /** end Evaluate() */
3809
3810
3811
3812
3813
template<bool NNPRUNE = true, typename TREE, typename T>
3814
DistData<RBLK, STAR, T> Evaluate( TREE &tree, DistData<RBLK, STAR, T> &w_rblk )
3815
{
3816
  size_t n    = w_rblk.row();
3817
  size_t nrhs = w_rblk.col();
3818
  /** Redistribute weights from RBLK to RIDS. */
3819
  DistData<RIDS, STAR, T> w_rids( n, nrhs, tree.treelist[ 0 ]->gids, tree.GetComm() );
3820
  w_rids = w_rblk;
3821
  /** Evaluation with RIDS distribution. */
3822
  auto u_rids = Evaluate<NNPRUNE>( tree, w_rids );
3823
  mpi::Barrier( tree.GetComm() );
3824
  /** Redistribute potentials from RIDS to RBLK. */
3825
  DistData<RBLK, STAR, T> u_rblk( n, nrhs, tree.GetComm() );
3826
  u_rblk = u_rids;
3827
  /** Return potentials in RBLK distribution. */
3828
  return u_rblk;
3829
}; /** end Evaluate() */
3830
3831
3832
3833
template<typename SPLITTER, typename T, typename SPDMATRIX>
3834
DistData<STAR, CBLK, pair<T, size_t>> FindNeighbors
3835
(
3836
  SPDMATRIX &K,
3837
  SPLITTER splitter,
3838
	gofmm::Configuration<T> &config,
3839
  mpi::Comm CommGOFMM,
3840
  size_t n_iter = 10
3841
)
3842
{
3843
  /** Instantiation for the randomized metric tree. */
3844
  using DATA  = gofmm::NodeData<T>;
3845
  using SETUP = mpigofmm::Setup<SPDMATRIX, SPLITTER, T>;
3846
  using TREE  = mpitree::Tree<SETUP, DATA>;
3847
  /** Derive type NODE from TREE. */
3848
  using NODE  = typename TREE::NODE;
3849
  /** Get all user-defined parameters. */
3850
  DistanceMetric metric = config.MetricType();
3851
  size_t n = config.ProblemSize();
3852
	size_t k = config.NeighborSize();
3853
  /** Iterative all nearnest-neighbor (ANN). */
3854
  pair<T, size_t> init( numeric_limits<T>::max(), n );
3855
  gofmm::NeighborsTask<NODE, T> NEIGHBORStask;
3856
  TREE rkdt( CommGOFMM );
3857
  rkdt.setup.FromConfiguration( config, K, splitter, NULL );
3858
  return rkdt.AllNearestNeighbor( n_iter, n, k, init, NEIGHBORStask );
3859
}; /** end FindNeighbors() */
3860
3861
3862
3863
3864
3865
3866
3867
3868
3869
/**
3870
 *  @brief template of the compress routine
3871
 */
3872
template<typename SPLITTER, typename RKDTSPLITTER, typename T, typename SPDMATRIX>
3873
mpitree::Tree<mpigofmm::Setup<SPDMATRIX, SPLITTER, T>, gofmm::NodeData<T>>
3874
*Compress
3875
(
3876
  SPDMATRIX &K,
3877
  DistData<STAR, CBLK, pair<T, size_t>> &NN_cblk,
3878
  SPLITTER splitter,
3879
  RKDTSPLITTER rkdtsplitter,
3880
	gofmm::Configuration<T> &config,
3881
  mpi::Comm CommGOFMM
3882
)
3883
{
3884
  try
3885
  {
3886
    /** MPI size ane rank. */
3887
    int size; mpi::Comm_size( CommGOFMM, &size );
3888
    int rank; mpi::Comm_rank( CommGOFMM, &rank );
3889
3890
    /** Get all user-defined parameters. */
3891
    DistanceMetric metric = config.MetricType();
3892
    size_t n = config.ProblemSize();
3893
	  size_t m = config.LeafNodeSize();
3894
	  size_t k = config.NeighborSize();
3895
	  size_t s = config.MaximumRank();
3896
3897
    /** options */
3898
    const bool SYMMETRIC = true;
3899
    const bool NNPRUNE   = true;
3900
    const bool CACHE     = true;
3901
3902
    /** Instantiation for the GOFMM metric tree. */
3903
    using SETUP   = mpigofmm::Setup<SPDMATRIX, SPLITTER, T>;
3904
    using DATA    = gofmm::NodeData<T>;
3905
    using TREE    = mpitree::Tree<SETUP, DATA>;
3906
    /** Derive type NODE and MPINODE from TREE. */
3907
    using NODE    = typename TREE::NODE;
3908
    using MPINODE = typename TREE::MPINODE;
3909
3910
    /** All timers. */
3911
    double beg, omptask45_time, omptask_time, ref_time;
3912
    double time_ratio, compress_time = 0.0, other_time = 0.0;
3913
    double ann_time, tree_time, skel_time, mpi_skel_time, mergefarnodes_time, cachefarnodes_time;
3914
    double local_skel_time, dist_skel_time, let_time;
3915
    double nneval_time, nonneval_time, fmm_evaluation_time, symbolic_evaluation_time;
3916
    double exchange_neighbor_time, symmetrize_time;
3917
3918
    /** Iterative all nearnest-neighbor (ANN). */
3919
    beg = omp_get_wtime();
3920
    if ( k && NN_cblk.row() * NN_cblk.col() != k * n )
3921
    {
3922
      NN_cblk = mpigofmm::FindNeighbors( K, rkdtsplitter,
3923
          config, CommGOFMM );
3924
    }
3925
    ann_time = omp_get_wtime() - beg;
3926
3927
    /** Initialize metric ball tree using approximate center split. */
3928
    auto *tree_ptr = new TREE( CommGOFMM );
3929
	  auto &tree = *tree_ptr;
3930
3931
	  /** Global configuration for the metric tree. */
3932
    tree.setup.FromConfiguration( config, K, splitter, &NN_cblk );
3933
3934
	  /** Metric ball tree partitioning. */
3935
    beg = omp_get_wtime();
3936
    tree.TreePartition();
3937
    tree_time = omp_get_wtime() - beg;
3938
3939
    /** Get tree permutataion. */
3940
    vector<size_t> perm = tree.GetPermutation();
3941
    if ( rank == 0 )
3942
    {
3943
      ofstream perm_file( "perm.txt" );
3944
      for ( auto &id : perm ) perm_file << id << " ";
3945
      perm_file.close();
3946
    }
3947
3948
3949
    /** Redistribute neighbors i.e. NN[ *, CIDS ] = NN[ *, CBLK ]; */
3950
    DistData<STAR, CIDS, pair<T, size_t>> NN( k, n, tree.treelist[ 0 ]->gids, tree.GetComm() );
3951
    NN = NN_cblk;
3952
    tree.setup.NN = &NN;
3953
    beg = omp_get_wtime();
3954
    ExchangeNeighbors<T>( tree );
3955
    exchange_neighbor_time = omp_get_wtime() - beg;
3956
3957
3958
    beg = omp_get_wtime();
3959
    /** Construct near interaction lists. */
3960
    FindNearInteractions( tree );
3961
    /** Symmetrize interaction pairs by Alltoallv. */
3962
    mpigofmm::SymmetrizeNearInteractions( tree );
3963
    /** Split node interaction lists per MPI rank. */
3964
    BuildInteractionListPerRank( tree, true );
3965
    /** Exchange {leafs} and {paramsleafs)}.  */
3966
    ExchangeLET( tree, string( "leafgids" ) );
3967
    symmetrize_time = omp_get_wtime() - beg;
3968
3969
3970
    /** Find and merge far interactions. */
3971
    mpi::PrintProgress( "[BEG] MergeFarNodes ...", tree.GetComm() );
3972
    beg = omp_get_wtime();
3973
    tree.DependencyCleanUp();
3974
    MergeFarNodesTask<true, NODE, T> seqMERGEtask;
3975
    DistMergeFarNodesTask<true, MPINODE, T> mpiMERGEtask;
3976
    tree.LocaTraverseUp( seqMERGEtask );
3977
    tree.DistTraverseUp( mpiMERGEtask );
3978
    tree.ExecuteAllTasks();
3979
    mergefarnodes_time += omp_get_wtime() - beg;
3980
    mpi::PrintProgress( "[END] MergeFarNodes ...", tree.GetComm() );
3981
3982
    /** Symmetrize interaction pairs by Alltoallv. */
3983
    beg = omp_get_wtime();
3984
    mpigofmm::SymmetrizeFarInteractions( tree );
3985
    /** Split node interaction lists per MPI rank. */
3986
    BuildInteractionListPerRank( tree, false );
3987
    symmetrize_time += omp_get_wtime() - beg;
3988
3989
    mpi::PrintProgress( "[BEG] Skeletonization ...", tree.GetComm() );
3990
    /** Skeletonization */
3991
	  beg = omp_get_wtime();
3992
    tree.DependencyCleanUp();
3993
    /** Gather sample rows and skeleton columns, then ID */
3994
    gofmm::SkeletonKIJTask<NNPRUNE, NODE, T> seqGETMTXtask;
3995
    mpigofmm::DistSkeletonKIJTask<NNPRUNE, MPINODE, T> mpiGETMTXtask;
3996
    mpigofmm::SkeletonizeTask<NODE, T> seqSKELtask;
3997
    mpigofmm::DistSkeletonizeTask<MPINODE, T> mpiSKELtask;
3998
    tree.LocaTraverseUp( seqGETMTXtask, seqSKELtask );
3999
    //tree.DistTraverseUp( mpiGETMTXtask, mpiSKELtask );
4000
    /** Compute the coefficient matrix of ID */
4001
    gofmm::InterpolateTask<NODE> seqPROJtask;
4002
    mpigofmm::InterpolateTask<MPINODE> mpiPROJtask;
4003
    tree.LocaTraverseUnOrdered( seqPROJtask );
4004
    //tree.DistTraverseUnOrdered( mpiPROJtask );
4005
4006
    /** Cache near KIJ interactions */
4007
    mpigofmm::CacheNearNodesTask<NNPRUNE, NODE> seqNEARKIJtask;
4008
    //tree.LocaTraverseLeafs( seqNEARKIJtask );
4009
4010
    tree.ExecuteAllTasks();
4011
    skel_time = omp_get_wtime() - beg;
4012
4013
	  beg = omp_get_wtime();
4014
    tree.DistTraverseUp( mpiGETMTXtask, mpiSKELtask );
4015
    tree.DistTraverseUnOrdered( mpiPROJtask );
4016
    tree.ExecuteAllTasks();
4017
    mpi_skel_time = omp_get_wtime() - beg;
4018
    mpi::PrintProgress( "[END] Skeletonization ...", tree.GetComm() );
4019
4020
4021
4022
    /** Exchange {skels} and {params(skels)}.  */
4023
    ExchangeLET( tree, string( "skelgids" ) );
4024
4025
    beg = omp_get_wtime();
4026
    /** Cache near KIJ interactions */
4027
    //mpigofmm::CacheNearNodesTask<NNPRUNE, NODE> seqNEARKIJtask;
4028
    //tree.LocaTraverseLeafs( seqNEARKIJtask );
4029
    /** Cache far KIJ interactions */
4030
    mpigofmm::CacheFarNodesTask<NNPRUNE,    NODE> seqFARKIJtask;
4031
    mpigofmm::CacheFarNodesTask<NNPRUNE, MPINODE> mpiFARKIJtask;
4032
    //tree.LocaTraverseUnOrdered( seqFARKIJtask );
4033
    //tree.DistTraverseUnOrdered( mpiFARKIJtask );
4034
    cachefarnodes_time = omp_get_wtime() - beg;
4035
    tree.ExecuteAllTasks();
4036
    cachefarnodes_time = omp_get_wtime() - beg;
4037
4038
4039
4040
    /** Compute the ratio of exact evaluation. */
4041
    auto ratio = NonCompressedRatio( tree );
4042
4043
    double exact_ratio = (double) m / n;
4044
4045
    if ( rank == 0 && REPORT_COMPRESS_STATUS )
4046
    {
4047
      compress_time += ann_time;
4048
      compress_time += tree_time;
4049
      compress_time += exchange_neighbor_time;
4050
      compress_time += symmetrize_time;
4051
      compress_time += skel_time;
4052
      compress_time += mpi_skel_time;
4053
      compress_time += mergefarnodes_time;
4054
      compress_time += cachefarnodes_time;
4055
      time_ratio = 100.0 / compress_time;
4056
      printf( "========================================================\n");
4057
      printf( "GOFMM compression phase\n" );
4058
      printf( "========================================================\n");
4059
      printf( "NeighborSearch ------------------------ %5.2lfs (%5.1lf%%)\n", ann_time, ann_time * time_ratio );
4060
      printf( "TreePartitioning ---------------------- %5.2lfs (%5.1lf%%)\n", tree_time, tree_time * time_ratio );
4061
      printf( "ExchangeNeighbors --------------------- %5.2lfs (%5.1lf%%)\n", exchange_neighbor_time, exchange_neighbor_time * time_ratio );
4062
      printf( "MergeFarNodes ------------------------- %5.2lfs (%5.1lf%%)\n", mergefarnodes_time, mergefarnodes_time * time_ratio );
4063
      printf( "Symmetrize ---------------------------- %5.2lfs (%5.1lf%%)\n", symmetrize_time, symmetrize_time * time_ratio );
4064
      printf( "Skeletonization (HMLP Runtime   ) ----- %5.2lfs (%5.1lf%%)\n", skel_time, skel_time * time_ratio );
4065
      printf( "Skeletonization (MPI            ) ----- %5.2lfs (%5.1lf%%)\n", mpi_skel_time, mpi_skel_time * time_ratio );
4066
      printf( "Cache KIJ ----------------------------- %5.2lfs (%5.1lf%%)\n", cachefarnodes_time, cachefarnodes_time * time_ratio );
4067
      printf( "========================================================\n");
4068
      printf( "%5.3lf%% and %5.3lf%% uncompressed--------- %5.2lfs (%5.1lf%%)\n",
4069
          100 * ratio.first, 100 * ratio.second, compress_time, compress_time * time_ratio );
4070
      printf( "========================================================\n\n");
4071
    }
4072
4073
    /** Cleanup all w/r dependencies on tree nodes */
4074
    tree_ptr->DependencyCleanUp();
4075
    /** Global barrier to make sure all processes have completed */
4076
    mpi::Barrier( tree.GetComm() );
4077
4078
    return tree_ptr;
4079
  }
4080
  catch ( const exception & e )
4081
  {
4082
    cout << e.what() << endl;
4083
    exit( 1 );
4084
  }
4085
}; /** end Compress() */
4086
4087
4088
4089
template<typename TREE, typename T>
4090
pair<T, T> ComputeError( TREE &tree, size_t gid, Data<T> potentials )
4091
{
4092
  int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
4093
  int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
4094
4095
  /** ( sum of square errors, square 2-norm of true values ) */
4096
  pair<T, T> ret( 0, 0 );
4097
4098
  auto &K = *tree.setup.K;
4099
  auto &w = *tree.setup.w;
4100
4101
  auto  I = vector<size_t>( 1, gid );
4102
  auto &J = tree.treelist[ 0 ]->gids;
4103
4104
  /** Bcast gid and its parameter to all MPI processes. */
4105
  K.BcastIndices( I, gid % comm_size, tree.GetComm() );
4106
4107
	Data<T> Kab = K( I, J );
4108
4109
  auto loc_exact = potentials;
4110
  auto glb_exact = potentials;
4111
4112
  xgemm( "N", "N", Kab.row(), w.col(), w.row(),
4113
    1.0,       Kab.data(),       Kab.row(),
4114
                 w.data(),         w.row(),
4115
    0.0, loc_exact.data(), loc_exact.row() );
4116
  //gemm::xgemm( (T)1.0, Kab, w, (T)0.0, loc_exact );
4117
4118
4119
4120
4121
  /** Allreduce u( gid, : ) = K( gid, CBLK ) * w( RBLK, : ) */
4122
  mpi::Allreduce( loc_exact.data(), glb_exact.data(),
4123
      loc_exact.size(), MPI_SUM, tree.GetComm() );
4124
4125
  for ( uint64_t j = 0; j < w.col(); j ++ )
4126
  {
4127
    T exac = glb_exact[ j ];
4128
    T pred = potentials[ j ];
4129
    /** Accumulate SSE and sqaure 2-norm. */
4130
    ret.first  += ( pred - exac ) * ( pred - exac );
4131
    ret.second += exac * exac;
4132
  }
4133
4134
  return ret;
4135
}; /** end ComputeError() */
4136
4137
4138
4139
4140
4141
4142
4143
4144
4145
4146
template<typename TREE>
4147
void SelfTesting( TREE &tree, size_t ntest, size_t nrhs )
4148
{
4149
  /** Derive type T from TREE. */
4150
  using T = typename TREE::T;
4151
  /** MPI Support. */
4152
  int rank; mpi::Comm_rank( tree.GetComm(), &rank );
4153
  int size; mpi::Comm_size( tree.GetComm(), &size );
4154
  /** Size of right hand sides. */
4155
  size_t n = tree.n;
4156
  /** Shrink ntest if ntest > n. */
4157
  if ( ntest > n ) ntest = n;
4158
  /** all_rhs = [ 0, 1, ..., nrhs - 1 ]. */
4159
  vector<size_t> all_rhs( nrhs );
4160
  for ( size_t rhs = 0; rhs < nrhs; rhs ++ ) all_rhs[ rhs ] = rhs;
4161
4162
  //auto A = tree.CheckAllInteractions();
4163
4164
  /** Input and output in RIDS and RBLK. */
4165
  DistData<RIDS, STAR, T> w_rids( n, nrhs, tree.treelist[ 0 ]->gids, tree.GetComm() );
4166
  DistData<RBLK, STAR, T> u_rblk( n, nrhs, tree.GetComm() );
4167
  /** Initialize with random N( 0, 1 ). */
4168
  w_rids.randn();
4169
  /** Evaluate u ~ K * w. */
4170
  auto u_rids = mpigofmm::Evaluate<true>( tree, w_rids );
4171
  /** Sanity check for INF and NAN. */
4172
  assert( !u_rids.HasIllegalValue() );
4173
  /** Redistribute potentials from RIDS to RBLK. */
4174
  u_rblk = u_rids;
4175
  /** Report elementwise and F-norm accuracy. */
4176
  if ( rank == 0 )
4177
  {
4178
    printf( "========================================================\n");
4179
    printf( "Accuracy report\n" );
4180
    printf( "========================================================\n");
4181
  }
4182
  /** All statistics. */
4183
  T nnerr_avg = 0.0, nonnerr_avg = 0.0, fmmerr_avg = 0.0;
4184
  T sse_2norm = 0.0, ssv_2norm = 0.0;
4185
  /** Loop over all testing gids and right hand sides. */
4186
  for ( size_t i = 0; i < ntest; i ++ )
4187
  {
4188
    size_t tar = i * n / ntest;
4189
    Data<T> potentials( (size_t)1, nrhs );
4190
    if ( rank == ( tar % size ) ) potentials = u_rblk( vector<size_t>( 1, tar ), all_rhs );
4191
    /** Bcast potentials to all MPI processes. */
4192
    mpi::Bcast( potentials.data(), nrhs, tar % size, tree.GetComm() );
4193
    /** Compare potentials with exact MATVEC. */
4194
    auto sse_ssv = mpigofmm::ComputeError( tree, tar, potentials );
4195
    /** Compute element-wise 2-norm error. */
4196
    auto fmmerr  = sqrt( sse_ssv.first / sse_ssv.second );
4197
    /** Accumulate element-wise 2-norm error. */
4198
    fmmerr_avg += fmmerr;
4199
    /** Accumulate SSE and SSV. */
4200
    sse_2norm += sse_ssv.first;
4201
    ssv_2norm += sse_ssv.second;
4202
    /** Only print 10 values. */
4203
    if ( i < 10 && rank == 0 )
4204
    {
4205
      printf( "gid %6lu, ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n",
4206
          tar, 0.0, 0.0, fmmerr );
4207
    }
4208
  }
4209
  if ( rank == 0 )
4210
  {
4211
    printf( "========================================================\n");
4212
    printf( "Elementwise ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n",
4213
        nnerr_avg / ntest , nonnerr_avg / ntest, fmmerr_avg / ntest );
4214
    printf( "F-norm      ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n",
4215
        0.0, 0.0, sqrt( sse_2norm / ssv_2norm ) );
4216
    printf( "========================================================\n");
4217
  }
4218
4219
  /** Factorization */
4220
  T lambda = 10.0;
4221
  mpigofmm::DistFactorize( tree, lambda );
4222
  mpigofmm::ComputeError( tree, lambda, w_rids, u_rids );
4223
}; /** end SelfTesting() */
4224
4225
4226
/** @brief Instantiate the splitters here. */
4227
template<typename SPDMATRIX>
4228
void LaunchHelper( SPDMATRIX &K, gofmm::CommandLineHelper &cmd, mpi::Comm CommGOFMM )
4229
{
4230
  using T = typename SPDMATRIX::T;
4231
  const int N_CHILDREN = 2;
4232
  /** Use geometric-oblivious splitters. */
4233
  using SPLITTER     = mpigofmm::centersplit<SPDMATRIX, N_CHILDREN, T>;
4234
  using RKDTSPLITTER = mpigofmm::randomsplit<SPDMATRIX, N_CHILDREN, T>;
4235
  /** GOFMM tree splitter. */
4236
  SPLITTER splitter( K );
4237
  splitter.Kptr = &K;
4238
  splitter.metric = cmd.metric;
4239
  /** Randomized tree splitter. */
4240
  RKDTSPLITTER rkdtsplitter( K );
4241
  rkdtsplitter.Kptr = &K;
4242
  rkdtsplitter.metric = cmd.metric;
4243
	/** Create configuration for all user-define arguments. */
4244
  gofmm::Configuration<T> config( cmd.metric,
4245
      cmd.n, cmd.m, cmd.k, cmd.s, cmd.stol, cmd.budget );
4246
  /** (Optional) provide neighbors, leave uninitialized otherwise. */
4247
  DistData<STAR, CBLK, pair<T, size_t>> NN( 0, cmd.n, CommGOFMM );
4248
  /** Compress matrix K. */
4249
  auto *tree_ptr = mpigofmm::Compress( K, NN, splitter, rkdtsplitter, config, CommGOFMM );
4250
  auto &tree = *tree_ptr;
4251
4252
  /** Examine accuracies. */
4253
  mpigofmm::SelfTesting( tree, 100, cmd.nrhs );
4254
4255
	/** Delete tree_ptr. */
4256
  delete tree_ptr;
4257
}; /** end test_gofmm_setup() */
4258
4259
4260
}; /** end namespace gofmm */
4261
}; /** end namespace hmlp */
4262
4263
#endif /** define GOFMM_MPI_HPP */