GCC Code Coverage Report
Directory: . Exec Total Coverage
File: gofmm/gofmm.hpp Lines: 0 880 0.0 %
Date: 2019-01-14 Branches: 0 5665 0.0 %

Line Exec Source
1
/**
2
 *  HMLP (High-Performance Machine Learning Primitives)
3
 *
4
 *  Copyright (C) 2014-2017, The University of Texas at Austin
5
 *
6
 *  This program is free software: you can redistribute it and/or modify
7
 *  it under the terms of the GNU General Public License as published by
8
 *  the Free Software Foundation, either version 3 of the License, or
9
 *  (at your option) any later version.
10
 *
11
 *  This program is distributed in the hope that it will be useful,
12
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
 *  GNU General Public License for more details.
15
 *
16
 *  You should have received a copy of the GNU General Public License
17
 *  along with this program. If not, see the LICENSE file.
18
 *
19
 **/
20
21
#ifndef GOFMM_HPP
22
#define GOFMM_HPP
23
24
/** Use STL future, thread, chrono */
25
#include <future>
26
#include <thread>
27
#include <chrono>
28
#include <set>
29
#include <vector>
30
#include <map>
31
#include <unordered_set>
32
#include <deque>
33
#include <assert.h>
34
#include <typeinfo>
35
#include <algorithm>
36
#include <functional>
37
#include <array>
38
#include <random>
39
#include <numeric>
40
#include <sstream>
41
#include <iostream>
42
#include <string>
43
#include <stdio.h>
44
#include <omp.h>
45
#include <time.h>
46
47
/** Use HMLP related support. */
48
#include <hmlp.h>
49
#include <hmlp_base.hpp>
50
/** Use HMLP primitives. */
51
#include <primitives/lowrank.hpp>
52
#include <primitives/combinatorics.hpp>
53
#include <primitives/gemm.hpp>
54
/** Use HMLP containers. */
55
#include <containers/VirtualMatrix.hpp>
56
#include <containers/SPDMatrix.hpp>
57
/** GOFMM templates. */
58
#include <tree.hpp>
59
#include <igofmm.hpp>
60
/** gpu related */
61
#ifdef HMLP_USE_CUDA
62
#include <cuda_runtime.h>
63
#include <gofmm_gpu.hpp>
64
#endif
65
/** Use STL and HMLP namespaces. */
66
using namespace std;
67
using namespace hmlp;
68
69
70
71
/** this parameter is used to reserve space for std::vector */
72
#define MAX_NRHS 1024
73
/** the block size we use for partitioning GEMM tasks */
74
#define GEMM_NB 256
75
76
77
//#define DEBUG_SPDASKIT 1
78
#define REPORT_ANN_ACCURACY 1
79
#define REPORT_COMPRESS_STATUS 1
80
#define REPORT_EVALUATE_STATUS 1
81
82
83
namespace hmlp
84
{
85
namespace gofmm
86
{
87
88
/** @brief This is a helper class that parses the arguments from command lines. */
89
class CommandLineHelper
90
{
91
  public:
92
93
    /** (Default) constructor. */
94
    CommandLineHelper( int argc, char *argv[] )
95
    {
96
      /** Number of columns and rows, i.e. problem size. */
97
      sscanf( argv[ 1 ], "%lu", &n );
98
      /** On-diagonal block size, such that the tree has log(n/m) levels. */
99
      sscanf( argv[ 2 ], "%lu", &m );
100
      /** Number of neighbors to use. */
101
      sscanf( argv[ 3 ], "%lu", &k );
102
      /** Maximum off-diagonal ranks. */
103
      sscanf( argv[ 4 ], "%lu", &s );
104
      /** Number of right-hand sides. */
105
      sscanf( argv[ 5 ], "%lu", &nrhs );
106
      /** Desired approximation accuracy. */
107
      sscanf( argv[ 6 ], "%lf", &stol );
108
      /** The maximum percentage of direct matrix-multiplication. */
109
      sscanf( argv[ 7 ], "%lf", &budget );
110
      /** Specify distance type. */
111
      distance_type = argv[ 8 ];
112
      if ( !distance_type.compare( "geometry" ) )
113
      {
114
        metric = GEOMETRY_DISTANCE;
115
      }
116
      else if ( !distance_type.compare( "kernel" ) )
117
      {
118
        metric = KERNEL_DISTANCE;
119
      }
120
      else if ( !distance_type.compare( "angle" ) )
121
      {
122
        metric = ANGLE_DISTANCE;
123
      }
124
      else
125
      {
126
        printf( "%s is not supported\n", argv[ 8 ] );
127
        exit( 1 );
128
      }
129
      /** Specify what kind of spdmatrix is used. */
130
      spdmatrix_type = argv[ 9 ];
131
      if ( !spdmatrix_type.compare( "testsuit" ) )
132
      {
133
        /** NOP */
134
      }
135
      else if ( !spdmatrix_type.compare( "userdefine" ) )
136
      {
137
        /** NOP */
138
      }
139
      else if ( !spdmatrix_type.compare( "pvfmm" ) )
140
      {
141
        /** NOP */
142
      }
143
      else if ( !spdmatrix_type.compare( "dense" ) || !spdmatrix_type.compare( "ooc" ) )
144
      {
145
        /** (Optional) provide the path to the matrix file. */
146
        user_matrix_filename = argv[ 10 ];
147
        if ( argc > 11 )
148
        {
149
          /** (Optional) provide the path to the data file. */
150
          user_points_filename = argv[ 11 ];
151
          /** Dimension of the data set. */
152
          sscanf( argv[ 12 ], "%lu", &d );
153
        }
154
      }
155
      else if ( !spdmatrix_type.compare( "mlp" ) )
156
      {
157
        hidden_layers = argv[ 10 ];
158
        user_points_filename = argv[ 11 ];
159
        /** Number of attributes (dimensions). */
160
        sscanf( argv[ 12 ], "%lu", &d );
161
      }
162
      else if ( !spdmatrix_type.compare( "cov" ) )
163
      {
164
        kernelmatrix_type = argv[ 10 ];
165
        user_points_filename = argv[ 11 ];
166
        /** Number of attributes (dimensions) */
167
        sscanf( argv[ 12 ], "%lu", &d );
168
        /** Block size (in dimensions) per file */
169
        sscanf( argv[ 13 ], "%lu", &nb );
170
      }
171
      else if ( !spdmatrix_type.compare( "kernel" ) )
172
      {
173
        kernelmatrix_type = argv[ 10 ];
174
        user_points_filename = argv[ 11 ];
175
        /** Number of attributes (dimensions) */
176
        sscanf( argv[ 12 ], "%lu", &d );
177
        /** (Optional) provide Gaussian kernel bandwidth */
178
        if ( argc > 13 ) sscanf( argv[ 13 ], "%lf", &h );
179
      }
180
      else
181
      {
182
        printf( "%s is not supported\n", argv[ 9 ] );
183
        exit( 1 );
184
      }
185
    }; /** end CommentLineSupport() */
186
187
    /** Basic GOFMM parameters. */
188
    size_t n, m, k, s, nrhs;
189
    /** (Default) user-defined approximation toleratnce and budget. */
190
    double stol = 1E-3;
191
    double budget = 0.0;
192
    /** (Default) geometric-oblivious scheme. */
193
    DistanceMetric metric = ANGLE_DISTANCE;
194
195
    /** (Optional) */
196
    size_t d, nb;
197
    /** (Optional) set the default Gaussian kernel bandwidth. */
198
    double h = 1.0;
199
    string distance_type;
200
    string spdmatrix_type;
201
    string kernelmatrix_type;
202
    string hidden_layers;
203
    string user_matrix_filename;
204
    string user_points_filename;
205
}; /** end class CommandLineHelper */
206
207
208
209
210
/** @brief Configuration contains all user-defined parameters. */
211
template<typename T>
212
class Configuration
213
{
214
	public:
215
216
    Configuration() {};
217
218
		Configuration( DistanceMetric metric_type,
219
		  size_t problem_size, size_t leaf_node_size,
220
      size_t neighbor_size, size_t maximum_rank,
221
			T tolerance, T budget )
222
		{
223
      Set( metric_type, problem_size, leaf_node_size,
224
          neighbor_size, maximum_rank, tolerance, budget );
225
		};
226
227
		void Set( DistanceMetric metric_type,
228
		  size_t problem_size, size_t leaf_node_size,
229
      size_t neighbor_size, size_t maximum_rank,
230
			T tolerance, T budget )
231
		{
232
			this->metric_type = metric_type;
233
			this->problem_size = problem_size;
234
			this->leaf_node_size = leaf_node_size;
235
			this->neighbor_size = neighbor_size;
236
			this->maximum_rank = maximum_rank;
237
			this->tolerance = tolerance;
238
			this->budget = budget;
239
		};
240
241
    void CopyFrom( Configuration<T> &config ) { *this = config; };
242
243
		DistanceMetric MetricType() { return metric_type; };
244
245
		size_t ProblemSize() { return problem_size; };
246
247
		size_t LeafNodeSize() { return leaf_node_size; };
248
249
		size_t NeighborSize() { return neighbor_size; };
250
251
		size_t MaximumRank() { return maximum_rank; };
252
253
		T Tolerance() { return tolerance; };
254
255
		T Budget() { return budget; };
256
257
    bool IsSymmetric() { return is_symmetric; };
258
259
    bool UseAdaptiveRanks() { return use_adaptive_ranks; };
260
261
    bool SecureAccuracy() { return secure_accuracy; };
262
263
	private:
264
265
		/** (Default) metric type. */
266
		DistanceMetric metric_type = ANGLE_DISTANCE;
267
268
		/** (Default) problem size. */
269
		size_t problem_size = 0;
270
271
		/** (Default) maximum leaf node size. */
272
		size_t leaf_node_size = 64;
273
274
		/** (Default) number of neighbors. */
275
		size_t neighbor_size = 32;
276
277
		/** (Default) maximum off-diagonal ranks. */
278
		size_t maximum_rank = 64;
279
280
		/** (Default) user error tolerance. */
281
		T tolerance = 1E-3;
282
283
		/** (Default) user computation budget. */
284
		T budget = 0.03;
285
286
    /** (Default, Advanced) whether the matrix is symmetric. */
287
    bool is_symmetric = true;
288
289
		/** (Default, Advanced) whether or not using adaptive ranks. */
290
		bool use_adaptive_ranks = true;
291
292
    /** (Default, Advanced) whether or not securing the accuracy. */
293
    bool secure_accuracy = false;
294
295
}; /** end class Configuration */
296
297
298
299
/** @brief These are data that shared by the whole local tree. */
300
template<typename SPDMATRIX, typename SPLITTER, typename T>
301
class Setup : public tree::Setup<SPLITTER, T>,
302
              public Configuration<T>
303
{
304
  public:
305
306
    Setup() {};
307
308
    /** Shallow copy from the config. */
309
    void FromConfiguration( Configuration<T> &config,
310
        SPDMATRIX &K, SPLITTER &splitter, Data<pair<T, size_t>> *NN )
311
    {
312
      this->CopyFrom( config );
313
      this->K = &K;
314
      this->splitter = splitter;
315
      this->NN = NN;
316
    };
317
318
    /** The SPDMATRIX (accessed with gids: dense, CSC or OOC). */
319
    SPDMATRIX *K = NULL;
320
321
    /** rhs-by-n, weights and potentials. */
322
    Data<T> *w = NULL;
323
    Data<T> *u = NULL;
324
325
    /** Buffer space, either dimension needs to be n. */
326
    Data<T> *input = NULL;
327
    Data<T> *output = NULL;
328
329
    /** Regularization for factorization. */
330
    T lambda = 0.0;
331
332
    /** Use ULV or Sherman-Morrison-Woodbury */
333
    bool do_ulv_factorization = true;
334
335
  private:
336
337
338
339
340
}; /** end class Setup */
341
342
343
/** @brief This class contains all GOFMM related data carried by a tree node. */
344
template<typename T>
345
class NodeData : public Factor<T>
346
{
347
  public:
348
349
    /** (Default) constructor. */
350
    NodeData() {};
351
352
    /** The OpenMP (or pthread) lock that grants exclusive right. */
353
    Lock lock;
354
355
    /** Whether the node can be compressed (with skel and proj). */
356
    bool isskel = false;
357
358
    /** Skeleton gids (subset of gids). */
359
    vector<size_t> skels;
360
361
    /** 2s, pivoting order of GEQP3 (or GEQP4). */
362
    vector<int> jpvt;
363
364
    /** s-by-2s, interpolative coefficients. */
365
    Data<T> proj;
366
367
    /** Sampling neighbors gids. */
368
    map<size_t, T> snids;
369
370
    /** (Buffer) nsamples row gids, and sl + sr skeleton columns of children. */
371
    vector<size_t> candidate_rows;
372
    vector<size_t> candidate_cols;
373
374
    /** (Buffer) nsamples-by-(sl+sr) submatrix of K. */
375
    Data<T> KIJ;
376
377
    /** (Buffer) skeleton weights and potentials. */
378
    Data<T> w_skel;
379
    Data<T> u_skel;
380
381
    /** (Buffer) permuted weights and potentials. */
382
    Data<T> w_leaf;
383
    Data<T> u_leaf[ 20 ];
384
385
    /** Hierarchical tree view of w<RIDS, STAR> and u<RIDS, STAR>. */
386
    View<T> w_view;
387
    View<T> u_view;
388
389
    /** Cached Kab */
390
    Data<size_t> Nearbmap;
391
    Data<T> NearKab;
392
    Data<T> FarKab;
393
394
395
    /** recorded events (for HMLP Runtime) */
396
    Event skeletonize;
397
    Event updateweight;
398
    Event skeltoskel;
399
    Event skeltonode;
400
    Event s2s;
401
    Event s2n;
402
403
    /** knn accuracy */
404
    double knn_acc = 0.0;
405
    size_t num_acc = 0;
406
407
}; /** end class Data */
408
409
410
411
412
/** @brief This task creates an hierarchical tree view for w<RIDS> and u<RIDS>. */
413
template<typename NODE>
414
class TreeViewTask : public Task
415
{
416
  public:
417
418
    NODE *arg = NULL;
419
420
    void Set( NODE *user_arg )
421
    {
422
      arg = user_arg;
423
      name = string( "TreeView" );
424
      label = to_string( arg->treelist_id );
425
      cost = 1.0;
426
    };
427
428
    /** Preorder dependencies (with a single source node). */
429
    void DependencyAnalysis() { arg->DependOnParent( this ); };
430
431
    void Execute( Worker* user_worker )
432
    {
433
      //printf( "TreeView %lu\n", node->treelist_id );
434
      auto *node   = arg;
435
      auto &data   = node->data;
436
      auto *setup  = node->setup;
437
438
      /** w and u can be Data<T> or DistData<RIDS,STAR,T> */
439
      auto &w = *(setup->u);
440
      auto &u = *(setup->w);
441
442
      /** get the matrix view of this tree node */
443
      auto &U  = data.u_view;
444
      auto &W  = data.w_view;
445
446
      /** create contigious view for u and w at the root level */
447
      if ( !node->parent )
448
      {
449
        /** both w and u are column-majored, thus nontranspose */
450
        U.Set( u );
451
        W.Set( w );
452
      }
453
454
      /** partition u and w using the hierarchical tree view */
455
      if ( !node->isleaf )
456
      {
457
        auto &UL = node->lchild->data.u_view;
458
        auto &UR = node->rchild->data.u_view;
459
        auto &WL = node->lchild->data.w_view;
460
        auto &WR = node->rchild->data.w_view;
461
        /**
462
         *  U = [ UL;    W = [ WL;
463
         *        UR; ]        WR; ]
464
         */
465
        U.Partition2x1( UL,
466
                        UR, node->lchild->n, TOP );
467
        W.Partition2x1( WL,
468
                        WR, node->lchild->n, TOP );
469
      }
470
    };
471
}; /** end class TreeViewTask */
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
/** @brief Provide statistics summary for the execution section.  */
492
template<typename NODE>
493
class Summary
494
{
495
496
  public:
497
498
    Summary() {};
499
500
    deque<Statistic> rank;
501
502
    deque<Statistic> skeletonize;
503
504
    /** n2s */
505
    deque<Statistic> updateweight;
506
507
    /** s2s */
508
    deque<Statistic> s2s_kij_t;
509
    deque<Statistic> s2s_t;
510
    deque<Statistic> s2s_gfp;
511
512
    /** s2n */
513
    deque<Statistic> s2n_kij_t;
514
    deque<Statistic> s2n_t;
515
    deque<Statistic> s2n_gfp;
516
517
518
    void operator() ( NODE *node )
519
    {
520
      if ( rank.size() <= node->l )
521
      {
522
        rank.push_back( hmlp::Statistic() );
523
        skeletonize.push_back( hmlp::Statistic() );
524
        updateweight.push_back( hmlp::Statistic() );
525
      }
526
527
      rank[ node->l ].Update( (double)node->data.skels.size() );
528
      skeletonize[ node->l ].Update( node->data.skeletonize.GetDuration() );
529
      updateweight[ node->l ].Update( node->data.updateweight.GetDuration() );
530
531
#ifdef DUMP_ANALYSIS_DATA
532
      if ( node->parent )
533
      {
534
        auto *parent = node->parent;
535
        printf( "@TREE\n" );
536
        printf( "#%lu (s%lu), #%lu (s%lu), %lu, %lu\n",
537
            node->treelist_id, node->data.skels.size(),
538
            parent->treelist_id, parent->data.skels.size(),
539
            node->data.skels.size(), node->l );
540
      }
541
      else
542
      {
543
        printf( "@TREE\n" );
544
        printf( "#%lu (s%lu), , %lu, %lu\n",
545
            node->treelist_id, node->data.skels.size(),
546
            node->data.skels.size(), node->l );
547
      }
548
#endif
549
    };
550
551
    void Print()
552
    {
553
      for ( size_t l = 1; l < rank.size(); l ++ )
554
      {
555
        printf( "@SUMMARY\n" );
556
        printf( "level %2lu, ", l ); rank[ l ].Print();
557
        //printf( "skel_t:     " ); skeletonize[ l ].Print();
558
        //printf( "... ... ...\n" );
559
        //printf( "n2s_t:      " ); updateweight[ l ].Print();
560
        ////printf( "s2s_kij_t:  " ); s2s_kij_t[ l ].Print();
561
        //printf( "s2s_t:      " ); s2s_t[ l ].Print();
562
        //printf( "s2s_gfp:    " ); s2s_gfp[ l ].Print();
563
        ////printf( "s2n_kij_t:  " ); s2n_kij_t[ l ].Print();
564
        //printf( "s2n_t:      " ); s2n_t[ l ].Print();
565
        //printf( "s2n_gfp:    " ); s2n_gfp[ l ].Print();
566
      }
567
    };
568
569
}; /** end class Summary */
570
571
572
573
/**
574
 *  @brief This the main splitter used to build the Spd-Askit tree.
575
 *         First compute the approximate center using subsamples.
576
 *         Then find the two most far away points to do the
577
 *         projection.
578
 */
579
template<typename SPDMATRIX, int N_SPLIT, typename T>
580
struct centersplit
581
{
582
  /** Closure */
583
  SPDMATRIX *Kptr = NULL;
584
  /** (Default) use angle distance from the Gram vector space. */
585
  DistanceMetric metric = ANGLE_DISTANCE;
586
  /** Number samples to approximate centroid. */
587
  size_t n_centroid_samples = 5;
588
589
  centersplit() {};
590
591
  centersplit( SPDMATRIX& K ) { this->Kptr = &K; };
592
593
	/** Overload the operator (). */
594
  vector<vector<size_t>> operator() ( vector<size_t>& gids ) const
595
  {
596
    /** all assertions */
597
    assert( N_SPLIT == 2 );
598
    assert( Kptr );
599
600
601
    SPDMATRIX &K = *Kptr;
602
    vector<vector<size_t>> split( N_SPLIT );
603
    size_t n = gids.size();
604
    vector<T> temp( n, 0.0 );
605
606
    /** Collecting column samples of K. */
607
    auto column_samples = combinatorics::SampleWithoutReplacement(
608
        n_centroid_samples, gids );
609
610
611
    /** Compute all pairwise distances. */
612
    auto DIC = K.Distances( this->metric, gids, column_samples );
613
614
    /** Zero out the temporary buffer. */
615
    for ( auto & it : temp ) it = 0;
616
617
    /** Accumulate distances to the temporary buffer. */
618
    for ( size_t j = 0; j < DIC.col(); j ++ )
619
      for ( size_t i = 0; i < DIC.row(); i ++ )
620
        temp[ i ] += DIC( i, j );
621
622
    /** Find the f2c (far most to center) from points owned */
623
    auto idf2c = distance( temp.begin(), max_element( temp.begin(), temp.end() ) );
624
625
    /** Collecting KIP */
626
    vector<size_t> P( 1, gids[ idf2c ] );
627
628
    /** Compute all pairwise distances. */
629
    auto DIP = K.Distances( this->metric, gids, P );
630
631
    /** Find f2f (far most to far most) from owned points */
632
    auto idf2f = distance( DIP.begin(), max_element( DIP.begin(), DIP.end() ) );
633
634
    /** collecting KIQ */
635
    vector<size_t> Q( 1, gids[ idf2f ] );
636
637
    /** Compute all pairwise distances. */
638
    auto DIQ = K.Distances( this->metric, gids, P );
639
640
    for ( size_t i = 0; i < temp.size(); i ++ )
641
      temp[ i ] = DIP[ i ] - DIQ[ i ];
642
643
    return combinatorics::MedianSplit( temp );
644
  };
645
}; /** end struct centersplit */
646
647
648
649
650
651
652
653
654
655
656
/** @brief This the splitter used in the randomized tree.  */
657
template<typename SPDMATRIX, int N_SPLIT, typename T>
658
struct randomsplit
659
{
660
  /** closure */
661
  SPDMATRIX *Kptr = NULL;
662
663
	/** (default) using angle distance from the Gram vector space */
664
  DistanceMetric metric = ANGLE_DISTANCE;
665
666
  randomsplit() {};
667
668
  randomsplit( SPDMATRIX& K ) { this->Kptr = &K; };
669
670
	/** overload with the operator */
671
  inline vector<vector<size_t> > operator() ( vector<size_t>& gids ) const
672
  {
673
    assert( Kptr && ( N_SPLIT == 2 ) );
674
675
    SPDMATRIX &K = *Kptr;
676
    size_t n = gids.size();
677
    vector<vector<size_t> > split( N_SPLIT );
678
    vector<T> temp( n, 0.0 );
679
680
    /** Randomly select two points p and q. */
681
    size_t idf2c = std::rand() % n;
682
    size_t idf2f = std::rand() % n;
683
    while ( idf2c == idf2f ) idf2f = std::rand() % n;
684
685
686
    vector<size_t> P( 1, gids[ idf2c ] );
687
    vector<size_t> Q( 1, gids[ idf2f ] );
688
689
690
    /** Compute all pairwise distances. */
691
    auto DIP = K.Distances( this->metric, gids, P );
692
    auto DIQ = K.Distances( this->metric, gids, Q );
693
694
    for ( size_t i = 0; i < temp.size(); i ++ )
695
      temp[ i ] = DIP[ i ] - DIQ[ i ];
696
697
    return combinatorics::MedianSplit( temp );
698
699
  };
700
}; /** end struct randomsplit */
701
702
703
704
705
template<typename NODE>
706
void FindNeighbors( NODE *node, DistanceMetric metric )
707
{
708
  /** Derive type T from NODE. */
709
  using T = typename NODE::T;
710
  auto & setup = *(node->setup);
711
  auto & K = *(setup.K);
712
  auto & NN = *(setup.NN);
713
  auto & I = node->gids;
714
  /** Number of neighbors to search for. */
715
  size_t kappa = NN.row();
716
  /** Initial value for the neighbor select. */
717
  pair<T, size_t> init( numeric_limits<T>::max(), NN.col() );
718
  /** k-nearest neighbor search kernel. */
719
  auto candidates = K.NeighborSearch( metric, kappa, I, I, init );
720
  /** Merge and update neighbors. */
721
	#pragma omp parallel
722
  {
723
    vector<pair<T, size_t> > aux( 2 * kappa );
724
    #pragma omp for
725
    for ( size_t j = 0; j < I.size(); j ++ )
726
    {
727
      MergeNeighbors( kappa, NN.columndata( I[ j ] ),
728
          candidates.columndata( j ), aux );
729
    }
730
  }
731
}; /** end FindNeighbors() */
732
733
734
735
736
737
template<class NODE, typename T>
738
class NeighborsTask : public Task
739
{
740
  public:
741
742
    NODE *arg = NULL;
743
744
	  /** (Default) using angle distance from the Gram vector space. */
745
	  DistanceMetric metric = ANGLE_DISTANCE;
746
747
    void Set( NODE *user_arg )
748
    {
749
      arg = user_arg;
750
      name = string( "Neighbors" );
751
      label = to_string( arg->treelist_id );
752
      /** Use the same distance as the tree. */
753
      metric = arg->setup->MetricType();
754
755
      //--------------------------------------
756
      double flops, mops;
757
      auto &gids = arg->gids;
758
      auto &NN = *arg->setup->NN;
759
      flops = gids.size();
760
      flops *= ( 4.0 * gids.size() );
761
      // Heap select worst case
762
      mops = (size_t)std::log( NN.row() ) * gids.size();
763
      mops *= gids.size();
764
      // Access K
765
      mops += flops;
766
      event.Set( name + label, flops, mops );
767
      //--------------------------------------
768
769
      // TODO: Need an accurate cost model.
770
      cost = mops / 1E+9;
771
    };
772
773
    void DependencyAnalysis() { arg->DependOnNoOne( this ); };
774
775
    void Execute( Worker* user_worker ) { FindNeighbors( arg, metric ); };
776
777
}; /** end class NeighborsTask */
778
779
780
781
/** @brief This is the ANN routine design for CSC matrices. */
782
template<bool DOAPPROXIMATE, bool SORTED, typename T, typename CSCMATRIX>
783
Data<pair<T, size_t>> SparsePattern( size_t n, size_t k, CSCMATRIX &K )
784
{
785
  pair<T, size_t> initNN( numeric_limits<T>::max(), n );
786
  Data<pair<T, size_t>> NN( k, n, initNN );
787
788
  printf( "SparsePattern k %lu n %lu, NN.row %lu NN.col %lu ...",
789
      k, n, NN.row(), NN.col() ); fflush( stdout );
790
791
  #pragma omp parallel for schedule( dynamic )
792
  for ( size_t j = 0; j < n; j ++ )
793
  {
794
    std::set<size_t> NNset;
795
    size_t nnz = K.ColPtr( j + 1 ) - K.ColPtr( j );
796
    if ( DOAPPROXIMATE && nnz > 2 * k ) nnz = 2 * k;
797
798
    //printf( "j %lu nnz %lu\n", j, nnz );
799
800
    for ( size_t i = 0; i < nnz; i ++ )
801
    {
802
      // TODO: this is lid. Need to be gid.
803
      auto row_ind = K.RowInd( K.ColPtr( j ) + i );
804
      auto val     = K.Value( K.ColPtr( j ) + i );
805
806
      if ( val ) val = 1.0 / std::abs( val );
807
      else       val = std::numeric_limits<T>::max() - 1.0;
808
809
      NNset.insert( row_ind );
810
      std::pair<T, std::size_t> query( val, row_ind );
811
      if ( nnz < k ) // not enough candidates
812
      {
813
        NN[ j * k + i  ] = query;
814
      }
815
      else
816
      {
817
        hmlp::HeapSelect( 1, NN.row(), &query, NN.data() + j * NN.row() );
818
      }
819
    }
820
821
    while ( nnz < k )
822
    {
823
      std::size_t row_ind = rand() % n;
824
      if ( !NNset.count( row_ind ) )
825
      {
826
        T val = std::numeric_limits<T>::max() - 1.0;
827
        std::pair<T, std::size_t> query( val, row_ind );
828
        NNset.insert( row_ind );
829
        NN[ j * k + nnz ] = query;
830
        nnz ++;
831
      }
832
    }
833
  }
834
  printf( "Done.\n" ); fflush( stdout );
835
836
  if ( SORTED )
837
  {
838
    printf( "Sorting ... " ); fflush( stdout );
839
    struct
840
    {
841
      bool operator () ( std::pair<T, size_t> a, std::pair<T, size_t> b )
842
      {
843
        return a.first < b.first;
844
      }
845
    } ANNLess;
846
847
    //printf( "SparsePattern k %lu n %lu, NN.row %lu NN.col %lu\n", k, n, NN.row(), NN.col() );
848
849
    #pragma omp parallel for
850
    for ( size_t j = 0; j < NN.col(); j ++ )
851
    {
852
      std::sort( NN.data() + j * NN.row(), NN.data() + ( j + 1 ) * NN.row(), ANNLess );
853
    }
854
    printf( "Done.\n" ); fflush( stdout );
855
  }
856
857
  return NN;
858
}; /** end SparsePattern() */
859
860
861
/* @brief Helper functions for sorting sampling neighbors. */
862
template<typename TA, typename TB>
863
pair<TB, TA> flip_pair( const pair<TA, TB> &p )
864
{
865
  return pair<TB, TA>( p.second, p.first );
866
}; /** end flip_pair() */
867
868
869
template<typename TA, typename TB>
870
multimap<TB, TA> flip_map( const map<TA, TB> &src )
871
{
872
  multimap<TB, TA> dst;
873
  transform( src.begin(), src.end(), inserter( dst, dst.begin() ),
874
                 flip_pair<TA, TB> );
875
  return dst;
876
}; /** end flip_map() */
877
878
879
880
881
/** @brief Compute the cofficient matrix by R11^{-1} * proj. */
882
template<typename NODE>
883
void Interpolate( NODE *node )
884
{
885
  /** Derive type T from NODE. */
886
  using T = typename NODE::T;
887
  /** Early return if possible. */
888
  if ( !node ) return;
889
890
  auto &K = *node->setup->K;
891
  auto &data = node->data;
892
  auto &skels = data.skels;
893
  auto &proj = data.proj;
894
  auto &jpvt = data.jpvt;
895
  auto s = proj.row();
896
  auto n = proj.col();
897
898
  /** Early return if the node is incompressible or all zeros. */
899
  if ( !data.isskel || proj[ 0 ] == 0 ) return;
900
901
  assert( s );
902
  assert( s <= n );
903
  assert( jpvt.size() == n );
904
905
  /** If is skeletonized, reserve space for w_skel and u_skel */
906
  if ( data.isskel )
907
  {
908
    data.w_skel.reserve( skels.size(), MAX_NRHS );
909
    data.u_skel.reserve( skels.size(), MAX_NRHS );
910
  }
911
912
913
  /** Fill in R11. */
914
  Data<T> R1( s, s, 0.0 );
915
916
  for ( int j = 0; j < s; j ++ )
917
  {
918
    for ( int i = 0; i < s; i ++ )
919
    {
920
      if ( i <= j ) R1[ j * s + i ] = proj[ j * s + i ];
921
    }
922
  }
923
924
  /** Copy proj to tmp. */
925
  Data<T> tmp = proj;
926
927
  /** proj = inv( R1 ) * proj */
928
  xtrsm( "L", "U", "N", "N", s, n, 1.0, R1.data(), s, tmp.data(), s );
929
930
  /** Fill in proj. */
931
  for ( int j = 0; j < n; j ++ )
932
  {
933
    for ( int i = 0; i < s; i ++ )
934
    {
935
  	  proj[ jpvt[ j ] * s + i ] = tmp[ j * s + i ];
936
    }
937
  }
938
939
940
}; /** end Interpolate() */
941
942
943
/** @brief The correponding task of Interpolate(). */
944
template<typename NODE>
945
class InterpolateTask : public Task
946
{
947
  public:
948
949
    NODE *arg = NULL;
950
951
    void Set( NODE *user_arg )
952
    {
953
      arg = user_arg;
954
      name = string( "it" );
955
      label = to_string( arg->treelist_id );
956
      // Need an accurate cost model.
957
      cost = 1.0;
958
    };
959
960
    void DependencyAnalysis() { arg->DependOnNoOne( this ); };
961
962
    void Execute( Worker* user_worker ) { Interpolate( arg ); };
963
964
}; /** end class InterpolateTask */
965
966
967
968
969
/**
970
 *  TODO: I decided not to use the sampling pool
971
 */
972
template<bool NNPRUNE, typename NODE>
973
void RowSamples( NODE *node, size_t nsamples )
974
{
975
  /** Derive type T from NODE. */
976
  using T = typename NODE::T;
977
  auto &setup = *(node->setup);
978
  auto &data = node->data;
979
  auto &K = *(setup.K);
980
981
  /** amap contains nsamples of row gids of K. */
982
  auto &amap = data.candidate_rows;
983
984
  /** Clean up candidates from previous iteration. */
985
  amap.clear();
986
987
  /** Construct snids from neighbors. */
988
  if ( setup.NN )
989
  {
990
    //printf( "construct snids NN.row() %lu NN.col() %lu\n",
991
    //    node->setup->NN->row(), node->setup->NN->col() ); fflush( stdout );
992
    auto &NN = *(setup.NN);
993
    auto &gids = node->gids;
994
    auto &snids = data.snids;
995
    size_t knum = NN.row();
996
997
    if ( node->isleaf )
998
    {
999
      snids.clear();
1000
1001
      vector<pair<T, size_t>> tmp( knum * gids.size() );
1002
      for ( size_t j = 0; j < gids.size(); j ++ )
1003
        for ( size_t i = 0; i < knum; i ++ )
1004
          tmp[ j * knum + i ] = NN( i, gids[ j ] );
1005
1006
      /** Create a sorted list. */
1007
      sort( tmp.begin(), tmp.end() );
1008
1009
      /** Each candidate is a pair of (distance, gid). */
1010
      for ( auto it : tmp )
1011
      {
1012
        size_t it_gid = it.second;
1013
        size_t it_morton = setup.morton[ it_gid ];
1014
1015
        if ( snids.size() >= nsamples ) break;
1016
1017
        /** Accept the sample if it does not belong to any near node */
1018
        bool is_near;
1019
        if ( NNPRUNE ) is_near = node->NNNearNodeMortonIDs.count( it_morton );
1020
        else           is_near = (it_morton == node->morton );
1021
1022
        if ( !is_near )
1023
        {
1024
          /** Duplication is handled by std::map. */
1025
          auto ret = snids.insert( make_pair( it.second, it.first ) );
1026
        }
1027
      }
1028
    }
1029
    else
1030
    {
1031
      auto &lsnids = node->lchild->data.snids;
1032
      auto &rsnids = node->rchild->data.snids;
1033
1034
      /** Merge left children's sampling neighbors */
1035
      snids = lsnids;
1036
1037
      /**
1038
       *  TODO: Exclude lsnids (rsnids) that are near rchild (lchild),
1039
       *  perhaps using a NearNodes list defined for interior nodes.
1040
       **/
1041
      /** Merge right child's sample neighbors and update duplicate. */
1042
      for ( auto it = rsnids.begin(); it != rsnids.end(); it ++ )
1043
      {
1044
        auto ret = snids.insert( *it );
1045
        if ( !ret.second )
1046
        {
1047
          if ( ret.first->second > (*it).first )
1048
            ret.first->second = (*it).first;
1049
        }
1050
      }
1051
1052
      /** Remove on-diagonal indices (gids) */
1053
      for ( auto gid : gids ) snids.erase( gid );
1054
    }
1055
1056
1057
    if ( nsamples < K.col() - node->n )
1058
    {
1059
      /** Create an order snids by flipping the std::map */
1060
      multimap<T, size_t> ordered_snids = flip_map( snids );
1061
      /** Reserve space for insertion. */
1062
      amap.reserve( nsamples );
1063
1064
      /** First we use important samples from snids. */
1065
      for ( auto it : ordered_snids )
1066
      {
1067
        if ( amap.size() >= nsamples ) break;
1068
        /** it has type pair<T, size_t> */
1069
        amap.push_back( it.second );
1070
      }
1071
1072
      /** Use uniform samples with replacement if there are not enough samples. */
1073
      while ( amap.size() < nsamples )
1074
      {
1075
        //size_t sample = rand() % K.col();
1076
        auto important_sample = K.ImportantSample( 0 );
1077
        size_t sample_gid = important_sample.second;
1078
        size_t sample_morton = setup.morton[ sample_gid ];
1079
1080
        if ( !MortonHelper::IsMyParent( sample_morton, node->morton ) )
1081
        {
1082
          amap.push_back( sample_gid );
1083
        }
1084
      }
1085
    }
1086
    else /** use all off-diagonal blocks without samples */
1087
    {
1088
      for ( size_t sample = 0; sample < K.col(); sample ++ )
1089
      {
1090
        size_t sample_morton = setup.morton[ sample ];
1091
        if ( !MortonHelper::IsMyParent( sample_morton, node->morton ) )
1092
        {
1093
          amap.push_back( sample );
1094
        }
1095
      }
1096
    }
1097
  } /** end if ( node->setup->NN ) */
1098
1099
}; /** end RowSamples() */
1100
1101
1102
1103
1104
1105
1106
1107
1108
template<bool NNPRUNE, typename NODE>
1109
void SkeletonKIJ( NODE *node )
1110
{
1111
  /** Derive type T from NODE. */
1112
  using T = typename NODE::T;
1113
  /** Gather shared data and create reference. */
1114
  auto &K = *(node->setup->K);
1115
  /** Gather per node data and create reference. */
1116
  auto &data = node->data;
1117
  auto &candidate_rows = data.candidate_rows;
1118
  auto &candidate_cols = data.candidate_cols;
1119
  auto &KIJ = data.KIJ;
1120
  /** This node belongs to the local tree. */
1121
  auto *lchild = node->lchild;
1122
  auto *rchild = node->rchild;
1123
1124
  if ( node->isleaf )
1125
  {
1126
    /** Use all columns. */
1127
    candidate_cols = node->gids;
1128
  }
1129
  else
1130
  {
1131
    auto &lskels = lchild->data.skels;
1132
    auto &rskels = rchild->data.skels;
1133
    /** If either child is not skeletonized, then return. */
1134
    if ( !lskels.size() || !rskels.size() ) return;
1135
    /** Concatinate [ lskels, rskels ]. */
1136
    candidate_cols = lskels;
1137
    candidate_cols.insert( candidate_cols.end(),
1138
        rskels.begin(), rskels.end() );
1139
  }
1140
1141
  /** Decide number of rows to sample. */
1142
  size_t nsamples = 2 * candidate_cols.size();
1143
1144
  /** Make sure we at least m samples. */
1145
  if ( nsamples < 2 * node->setup->LeafNodeSize() )
1146
    nsamples = 2 * node->setup->LeafNodeSize();
1147
1148
  /** Sample off-diagonal rows. */
1149
  RowSamples<NNPRUNE>( node, nsamples );
1150
1151
  /** Compute (or fetch) submatrix KIJ. */
1152
  KIJ = K( candidate_rows, candidate_cols );
1153
1154
1155
1156
}; /** end SkeletonKIJ() */
1157
1158
1159
1160
1161
1162
1163
1164
/**
1165
 *
1166
 */
1167
template<bool NNPRUNE, typename NODE, typename T>
1168
class SkeletonKIJTask : public Task
1169
{
1170
  public:
1171
1172
    NODE *arg = NULL;
1173
1174
    void Set( NODE *user_arg )
1175
    {
1176
      arg = user_arg;
1177
      name = string( "par-gskm" );
1178
      label = to_string( arg->treelist_id );
1179
      /** we don't know the exact cost here */
1180
      cost = 5.0;
1181
      /** high priority */
1182
      priority = true;
1183
    };
1184
1185
    void DependencyAnalysis() { arg->DependOnChildren( this ); };
1186
1187
    void Execute( Worker* user_worker ) { SkeletonKIJ<NNPRUNE>( arg ); };
1188
1189
}; /** end class SkeletonKIJTask */
1190
1191
1192
/** @brief Compress with interpolative decomposition (ID). */
1193
template<typename NODE>
1194
void Skeletonize( NODE *node )
1195
{
1196
  /** Derive type T from NODE. */
1197
  using T = typename NODE::T;
1198
  /** Early return if we do not need to skeletonize. */
1199
  if ( !node->parent ) return;
1200
1201
  /** Gather shared data and create reference. */
1202
  auto &K   = *(node->setup->K);
1203
  auto &NN  = *(node->setup->NN);
1204
  auto maxs = node->setup->MaximumRank();
1205
  auto stol = node->setup->Tolerance();
1206
  bool secure_accuracy = node->setup->SecureAccuracy();
1207
  bool use_adaptive_ranks = node->setup->UseAdaptiveRanks();
1208
1209
  /** Gather per node data and create reference. */
1210
  auto &data  = node->data;
1211
  auto &skels = data.skels;
1212
  auto &proj  = data.proj;
1213
  auto &jpvt  = data.jpvt;
1214
  auto &KIJ   = data.KIJ;
1215
  auto &candidate_cols = data.candidate_cols;
1216
1217
  /** Interpolative decomposition (ID). */
1218
  size_t N = K.col();
1219
  size_t m = KIJ.row();
1220
  size_t n = KIJ.col();
1221
  size_t q = node->n;
1222
1223
  if ( secure_accuracy )
1224
  {
1225
    if ( !node->isleaf && ( !node->lchild->data.isskel || !node->rchild->data.isskel ) )
1226
    {
1227
      skels.clear();
1228
      proj.resize( 0, 0 );
1229
      data.isskel = false;
1230
      return;
1231
    }
1232
  }
1233
1234
  /** Bill's l2 norm scaling factor. */
1235
  T scaled_stol = std::sqrt( (T)n / q ) * std::sqrt( (T)m / (N - q) ) * stol;
1236
  /** Account for uniform sampling. */
1237
  scaled_stol *= std::sqrt( (T)q / N );
1238
1239
  /** Call adaptive interpolative decomposition primitive. */
1240
  lowrank::id( use_adaptive_ranks, secure_accuracy,
1241
    KIJ.row(), KIJ.col(), maxs, scaled_stol, KIJ, skels, proj, jpvt );
1242
1243
  /** free KIJ for spaces */
1244
  KIJ.resize( 0, 0 );
1245
1246
  /** Depending on the flag, decide isskel or not. */
1247
  if ( secure_accuracy )
1248
  {
1249
    /** TODO: this needs to be bcast to other nodes */
1250
    data.isskel = (skels.size() != 0);
1251
  }
1252
  else
1253
  {
1254
    assert( skels.size() && proj.size() && jpvt.size() );
1255
    data.isskel = true;
1256
  }
1257
1258
  /** Relabel skeletions with the real gids. */
1259
  for ( size_t i = 0; i < skels.size(); i ++ ) skels[ i ] = candidate_cols[ skels[ i ] ];
1260
1261
}; /** end Skeletonize() */
1262
1263
1264
template<typename NODE, typename T>
1265
class SkeletonizeTask : public Task
1266
{
1267
  public:
1268
1269
    NODE *arg = NULL;
1270
1271
    void Set( NODE *user_arg )
1272
    {
1273
      arg = user_arg;
1274
      name = string( "sk" );
1275
      label = to_string( arg->treelist_id );
1276
      /** we don't know the exact cost here */
1277
      cost = 5.0;
1278
      /** high priority */
1279
      priority = true;
1280
    };
1281
1282
    void GetEventRecord()
1283
    {
1284
      double flops = 0.0, mops = 0.0;
1285
1286
      auto &K = *arg->setup->K;
1287
      size_t n = arg->data.proj.col();
1288
      size_t m = 2 * n;
1289
      size_t k = arg->data.proj.row();
1290
1291
      /** GEQP3 */
1292
      flops += ( 2.0 / 3.0 ) * n * n * ( 3 * m - n );
1293
      mops += ( 2.0 / 3.0 ) * n * n * ( 3 * m - n );
1294
1295
      /* TRSM */
1296
      flops += k * ( k - 1 ) * ( n + 1 );
1297
      mops  += 2.0 * ( k * k + k * n );
1298
1299
      //flops += ( 2.0 / 3.0 ) * k * k * ( 3 * m - k );
1300
      //mops += 2.0 * m * k;
1301
      //flops += 2.0 * m * n * k;
1302
      //mops += 2.0 * ( m * k + k * n + m * n );
1303
      //flops += ( 1.0 / 3.0 ) * k * k * n;
1304
      //mops += 2.0 * ( k * k + k * n );
1305
1306
      event.Set( label + name, flops, mops );
1307
      arg->data.skeletonize = event;
1308
    };
1309
1310
    void DependencyAnalysis() { arg->DependOnNoOne( this ); };
1311
1312
    void Execute( Worker* user_worker ) { Skeletonize( arg ); };
1313
1314
}; /** end class SkeletonizeTask */
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
/** @brief Compute skeleton weights for each node. */
1349
template<typename NODE>
1350
void UpdateWeights( NODE *node )
1351
{
1352
  /** Derive type T from NODE. */
1353
  using T = typename NODE::T;
1354
  /** Early return if possible. */
1355
  if ( !node->parent || !node->data.isskel ) return;
1356
1357
  /** Gather shared data and create reference */
1358
  auto &w = *node->setup->w;
1359
1360
  /** Gather per node data and create reference */
1361
  auto &data = node->data;
1362
  auto &proj = data.proj;
1363
  auto &skels = data.skels;
1364
  auto &w_skel = data.w_skel;
1365
  auto &w_leaf = data.w_leaf;
1366
  auto *lchild = node->lchild;
1367
  auto *rchild = node->rchild;
1368
1369
1370
  size_t nrhs = w.col();
1371
1372
  /** w_skel is s-by-nrhs, initial values are not important */
1373
  w_skel.resize( skels.size(), nrhs );
1374
1375
  //printf( "%lu UpdateWeight w_skel.num() %lu\n", node->treelist_id, w_skel.num() );
1376
1377
  if ( node->isleaf )
1378
  {
1379
    if ( w_leaf.size() )
1380
    {
1381
      //printf( "%8lu w_leaf allocated [%lu %lu]\n",
1382
      //    node->morton, w_leaf.row(), w_leaf.col() ); fflush( stdout );
1383
1384
      /** w_leaf is allocated */
1385
      xgemm
1386
      (
1387
        "N", "N",
1388
        w_skel.row(), w_skel.col(), w_leaf.row(),
1389
        1.0, proj.data(),   proj.row(),
1390
             w_leaf.data(), w_leaf.row(),
1391
        0.0, w_skel.data(), w_skel.row()
1392
      );
1393
    }
1394
    else
1395
    {
1396
      /** w_leaf is not allocated, use w_view instead */
1397
      View<T> W = data.w_view;
1398
      //printf( "%8lu n2s W[%lu %lu ld %lu]\n",
1399
      //    node->morton, W.row(), W.col(), W.ld() ); fflush( stdout );
1400
      //for ( int i = 0; i < 10; i ++ )
1401
      //  printf( "%lu W.data() + %d = %E\n", node->gids[ i ], i, *(W.data() + i) );
1402
      xgemm
1403
      (
1404
        "N", "N",
1405
        w_skel.row(), w_skel.col(), W.row(),
1406
        1.0, proj.data(),   proj.row(),
1407
                W.data(),       W.ld(),
1408
        0.0, w_skel.data(), w_skel.row()
1409
      );
1410
    }
1411
1412
    //double update_leaf_time = omp_get_wtime() - beg;
1413
    //printf( "%lu, m %lu n %lu k %lu, total %.3E\n",
1414
    //  node->treelist_id,
1415
    //  w_skel.row(), w_skel.col(), w_leaf.col(), update_leaf_time );
1416
  }
1417
  else
1418
  {
1419
    //double beg = omp_get_wtime();
1420
    auto &w_lskel = lchild->data.w_skel;
1421
    auto &w_rskel = rchild->data.w_skel;
1422
    auto &lskel = lchild->data.skels;
1423
    auto &rskel = rchild->data.skels;
1424
1425
    //if ( 1 )
1426
    if ( node->treelist_id > 6 )
1427
    {
1428
      //printf( "%8lu n2s\n", node->morton ); fflush( stdout );
1429
      xgemm
1430
      (
1431
        "N", "N",
1432
        w_skel.row(), w_skel.col(), lskel.size(),
1433
        1.0,    proj.data(),    proj.row(),
1434
             w_lskel.data(), w_lskel.row(),
1435
        0.0,  w_skel.data(),  w_skel.row()
1436
      );
1437
      xgemm
1438
      (
1439
        "N", "N",
1440
        w_skel.row(), w_skel.col(), rskel.size(),
1441
        1.0,    proj.data() + proj.row() * lskel.size(), proj.row(),
1442
             w_rskel.data(), w_rskel.row(),
1443
        1.0,  w_skel.data(),  w_skel.row()
1444
      );
1445
    }
1446
    else
1447
    {
1448
      /** create a view proj_v */
1449
      View<T> P( false,   proj ), PL,
1450
                                  PR;
1451
      View<T> W( false, w_skel ), WL( false, w_lskel ),
1452
                                  WR( false, w_rskel );
1453
      /** P = [ PL, PR ] */
1454
      P.Partition1x2( PL, PR, lskel.size(), LEFT );
1455
      /** W  = PL * WL */
1456
      gemm::xgemm<GEMM_NB>( (T)1.0, PL, WL, (T)0.0, W );
1457
      W.DependencyCleanUp();
1458
      /** W += PR * WR */
1459
      gemm::xgemm<GEMM_NB>( (T)1.0, PR, WR, (T)1.0, W );
1460
      //W.DependencyCleanUp();
1461
    }
1462
  }
1463
}; /** end UpdateWeights() */
1464
1465
1466
/**
1467
 *
1468
 */
1469
template<typename NODE, typename T>
1470
class UpdateWeightsTask : public Task
1471
{
1472
  public:
1473
1474
    NODE *arg = NULL;
1475
1476
    void Set( NODE *user_arg )
1477
    {
1478
      arg = user_arg;
1479
      name = string( "n2s" );
1480
      label = to_string( arg->treelist_id );
1481
1482
      /** Compute flops and mops */
1483
      double flops, mops;
1484
      auto &gids = arg->gids;
1485
      auto &skels = arg->data.skels;
1486
      auto &w = *arg->setup->w;
1487
      if ( arg->isleaf )
1488
      {
1489
        auto m = skels.size();
1490
        auto n = w.col();
1491
        auto k = gids.size();
1492
        flops = 2.0 * m * n * k;
1493
        mops = 2.0 * ( m * n + m * k + k * n );
1494
      }
1495
      else
1496
      {
1497
        auto &lskels = arg->lchild->data.skels;
1498
        auto &rskels = arg->rchild->data.skels;
1499
        auto m = skels.size();
1500
        auto n = w.col();
1501
        auto k = lskels.size() + rskels.size();
1502
        flops = 2.0 * m * n * k;
1503
        mops  = 2.0 * ( m * n + m * k + k * n );
1504
      }
1505
1506
      /** Setup the event */
1507
      event.Set( label + name, flops, mops );
1508
      /** Assume computation bound */
1509
      cost = flops / 1E+9;
1510
      /** "HIGH" priority (critical path) */
1511
      priority = true;
1512
    };
1513
1514
    void Prefetch( Worker* user_worker )
1515
    {
1516
      auto &proj = arg->data.proj;
1517
      __builtin_prefetch( proj.data() );
1518
      auto &w_skel = arg->data.w_skel;
1519
      __builtin_prefetch( w_skel.data() );
1520
      if ( arg->isleaf )
1521
      {
1522
        auto &w_leaf = arg->data.w_leaf;
1523
        __builtin_prefetch( w_leaf.data() );
1524
      }
1525
      else
1526
      {
1527
        auto &w_lskel = arg->lchild->data.w_skel;
1528
        __builtin_prefetch( w_lskel.data() );
1529
        auto &w_rskel = arg->rchild->data.w_skel;
1530
        __builtin_prefetch( w_rskel.data() );
1531
      }
1532
#ifdef HMLP_USE_CUDA
1533
      hmlp::Device *device = NULL;
1534
      if ( user_worker ) device = user_worker->GetDevice();
1535
      if ( device )
1536
      {
1537
        proj.CacheD( device );
1538
        proj.PrefetchH2D( device, 1 );
1539
        if ( arg->isleaf )
1540
        {
1541
          auto &w_leaf = arg->data.w_leaf;
1542
          w_leaf.CacheD( device );
1543
          w_leaf.PrefetchH2D( device, 1 );
1544
        }
1545
        else
1546
        {
1547
          auto &w_lskel = arg->lchild->data.w_skel;
1548
          w_lskel.CacheD( device );
1549
          w_lskel.PrefetchH2D( device, 1 );
1550
          auto &w_rskel = arg->rchild->data.w_skel;
1551
          w_rskel.CacheD( device );
1552
          w_rskel.PrefetchH2D( device, 1 );
1553
        }
1554
      }
1555
#endif
1556
    };
1557
1558
    void DependencyAnalysis() { arg->DependOnChildren( this ); };
1559
1560
    void Execute( Worker* user_worker )
1561
    {
1562
#ifdef HMLP_USE_CUDA
1563
      hmlp::Device *device = NULL;
1564
      if ( user_worker ) device = user_worker->GetDevice();
1565
      if ( device ) gpu::UpdateWeights( device, arg );
1566
      else               UpdateWeights<NODE, T>( arg );
1567
#else
1568
      UpdateWeights( arg );
1569
#endif
1570
    };
1571
1572
}; /** end class UpdateWeightsTask */
1573
1574
1575
1576
/**
1577
 *  @brief Compute the interation from column skeletons to row
1578
 *         skeletons. Store the results in the node. Later
1579
 *         there is a SkeletonstoAll function to be called.
1580
 *
1581
 */
1582
template<typename NODE>
1583
void SkeletonsToSkeletons( NODE *node )
1584
{
1585
  /** Derive type T from NODE. */
1586
  using T = typename NODE::T;
1587
  /** Early return if possible. */
1588
  if ( !node->parent || !node->data.isskel ) return;
1589
1590
  double beg, u_skel_time, s2s_time;
1591
1592
  auto *FarNodes = &node->NNFarNodes;
1593
1594
  auto &K = *node->setup->K;
1595
  auto &data = node->data;
1596
  auto &amap = node->data.skels;
1597
  auto &u_skel = node->data.u_skel;
1598
  auto &FarKab = node->data.FarKab;
1599
1600
  size_t nrhs = node->setup->w->col();
1601
1602
  /** initilize u_skel to be zeros( s, nrhs ). */
1603
  beg = omp_get_wtime();
1604
  u_skel.resize( 0, 0 );
1605
  u_skel.resize( amap.size(), nrhs, 0.0 );
1606
  u_skel_time = omp_get_wtime() - beg;
1607
1608
  size_t offset = 0;
1609
1610
1611
  /** create a base view for FarKab */
1612
  View<T> FarKab_v( FarKab );
1613
1614
  /** reduce all u_skel */
1615
  for ( auto it = FarNodes->begin(); it != FarNodes->end(); it ++ )
1616
  {
1617
    auto &bmap = (*it)->data.skels;
1618
    auto &w_skel = (*it)->data.w_skel;
1619
    assert( w_skel.col() == nrhs );
1620
    assert( w_skel.row() == bmap.size() );
1621
    assert( w_skel.size() == nrhs * bmap.size() );
1622
1623
    if ( FarKab.size() ) /** Kab is cached */
1624
    {
1625
      //if ( node->treelist_id > 6 )
1626
      if ( 1 )
1627
      {
1628
        assert( FarKab.row() == amap.size() );
1629
        assert( u_skel.row() * offset <= FarKab.size() );
1630
1631
        //printf( "%8lu s2s %8lu w_skel[%lu %lu]\n",
1632
        //    node->morton, (*it)->morton, w_skel.row(), w_skel.col() );
1633
        //fflush( stdout );
1634
        xgemm
1635
        (
1636
          "N", "N",
1637
          u_skel.row(), u_skel.col(), w_skel.row(),
1638
          1.0, FarKab.data() + u_skel.row() * offset, FarKab.row(),
1639
               w_skel.data(),          w_skel.row(),
1640
          1.0, u_skel.data(),          u_skel.row()
1641
        );
1642
1643
      }
1644
      else
1645
      {
1646
        /** create views */
1647
        View<T> U( false, u_skel );
1648
        View<T> W( false, w_skel );
1649
        View<T> Kab;
1650
        assert( FarKab.col() >= W.row() + offset );
1651
        Kab.Set( FarKab.row(), W.row(), 0, offset, &FarKab_v );
1652
        gemm::xgemm<GEMM_NB>( (T)1.0, Kab, W, (T)1.0, U );
1653
      }
1654
1655
      /** move to the next submatrix Kab */
1656
      offset += w_skel.row();
1657
    }
1658
    else
1659
    {
1660
      printf( "Far Kab not cached treelist_id %lu, l %lu\n\n",
1661
					node->treelist_id, node->l ); fflush( stdout );
1662
1663
      /** get submatrix Kad from K */
1664
      auto Kab = K( amap, bmap );
1665
      xgemm( "N", "N", u_skel.row(), u_skel.col(), w_skel.row(),
1666
        1.0, Kab.data(),       Kab.row(),
1667
             w_skel.data(), w_skel.row(),
1668
        1.0, u_skel.data(), u_skel.row() );
1669
    }
1670
  }
1671
1672
}; /** end SkeletonsToSkeletons() */
1673
1674
1675
1676
/**
1677
 *  @brief There is no dependency between each task. However
1678
 *         there are raw (read after write) dependencies:
1679
 *
1680
 *         NodesToSkeletons (P*w)
1681
 *         SkeletonsToSkeletons ( Sum( Kab * ))
1682
 *
1683
 *  @TODO  The flops and mops of constructing Kab.
1684
 *
1685
 */
1686
template<bool NNPRUNE, typename NODE, typename T>
1687
class SkeletonsToSkeletonsTask : public Task
1688
{
1689
  public:
1690
1691
    NODE *arg = NULL;
1692
1693
    void Set( NODE *user_arg )
1694
    {
1695
      arg = user_arg;
1696
      name = string( "s2s" );
1697
      {
1698
        //label = std::to_string( arg->treelist_id );
1699
        ostringstream ss;
1700
        ss << arg->treelist_id;
1701
        label = ss.str();
1702
      }
1703
1704
      /** compute flops and mops */
1705
      double flops = 0.0, mops = 0.0;
1706
      auto &w = *arg->setup->w;
1707
      size_t m = arg->data.skels.size();
1708
      size_t n = w.col();
1709
1710
      std::set<NODE*> *FarNodes;
1711
      if ( NNPRUNE ) FarNodes = &arg->NNFarNodes;
1712
      else           FarNodes = &arg->FarNodes;
1713
1714
      for ( auto it = FarNodes->begin(); it != FarNodes->end(); it ++ )
1715
      {
1716
        size_t k = (*it)->data.skels.size();
1717
        flops += 2.0 * m * n * k;
1718
        mops  += m * k; // cost of Kab
1719
        mops  += 2.0 * ( m * n + n * k + k * n );
1720
      }
1721
1722
      /** Setup the event */
1723
      event.Set( label + name, flops, mops );
1724
      /** Assume computation bound */
1725
      cost = flops / 1E+9;
1726
      /** High priority */
1727
      priority = true;
1728
    };
1729
1730
    void DependencyAnalysis()
1731
    {
1732
      for ( auto it : arg->NNFarNodes ) it->DependencyAnalysis( R, this );
1733
      arg->DependencyAnalysis( RW, this );
1734
      this->TryEnqueue();
1735
    };
1736
1737
    void Execute( Worker* user_worker ) { SkeletonsToSkeletons( arg ); };
1738
}; /** end class SkeletonsToSkeletonsTask */
1739
1740
1741
/**
1742
 *  @brief This is a task in Downward traversal. There is data
1743
 *         dependency on u_skel.
1744
 *
1745
 */
1746
template<typename NODE>
1747
void SkeletonsToNodes( NODE *node )
1748
{
1749
  /** Derive type T from NODE. */
1750
  using T = typename NODE::T;
1751
1752
  /** Gather shared data and create reference. */
1753
  auto &K = *node->setup->K;
1754
  auto &w = *node->setup->w;
1755
1756
  /** Gather per node data and create reference. */
1757
  auto &gids = node->gids;
1758
  auto &data = node->data;
1759
  auto &proj = data.proj;
1760
  auto &skels = data.skels;
1761
  auto &u_skel = data.u_skel;
1762
  auto *lchild = node->lchild;
1763
  auto *rchild = node->rchild;
1764
1765
  size_t nrhs = w.col();
1766
1767
1768
1769
1770
1771
  if ( node->isleaf )
1772
  {
1773
    /** Get U view of this node if initialized */
1774
    View<T> U = data.u_view;
1775
1776
    if ( U.col() == nrhs )
1777
    {
1778
      //printf( "%8lu s2n U[%lu %lu %lu]\n",
1779
      //    node->morton, U.row(), U.col(), U.ld() ); fflush( stdout );
1780
      xgemm
1781
      (
1782
        "Transpose", "Non-transpose",
1783
        U.row(), U.col(), u_skel.row(),
1784
        1.0,   proj.data(),   proj.row(),
1785
             u_skel.data(), u_skel.row(),
1786
        1.0,      U.data(),       U.ld()
1787
      );
1788
    }
1789
    else
1790
    {
1791
      //printf( "%8lu use u_leaf u_view [%lu %lu ld %lu]\n",
1792
      //    node->morton, U.row(), U.col(), U.ld()  ); fflush( stdout );
1793
1794
      auto &u_leaf = node->data.u_leaf[ 0 ];
1795
1796
      /** zero-out u_leaf */
1797
      u_leaf.resize( 0, 0 );
1798
      u_leaf.resize( gids.size(), nrhs, 0.0 );
1799
1800
      /** accumulate far interactions */
1801
      if ( data.isskel )
1802
      {
1803
        /** u_leaf += P' * u_skel */
1804
        xgemm
1805
        (
1806
          "T", "N",
1807
          u_leaf.row(), u_leaf.col(), u_skel.row(),
1808
          1.0,   proj.data(),   proj.row(),
1809
               u_skel.data(), u_skel.row(),
1810
          1.0, u_leaf.data(), u_leaf.row()
1811
        );
1812
      }
1813
    }
1814
  }
1815
  else
1816
  {
1817
    if ( !node->parent || !node->data.isskel ) return;
1818
1819
    auto &u_lskel = lchild->data.u_skel;
1820
    auto &u_rskel = rchild->data.u_skel;
1821
    auto &lskel = lchild->data.skels;
1822
    auto &rskel = rchild->data.skels;
1823
1824
    //if ( 1 )
1825
    if ( node->treelist_id > 6 )
1826
    {
1827
      //printf( "%8lu s2n\n", node->morton ); fflush( stdout );
1828
      xgemm
1829
      (
1830
        "Transpose", "No transpose",
1831
        u_lskel.row(), u_lskel.col(), proj.row(),
1832
        1.0, proj.data(),    proj.row(),
1833
             u_skel.data(),  u_skel.row(),
1834
        1.0, u_lskel.data(), u_lskel.row()
1835
      );
1836
      xgemm
1837
      (
1838
        "Transpose", "No transpose",
1839
        u_rskel.row(), u_rskel.col(), proj.row(),
1840
        1.0, proj.data() + proj.row() * lskel.size(), proj.row(),
1841
             u_skel.data(), u_skel.row(),
1842
        1.0, u_rskel.data(), u_rskel.row()
1843
      );
1844
    }
1845
    else
1846
    {
1847
      /** create a transpose view proj_v */
1848
      View<T> P(  true,   proj ), PL,
1849
                                  PR;
1850
      View<T> U( false, u_skel ), UL( false, u_lskel ),
1851
                                  UR( false, u_rskel );
1852
      /** P' = [ PL, PR ]' */
1853
      P.Partition2x1( PL,
1854
                      PR, lskel.size(), TOP );
1855
      /** UL += PL' * U */
1856
      gemm::xgemm<GEMM_NB>( (T)1.0, PL, U, (T)1.0, UL );
1857
      /** UR += PR' * U */
1858
      gemm::xgemm<GEMM_NB>( (T)1.0, PR, U, (T)1.0, UR );
1859
    }
1860
  }
1861
  //printf( "\n" );
1862
1863
}; /** end SkeletonsToNodes() */
1864
1865
1866
template<bool NNPRUNE, typename NODE, typename T>
1867
class SkeletonsToNodesTask : public Task
1868
{
1869
  public:
1870
1871
    NODE *arg = NULL;
1872
1873
    void Set( NODE *user_arg )
1874
    {
1875
      arg = user_arg;
1876
      name = string( "s2n" );
1877
      label = to_string( arg->treelist_id );
1878
1879
      //--------------------------------------
1880
      double flops = 0.0, mops = 0.0;
1881
      auto &gids = arg->gids;
1882
      auto &data = arg->data;
1883
      auto &proj = data.proj;
1884
      auto &skels = data.skels;
1885
      auto &w = *arg->setup->w;
1886
1887
      if ( arg->isleaf )
1888
      {
1889
        size_t m = proj.col();
1890
        size_t n = w.col();
1891
        size_t k = proj.row();
1892
        flops += 2.0 * m * n * k;
1893
        mops  += 2.0 * ( m * n + n * k + m * k );
1894
      }
1895
      else
1896
      {
1897
        if ( !arg->parent || !arg->data.isskel )
1898
        {
1899
          // No computation.
1900
        }
1901
        else
1902
        {
1903
          size_t m = proj.col();
1904
          size_t n = w.col();
1905
          size_t k = proj.row();
1906
          flops += 2.0 * m * n * k;
1907
          mops  += 2.0 * ( m * n + n * k + m * k );
1908
        }
1909
      }
1910
1911
      /** Setup the event */
1912
      event.Set( label + name, flops, mops );
1913
      /** Asuume computation bound */
1914
      cost = flops / 1E+9;
1915
      /** "HIGH" priority (critical path) */
1916
      priority = true;
1917
    };
1918
1919
    void Prefetch( Worker* user_worker )
1920
    {
1921
      auto &proj = arg->data.proj;
1922
      __builtin_prefetch( proj.data() );
1923
      auto &u_skel = arg->data.u_skel;
1924
      __builtin_prefetch( u_skel.data() );
1925
      if ( arg->isleaf )
1926
      {
1927
        //__builtin_prefetch( arg->data.u_leaf[ 0 ].data() );
1928
        //__builtin_prefetch( arg->data.u_leaf[ 1 ].data() );
1929
        //__builtin_prefetch( arg->data.u_leaf[ 2 ].data() );
1930
        //__builtin_prefetch( arg->data.u_leaf[ 3 ].data() );
1931
      }
1932
      else
1933
      {
1934
        auto &u_lskel = arg->lchild->data.u_skel;
1935
        __builtin_prefetch( u_lskel.data() );
1936
        auto &u_rskel = arg->rchild->data.u_skel;
1937
        __builtin_prefetch( u_rskel.data() );
1938
      }
1939
#ifdef HMLP_USE_CUDA
1940
      hmlp::Device *device = NULL;
1941
      if ( user_worker ) device = user_worker->GetDevice();
1942
      if ( device )
1943
      {
1944
        int stream_id = arg->treelist_id % 8;
1945
        proj.CacheD( device );
1946
        proj.PrefetchH2D( device, stream_id );
1947
        u_skel.CacheD( device );
1948
        u_skel.PrefetchH2D( device, stream_id );
1949
        if ( arg->isleaf )
1950
        {
1951
        }
1952
        else
1953
        {
1954
          auto &u_lskel = arg->lchild->data.u_skel;
1955
          u_lskel.CacheD( device );
1956
          u_lskel.PrefetchH2D( device, stream_id );
1957
          auto &u_rskel = arg->rchild->data.u_skel;
1958
          u_rskel.CacheD( device );
1959
          u_rskel.PrefetchH2D( device, stream_id );
1960
        }
1961
      }
1962
#endif
1963
    };
1964
1965
    void DependencyAnalysis() { arg->DependOnParent( this ); };
1966
1967
    void Execute( Worker* user_worker )
1968
    {
1969
#ifdef HMLP_USE_CUDA
1970
      Device *device = NULL;
1971
      if ( user_worker ) device = user_worker->GetDevice();
1972
      if ( device ) gpu::SkeletonsToNodes<NNPRUNE, NODE, T>( device, arg );
1973
      else               SkeletonsToNodes<NNPRUNE, NODE, T>( arg );
1974
#else
1975
    SkeletonsToNodes( arg );
1976
#endif
1977
    };
1978
1979
}; /** end class SkeletonsToNodesTask */
1980
1981
1982
1983
template<int SUBTASKID, bool NNPRUNE, typename NODE, typename T>
1984
void LeavesToLeaves( NODE *node, size_t itbeg, size_t itend )
1985
{
1986
  assert( node->isleaf );
1987
1988
  double beg, u_leaf_time, before_writeback_time, after_writeback_time;
1989
1990
  /** gather shared data and create reference */
1991
  auto &K = *node->setup->K;
1992
  auto &w = *node->setup->w;
1993
1994
  auto &gids = node->gids;
1995
  auto &data = node->data;
1996
  auto &amap = node->gids;
1997
  auto &NearKab = data.NearKab;
1998
1999
  size_t nrhs = w.col();
2000
2001
  set<NODE*> *NearNodes;
2002
  if ( NNPRUNE ) NearNodes = &node->NNNearNodes;
2003
  else           NearNodes = &node->NearNodes;
2004
2005
  /** TODO: I think there may be a performance bug here.
2006
   *        Overall there will be 4 task
2007
   **/
2008
  auto &u_leaf = data.u_leaf[ SUBTASKID ];
2009
  u_leaf.resize( 0, 0 );
2010
2011
  /** early return if nothing to do */
2012
  if ( itbeg == itend )
2013
  {
2014
    return;
2015
  }
2016
  else
2017
  {
2018
    u_leaf.resize( gids.size(), nrhs, 0.0 );
2019
  }
2020
2021
  if ( NearKab.size() ) /** Kab is cached */
2022
  {
2023
    size_t itptr = 0;
2024
    size_t offset = 0;
2025
2026
    for ( auto it = NearNodes->begin(); it != NearNodes->end(); it ++ )
2027
    {
2028
      if ( itbeg <= itptr && itptr < itend )
2029
      {
2030
        //auto wb = w( bmap );
2031
        auto wb = (*it)->data.w_leaf;
2032
2033
        if ( wb.size() )
2034
        {
2035
          /** Kab * wb */
2036
          xgemm
2037
          (
2038
            "N", "N",
2039
            u_leaf.row(), u_leaf.col(), wb.row(),
2040
            1.0, NearKab.data() + offset * NearKab.row(), NearKab.row(),
2041
                      wb.data(),                               wb.row(),
2042
            1.0,  u_leaf.data(),                           u_leaf.row()
2043
          );
2044
        }
2045
        else
2046
        {
2047
          View<T> W = (*it)->data.w_view;
2048
          xgemm
2049
          (
2050
            "N", "N",
2051
            u_leaf.row(), u_leaf.col(), W.row(),
2052
            1.0, NearKab.data() + offset * NearKab.row(), NearKab.row(),
2053
                       W.data(),                                 W.ld(),
2054
            1.0,  u_leaf.data(),                           u_leaf.row()
2055
          );
2056
        }
2057
      }
2058
      offset += (*it)->gids.size();
2059
      itptr ++;
2060
    }
2061
  }
2062
  else /** TODO: make xgemm into NN instead of NT. Kab is not cached */
2063
  {
2064
    size_t itptr = 0;
2065
    for ( auto it = NearNodes->begin(); it != NearNodes->end(); it ++ )
2066
    {
2067
      if ( itbeg <= itptr && itptr < itend )
2068
      {
2069
        auto &bmap = (*it)->gids;
2070
        auto wb = (*it)->data.w_leaf;
2071
2072
        /** evaluate the submatrix */
2073
        auto Kab = K( amap, bmap );
2074
2075
        if ( wb.size() )
2076
        {
2077
          /** ( Kab * wb )' = wb' * Kab' */
2078
          xgemm( "N", "N", u_leaf.row(), u_leaf.col(), wb.row(),
2079
            1.0,    Kab.data(),    Kab.row(),
2080
                     wb.data(),     wb.row(),
2081
            1.0, u_leaf.data(), u_leaf.row());
2082
        }
2083
        else
2084
        {
2085
          View<T> W = (*it)->data.w_view;
2086
          xgemm( "N", "N", u_leaf.row(), u_leaf.col(), W.row(),
2087
            1.0,    Kab.data(),    Kab.row(),
2088
                      W.data(),       W.ld(),
2089
            1.0, u_leaf.data(), u_leaf.row() );
2090
        }
2091
      }
2092
      itptr ++;
2093
    }
2094
  }
2095
  before_writeback_time = omp_get_wtime() - beg;
2096
2097
}; /** end LeavesToLeaves() */
2098
2099
2100
template<int SUBTASKID, bool NNPRUNE, typename NODE, typename T>
2101
class LeavesToLeavesTask : public Task
2102
{
2103
  public:
2104
2105
    NODE *arg = NULL;
2106
2107
    size_t itbeg;
2108
2109
	  size_t itend;
2110
2111
    void Set( NODE *user_arg )
2112
    {
2113
      arg = user_arg;
2114
      name = string( "l2l" );
2115
      label = to_string( arg->treelist_id );
2116
2117
      /** TODO: fill in flops and mops */
2118
      //--------------------------------------
2119
      double flops = 0.0, mops = 0.0;
2120
      auto &gids = arg->gids;
2121
      auto &data = arg->data;
2122
      auto &proj = data.proj;
2123
      auto &skels = data.skels;
2124
      auto &w = *arg->setup->w;
2125
      auto &K = *arg->setup->K;
2126
      auto &NearKab = data.NearKab;
2127
2128
      assert( arg->isleaf );
2129
2130
      size_t m = gids.size();
2131
      size_t n = w.col();
2132
2133
      set<NODE*> *NearNodes;
2134
      if ( NNPRUNE ) NearNodes = &arg->NNNearNodes;
2135
      else           NearNodes = &arg->NearNodes;
2136
2137
      /** TODO: need to better decide the range [itbeg itend] */
2138
      size_t itptr = 0;
2139
      size_t itrange = ( NearNodes->size() + 3 ) / 4;
2140
      if ( itrange < 1 ) itrange = 1;
2141
      itbeg = ( SUBTASKID - 1 ) * itrange;
2142
      itend = ( SUBTASKID + 0 ) * itrange;
2143
      if ( itbeg > NearNodes->size() ) itbeg = NearNodes->size();
2144
      if ( itend > NearNodes->size() ) itend = NearNodes->size();
2145
      if ( SUBTASKID == 4 ) itend = NearNodes->size();
2146
2147
      for ( auto it = NearNodes->begin(); it != NearNodes->end(); it ++ )
2148
      {
2149
        if ( itbeg <= itptr && itptr < itend )
2150
        {
2151
          size_t k = (*it)->gids.size();
2152
          flops += 2.0 * m * n * k;
2153
          mops += m * k;
2154
          mops += 2.0 * ( m * n + n * k + m * k );
2155
        }
2156
        itptr ++;
2157
      }
2158
2159
      /** setup the event */
2160
      event.Set( label + name, flops, mops );
2161
2162
      /** asuume computation bound */
2163
      cost = flops / 1E+9;
2164
    };
2165
2166
    void Prefetch( Worker* user_worker )
2167
    {
2168
      auto &u_leaf = arg->data.u_leaf[ SUBTASKID ];
2169
      __builtin_prefetch( u_leaf.data() );
2170
    };
2171
2172
    void GetEventRecord()
2173
    {
2174
      /** create l2l event */
2175
      //arg->data.s2n = event;
2176
    };
2177
2178
    void DependencyAnalysis()
2179
    {
2180
      assert( arg->isleaf );
2181
      /** depends on nothing */
2182
      this->TryEnqueue();
2183
2184
      /** impose rw dependencies on multiple copies */
2185
      //auto &u_leaf = arg->data.u_leaf[ SUBTASKID ];
2186
      //u_leaf.DependencyAnalysis( hmlp::ReadWriteType::W, this );
2187
    };
2188
2189
    void Execute( Worker* user_worker )
2190
    {
2191
      LeavesToLeaves<SUBTASKID, NNPRUNE, NODE, T>( arg, itbeg, itend );
2192
    };
2193
2194
}; /** end class LeavesToLeaves */
2195
2196
2197
2198
2199
template<typename NODE>
2200
void PrintSet( set<NODE*> &set )
2201
{
2202
  for ( auto it = set.begin(); it != set.end(); it ++ )
2203
  {
2204
    printf( "%lu, ", (*it)->treelist_id );
2205
  }
2206
  printf( "\n" );
2207
}; /** end PrintSet() */
2208
2209
2210
2211
2212
2213
/**
2214
 *
2215
 */
2216
template<typename NODE>
2217
multimap<size_t, size_t> NearNodeBallots( NODE *node )
2218
{
2219
  /** Must be a leaf node. */
2220
  assert( node->isleaf );
2221
2222
  auto &setup = *(node->setup);
2223
  auto &NN = *(setup.NN);
2224
  auto &gids = node->gids;
2225
2226
  /** Ballot table ( node MortonID, ids ) */
2227
  map<size_t, size_t> ballot;
2228
2229
  size_t HasMissingNeighbors = 0;
2230
2231
2232
  /** Loop over all neighbors and insert them into tables. */
2233
  for ( size_t j = 0; j < gids.size(); j ++ )
2234
  {
2235
    for ( size_t i = 0; i < NN.row(); i ++ )
2236
    {
2237
      auto value = NN( i, gids[ j ] ).first;
2238
      size_t neighbor_gid = NN( i, gids[ j ] ).second;
2239
      /** If this gid is valid, then compute its morton */
2240
      if ( neighbor_gid >= 0 && neighbor_gid < NN.col() )
2241
      {
2242
        size_t neighbor_morton = setup.morton[ neighbor_gid ];
2243
        size_t weighted_ballot = 1.0 / ( value + 1E-3 );
2244
        //printf( "gid %lu i %lu neighbor_gid %lu morton %lu\n", gids[ j ], i,
2245
        //    neighbor_gid, neighbor_morton );
2246
2247
        if (  i < NN.row() / 2 )
2248
        {
2249
          if ( ballot.find( neighbor_morton ) != ballot.end() )
2250
          {
2251
            //ballot[ neighbor_morton ] ++;
2252
            ballot[ neighbor_morton ] += weighted_ballot;
2253
          }
2254
          else
2255
          {
2256
            //ballot[ neighbor_morton ] = 1;
2257
            ballot[ neighbor_morton ] = weighted_ballot;
2258
          }
2259
        }
2260
      }
2261
      else
2262
      {
2263
        HasMissingNeighbors ++;
2264
      }
2265
    }
2266
  }
2267
2268
  if ( HasMissingNeighbors )
2269
  {
2270
    printf( "Missing %lu neighbor pairs\n", HasMissingNeighbors );
2271
    fflush( stdout );
2272
  }
2273
2274
  /** Flip ballot to create sorted_ballot. */
2275
  return flip_map( ballot );
2276
2277
}; /** end NearNodeBallots() */
2278
2279
2280
2281
2282
template<typename NODE, typename T>
2283
void NearSamples( NODE *node )
2284
{
2285
  auto &setup = *(node->setup);
2286
  auto &NN = *(setup.NN);
2287
2288
  if ( node->isleaf )
2289
  {
2290
    auto &gids = node->gids;
2291
    //double budget = setup.budget;
2292
    double budget = setup.Budget();
2293
    size_t n_nodes = ( 1 << node->l );
2294
2295
    /** Add myself to the near interaction list.  */
2296
    node->NearNodes.insert( node );
2297
    node->NNNearNodes.insert( node );
2298
    node->NNNearNodeMortonIDs.insert( node->morton );
2299
2300
    /** Compute ballots for all near interactions */
2301
    multimap<size_t, size_t> sorted_ballot = NearNodeBallots( node );
2302
2303
    /** Insert near node cadidates until reaching the budget limit. */
2304
    for ( auto it = sorted_ballot.rbegin(); it != sorted_ballot.rend(); it ++ )
2305
    {
2306
      /** Exit if we have enough. */
2307
      if ( node->NNNearNodes.size() >= n_nodes * budget ) break;
2308
      /** Insert */
2309
      auto *target = (*node->morton2node)[ (*it).second ];
2310
      node->NNNearNodeMortonIDs.insert( (*it).second );
2311
      node->NNNearNodes.insert( target );
2312
    }
2313
  }
2314
2315
}; /** void NearSamples() */
2316
2317
2318
2319
template<typename NODE, typename T>
2320
class NearSamplesTask : public Task
2321
{
2322
  public:
2323
2324
    NODE *arg = NULL;
2325
2326
    void Set( NODE *user_arg )
2327
    {
2328
      arg = user_arg;
2329
      name = string( "near" );
2330
2331
      //--------------------------------------
2332
      double flops = 0.0, mops = 0.0;
2333
2334
      /** setup the event */
2335
      event.Set( label + name, flops, mops );
2336
      /** asuume computation bound */
2337
      cost = 1.0;
2338
      /** low priority */
2339
      priority = true;
2340
    }
2341
2342
    void DependencyAnalysis() { this->TryEnqueue(); };
2343
2344
    void Execute( Worker* user_worker )
2345
    {
2346
      NearSamples<NODE, T>( arg );
2347
    };
2348
2349
}; /** end class NearSamplesTask */
2350
2351
2352
template<typename TREE>
2353
void SymmetrizeNearInteractions( TREE & tree )
2354
{
2355
  int n_nodes = 1 << tree.depth;
2356
  auto level_beg = tree.treelist.begin() + n_nodes - 1;
2357
2358
  for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2359
  {
2360
    auto *node = *(level_beg + node_ind);
2361
    auto & NearMortonIDs = node->NNNearNodeMortonIDs;
2362
    for ( auto & it : NearMortonIDs )
2363
    {
2364
      auto *target = tree.morton2node[ it ];
2365
      target->NNNearNodes.insert( node );
2366
      target->NNNearNodeMortonIDs.insert( it );
2367
    }
2368
  }
2369
}; /** end SymmetrizeNearInteractions() */
2370
2371
2372
/** @brief Task wrapper for CacheNearNodes(). */
2373
template<bool NNPRUNE, typename NODE>
2374
class CacheNearNodesTask : public Task
2375
{
2376
  public:
2377
2378
    NODE *arg;
2379
2380
    void Set( NODE *user_arg )
2381
    {
2382
      arg = user_arg;
2383
      name = string( "c-n" );
2384
      label = to_string( arg->treelist_id );
2385
      /** asuume computation bound */
2386
      cost = 1.0;
2387
    };
2388
2389
    void GetEventRecord()
2390
    {
2391
      double flops = 0.0, mops = 0.0;
2392
2393
      NODE *node = arg;
2394
      auto *NearNodes = &node->NearNodes;
2395
      if ( NNPRUNE ) NearNodes = &node->NNNearNodes;
2396
      auto &K = *node->setup->K;
2397
2398
      size_t m = node->gids.size();
2399
      size_t n = 0;
2400
      for ( auto it = NearNodes->begin(); it != NearNodes->end(); it ++ )
2401
      {
2402
        n += (*it)->gids.size();
2403
      }
2404
      /** setup the event */
2405
      event.Set( label + name, flops, mops );
2406
    };
2407
2408
    void DependencyAnalysis() { arg->DependOnNoOne( this ); };
2409
2410
    void Execute( Worker* user_worker )
2411
    {
2412
      //printf( "%lu CacheNearNodes beg\n", arg->treelist_id ); fflush( stdout );
2413
2414
      NODE *node = arg;
2415
      auto *NearNodes = &node->NearNodes;
2416
      if ( NNPRUNE ) NearNodes = &node->NNNearNodes;
2417
      auto &K = *node->setup->K;
2418
      auto &data = node->data;
2419
      auto &amap = node->gids;
2420
      vector<size_t> bmap;
2421
      for ( auto it = NearNodes->begin(); it != NearNodes->end(); it ++ )
2422
      {
2423
        bmap.insert( bmap.end(), (*it)->gids.begin(), (*it)->gids.end() );
2424
      }
2425
      data.NearKab = K( amap, bmap );
2426
2427
      /** */
2428
      data.Nearbmap.resize( bmap.size(), 1 );
2429
      for ( size_t i = 0; i < bmap.size(); i ++ )
2430
        data.Nearbmap[ i ] = bmap[ i ];
2431
2432
#ifdef HMLP_USE_CUDA
2433
      auto *device = hmlp_get_device( 0 );
2434
      /** prefetch Nearbmap to GPU */
2435
      node->data.Nearbmap.PrefetchH2D( device, 8 );
2436
2437
      size_t preserve_size = 3000000000;
2438
      //if ( data.NearKab.col() * MAX_NRHS < 1200000000 &&
2439
      //     data.NearKab.size() * 8 + preserve_size < device->get_memory_left() &&
2440
      //     data.NearKab.size() * 8 > 4096 * 4096 * 8 * 4 )
2441
      if ( data.NearKab.col() * MAX_NRHS < 1200000000 &&
2442
           data.NearKab.size() * 8 + preserve_size < device->get_memory_left() )
2443
      {
2444
        /** prefetch NearKab to GPU */
2445
        data.NearKab.PrefetchH2D( device, 8 );
2446
      }
2447
      else
2448
      {
2449
        printf( "Kab %lu %lu not cache\n", data.NearKab.row(), data.NearKab.col() );
2450
      }
2451
#endif
2452
2453
      //printf( "%lu CacheNearNodesTask end\n", arg->treelist_id ); fflush( stdout );
2454
    };
2455
}; /** end class CacheNearNodesTask */
2456
2457
2458
2459
/**
2460
 *  @brief (FMM specific) find Far( target ) by traversing all treenodes
2461
 *         top-down.
2462
 *         If the visiting ``node'' does not contain any near node
2463
 *         of ``target'' (by MORTON helper function ContainAny() ),
2464
 *         then we add ``node'' to Far( target ).
2465
 *
2466
 *         Otherwise, recurse to two children.
2467
 */
2468
template<typename NODE>
2469
void FindFarNodes( NODE *node, NODE *target )
2470
{
2471
  /** all assertions, ``target'' must be a leaf node */
2472
  assert( target->isleaf );
2473
2474
  /** get a list of near nodes from target */
2475
  set<NODE*> *NearNodes;
2476
  auto &data = node->data;
2477
  auto *lchild = node->lchild;
2478
  auto *rchild = node->rchild;
2479
2480
  /**
2481
   *  case: !NNPRUNE
2482
   *
2483
   *  Build NearNodes for pure hierarchical low-rank approximation.
2484
   *  In this case, Near( target ) only contains target itself.
2485
   *
2486
   **/
2487
  NearNodes = &target->NearNodes;
2488
2489
  /** If this node contains any Near( target ) or isn't skeletonized */
2490
  if ( !data.isskel || node->ContainAny( *NearNodes ) )
2491
  {
2492
    if ( !node->isleaf )
2493
    {
2494
      /** Recurs to two children */
2495
      FindFarNodes( lchild, target );
2496
      FindFarNodes( rchild, target );
2497
    }
2498
  }
2499
  else
2500
  {
2501
    /** Insert ``node'' to Far( target ) */
2502
    target->FarNodes.insert( node );
2503
  }
2504
2505
  /**
2506
   *  case: NNPRUNE
2507
   *
2508
   *  Build NNNearNodes for the FMM approximation.
2509
   *  Near( target ) contains other leaf nodes
2510
   *
2511
   **/
2512
  NearNodes = &target->NNNearNodes;
2513
2514
  /** If this node contains any Near( target ) or isn't skeletonized */
2515
  if ( !data.isskel || node->ContainAny( *NearNodes ) )
2516
  {
2517
    if ( !node->isleaf )
2518
    {
2519
      /** Recurs to two children */
2520
      FindFarNodes( lchild, target );
2521
      FindFarNodes( rchild, target );
2522
    }
2523
  }
2524
  else
2525
  {
2526
    if ( node->setup->IsSymmetric() && ( node->morton < target->morton ) )
2527
    {
2528
      /** since target->morton is larger than the visiting node,
2529
       * the interaction between the target and this node has
2530
       * been computed.
2531
       */
2532
    }
2533
    else
2534
    {
2535
      target->NNFarNodes.insert( node );
2536
    }
2537
  }
2538
2539
}; /** end FindFarNodes() */
2540
2541
2542
2543
/**
2544
 *  @brief (FMM specific) perform an bottom-up traversal to build
2545
 *         Far( node ) for each node. Leaf nodes call
2546
 *         FindFarNodes(), and inner nodes will merge two Far lists
2547
 *         from lchild and rchild.
2548
 *
2549
 *  @TODO  change to task.
2550
 *
2551
 */
2552
template<typename TREE>
2553
void MergeFarNodes( TREE &tree )
2554
{
2555
  for ( int l = tree.depth; l >= 0; l -- )
2556
  {
2557
    size_t n_nodes = ( 1 << l );
2558
    auto level_beg = tree.treelist.begin() + n_nodes - 1;
2559
2560
    for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2561
    {
2562
      auto *node = *(level_beg + node_ind);
2563
2564
      /** if I don't have any skeleton, then I'm nobody's far field */
2565
      if ( !node->data.isskel ) continue;
2566
2567
      if ( node->isleaf )
2568
      {
2569
        FindFarNodes( tree.treelist[ 0 ] /** root */, node );
2570
      }
2571
      else
2572
      {
2573
        /** merge Far( lchild ) and Far( rchild ) from children */
2574
        auto *lchild = node->lchild;
2575
        auto *rchild = node->rchild;
2576
2577
        /** case: !NNPRUNE (HSS specific) */
2578
        auto &pFarNodes =   node->FarNodes;
2579
        auto &lFarNodes = lchild->FarNodes;
2580
        auto &rFarNodes = rchild->FarNodes;
2581
        /** Far( parent ) = Far( lchild ) intersects Far( rchild ) */
2582
        for ( auto it = lFarNodes.begin(); it != lFarNodes.end(); ++ it )
2583
        {
2584
          if ( rFarNodes.count( *it ) ) pFarNodes.insert( *it );
2585
        }
2586
        /** Far( lchild ) \= Far( parent ); Far( rchild ) \= Far( parent ) */
2587
        for ( auto it = pFarNodes.begin(); it != pFarNodes.end(); it ++ )
2588
        {
2589
          lFarNodes.erase( *it ); rFarNodes.erase( *it );
2590
        }
2591
2592
2593
        /** case: NNPRUNE (FMM specific) */
2594
        auto &pNNFarNodes =   node->NNFarNodes;
2595
        auto &lNNFarNodes = lchild->NNFarNodes;
2596
        auto &rNNFarNodes = rchild->NNFarNodes;
2597
2598
        //printf( "node %lu\n", node->treelist_id );
2599
        //PrintSet( pNNFarNodes );
2600
        //PrintSet( lNNFarNodes );
2601
        //PrintSet( rNNFarNodes );
2602
2603
2604
        /** Far( parent ) = Far( lchild ) intersects Far( rchild ) */
2605
        for ( auto it = lNNFarNodes.begin(); it != lNNFarNodes.end(); ++ it )
2606
        {
2607
          if ( rNNFarNodes.count( *it ) ) pNNFarNodes.insert( *it );
2608
        }
2609
        /** Far( lchild ) \= Far( parent ); Far( rchild ) \= Far( parent ) */
2610
        for ( auto it = pNNFarNodes.begin(); it != pNNFarNodes.end(); it ++ )
2611
        {
2612
          lNNFarNodes.erase( *it );
2613
          rNNFarNodes.erase( *it );
2614
        }
2615
2616
        //PrintSet( pNNFarNodes );
2617
        //PrintSet( lNNFarNodes );
2618
        //PrintSet( rNNFarNodes );
2619
      }
2620
    }
2621
  }
2622
2623
  if ( tree.setup.IsSymmetric() )
2624
  {
2625
    /** symmetrinize FarNodes to FarNodes interaction */
2626
    for ( int l = tree.depth; l >= 0; l -- )
2627
    {
2628
      std::size_t n_nodes = 1 << l;
2629
      auto level_beg = tree.treelist.begin() + n_nodes - 1;
2630
2631
      for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2632
      {
2633
        auto *node = *(level_beg + node_ind);
2634
        auto &pFarNodes = node->NNFarNodes;
2635
        for ( auto it = pFarNodes.begin(); it != pFarNodes.end(); it ++ )
2636
        {
2637
          (*it)->NNFarNodes.insert( node );
2638
        }
2639
      }
2640
    }
2641
  }
2642
2643
#ifdef DEBUG_SPDASKIT
2644
  for ( int l = tree.depth; l >= 0; l -- )
2645
  {
2646
    std::size_t n_nodes = 1 << l;
2647
    auto level_beg = tree.treelist.begin() + n_nodes - 1;
2648
2649
    for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2650
    {
2651
      auto *node = *(level_beg + node_ind);
2652
      auto &pFarNodes =   node->NNFarNodes;
2653
      for ( auto it = pFarNodes.begin(); it != pFarNodes.end(); it ++ )
2654
      {
2655
        if ( !( (*it)->NNFarNodes.count( node ) ) )
2656
        {
2657
          printf( "Unsymmetric FarNodes %lu, %lu\n", node->treelist_id, (*it)->treelist_id );
2658
          printf( "Near\n" );
2659
          PrintSet(  node->NNNearNodes );
2660
          PrintSet( (*it)->NNNearNodes );
2661
          printf( "Far\n" );
2662
          PrintSet(  node->NNFarNodes );
2663
          PrintSet( (*it)->NNFarNodes );
2664
          printf( "======\n" );
2665
          break;
2666
        }
2667
      }
2668
      if ( pFarNodes.size() )
2669
      {
2670
        printf( "l %2lu FarNodes(%lu) ", node->l, node->treelist_id );
2671
        PrintSet( pFarNodes );
2672
      }
2673
    }
2674
  }
2675
#endif
2676
};
2677
2678
2679
/**
2680
 *  @brief Evaluate and store all submatrices Kba used in the Far
2681
 *         interaction.
2682
 *
2683
 *  @TODO  Take care the HSS case i.e. (!NNPRUNE)
2684
 *
2685
 */
2686
template<bool NNPRUNE, bool CACHE = true, typename TREE>
2687
void CacheFarNodes( TREE &tree )
2688
{
2689
  /** reserve space for w_leaf and u_leaf */
2690
  #pragma omp parallel for schedule( dynamic )
2691
  for ( size_t i = 0; i < tree.treelist.size(); i ++ )
2692
  {
2693
    auto *node = tree.treelist[ i ];
2694
    if ( node->isleaf )
2695
    {
2696
      node->data.w_leaf.reserve( node->gids.size(), MAX_NRHS );
2697
      node->data.u_leaf[ 0 ].reserve( MAX_NRHS, node->gids.size() );
2698
    }
2699
  }
2700
2701
  /** cache Kab by request */
2702
  if ( CACHE )
2703
  {
2704
    /** cache FarKab */
2705
    #pragma omp parallel for schedule( dynamic )
2706
    for ( size_t i = 0; i < tree.treelist.size(); i ++ )
2707
    {
2708
      auto *node = tree.treelist[ i ];
2709
      auto *FarNodes = &node->FarNodes;
2710
      if ( NNPRUNE ) FarNodes = &node->NNFarNodes;
2711
      auto &K = *node->setup->K;
2712
      auto &data = node->data;
2713
      auto &amap = data.skels;
2714
      std::vector<size_t> bmap;
2715
      for ( auto it = FarNodes->begin(); it != FarNodes->end(); it ++ )
2716
      {
2717
        bmap.insert( bmap.end(), (*it)->data.skels.begin(),
2718
                                 (*it)->data.skels.end() );
2719
      }
2720
      data.FarKab = K( amap, bmap );
2721
    }
2722
  }
2723
}; /** end CacheFarNodes() */
2724
2725
2726
/**
2727
 *  @brief
2728
 */
2729
template<bool NNPRUNE, typename TREE>
2730
double DrawInteraction( TREE &tree )
2731
{
2732
  double exact_ratio = 0.0;
2733
  FILE * pFile;
2734
  //int n;
2735
  char name[ 100 ];
2736
2737
  pFile = fopen ( "interaction.m", "w" );
2738
2739
  fprintf( pFile, "figure('Position',[100,100,800,800]);" );
2740
  fprintf( pFile, "hold on;" );
2741
  fprintf( pFile, "axis square;" );
2742
  fprintf( pFile, "axis ij;" );
2743
2744
  for ( int l = tree.depth; l >= 0; l -- )
2745
  {
2746
    std::size_t n_nodes = 1 << l;
2747
    auto level_beg = tree.treelist.begin() + n_nodes - 1;
2748
2749
    for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2750
    {
2751
      auto *node = *(level_beg + node_ind);
2752
2753
      if ( NNPRUNE )
2754
      {
2755
        auto &pNearNodes = node->NNNearNodes;
2756
        auto &pFarNodes = node->NNFarNodes;
2757
        for ( auto it = pFarNodes.begin(); it != pFarNodes.end(); it ++ )
2758
        {
2759
          double gb = (double)std::min( node->l, (*it)->l ) / tree.depth;
2760
          //printf( "node->l %lu (*it)->l %lu depth %lu\n", node->l, (*it)->l, tree.depth );
2761
          fprintf( pFile, "rectangle('position',[%lu %lu %lu %lu],'facecolor',[1.0,%lf,%lf]);\n",
2762
              node->offset,      (*it)->offset,
2763
              node->gids.size(), (*it)->gids.size(),
2764
              gb, gb );
2765
        }
2766
        for ( auto it = pNearNodes.begin(); it != pNearNodes.end(); it ++ )
2767
        {
2768
          fprintf( pFile, "rectangle('position',[%lu %lu %lu %lu],'facecolor',[0.2,0.4,1.0]);\n",
2769
              node->offset,      (*it)->offset,
2770
              node->gids.size(), (*it)->gids.size() );
2771
2772
          /** accumulate exact evaluation */
2773
          exact_ratio += node->gids.size() * (*it)->gids.size();
2774
        }
2775
      }
2776
      else
2777
      {
2778
      }
2779
    }
2780
  }
2781
  fprintf( pFile, "hold off;" );
2782
  fclose( pFile );
2783
2784
  return exact_ratio / ( tree.n * tree.n );
2785
}; /** end DrawInteration() */
2786
2787
2788
2789
2790
/**
2791
 *  @breif This is a fake evaluation setup aimming to figure out
2792
 *         which tree node will prun which points. The result
2793
 *         will be stored in each node as two lists, prune and noprune.
2794
 *
2795
 */
2796
template<bool SYMBOLIC, bool NNPRUNE, typename NODE, typename T>
2797
void Evaluate
2798
(
2799
  NODE *node,
2800
  size_t gid,
2801
  vector<size_t> &nnandi, // k + 1 non-prunable lists
2802
  Data<T> &potentials
2803
)
2804
{
2805
  auto &K = *node->setup->K;
2806
  auto &w = *node->setup->w;
2807
  auto &gids = node->gids;
2808
  auto &data = node->data;
2809
  auto *lchild = node->lchild;
2810
  auto *rchild = node->rchild;
2811
2812
  size_t nrhs = w.col();
2813
2814
  auto amap = std::vector<size_t>( 1 );
2815
  amap[ 0 ] = gid;
2816
2817
  if ( !SYMBOLIC ) // No potential evaluation.
2818
  {
2819
    assert( potentials.size() == amap.size() * nrhs );
2820
  }
2821
2822
  if ( !data.isskel || node->ContainAny( nnandi ) )
2823
  {
2824
    if ( node->isleaf )
2825
    {
2826
      if ( SYMBOLIC )
2827
      {
2828
        /** add gid to notprune list. We use a lock */
2829
        data.lock.Acquire();
2830
        {
2831
          if ( NNPRUNE ) node->NNNearIDs.insert( gid );
2832
          else           node->NearIDs.insert(   gid );
2833
        }
2834
        data.lock.Release();
2835
      }
2836
      else
2837
      {
2838
#ifdef DEBUG_SPDASKIT
2839
        printf( "level %lu direct evaluation\n", node->l );
2840
#endif
2841
        /** amap.size()-by-gids.size() */
2842
        auto Kab = K( amap, gids );
2843
2844
        /** all right hand sides */
2845
        std::vector<size_t> bmap( nrhs );
2846
        for ( size_t j = 0; j < bmap.size(); j ++ )
2847
          bmap[ j ] = j;
2848
2849
        /** gids.size()-by-nrhs */
2850
        auto wb  = w( gids, bmap );
2851
2852
        xgemm
2853
        (
2854
          "N", "N",
2855
          Kab.row(), wb.col(), wb.row(),
2856
          1.0, Kab.data(),        Kab.row(),
2857
                wb.data(),         wb.row(),
2858
          1.0, potentials.data(), potentials.row()
2859
        );
2860
      }
2861
    }
2862
    else
2863
    {
2864
      Evaluate<SYMBOLIC, NNPRUNE>( lchild, gid, nnandi, potentials );
2865
      Evaluate<SYMBOLIC, NNPRUNE>( rchild, gid, nnandi, potentials );
2866
    }
2867
  }
2868
  else // need gid's morton and neighbors' mortons
2869
  {
2870
    //printf( "level %lu is prunable\n", node->l );
2871
    if ( SYMBOLIC )
2872
    {
2873
      data.lock.Acquire();
2874
      {
2875
        // Add gid to prunable list.
2876
        if ( NNPRUNE ) node->FarIDs.insert(   gid );
2877
        else           node->NNFarIDs.insert( gid );
2878
      }
2879
      data.lock.Release();
2880
    }
2881
    else
2882
    {
2883
#ifdef DEBUG_SPDASKIT
2884
      printf( "level %lu is prunable\n", node->l );
2885
#endif
2886
      auto Kab = K( amap, node->data.skels );
2887
      auto &w_skel = node->data.w_skel;
2888
      xgemm
2889
      (
2890
        "N", "N",
2891
        Kab.row(), w_skel.col(), w_skel.row(),
2892
        1.0, Kab.data(),        Kab.row(),
2893
          w_skel.data(),     w_skel.row(),
2894
        1.0, potentials.data(), potentials.row()
2895
      );
2896
    }
2897
  }
2898
2899
2900
2901
}; /** end Evaluate() */
2902
2903
2904
/** @brief Evaluate potentials( gid ) using treecode.
2905
 *         Notice in this case, the approximation is unsymmetric.
2906
 *
2907
 **/
2908
template<bool SYMBOLIC, bool NNPRUNE, typename TREE, typename T>
2909
void Evaluate
2910
(
2911
  TREE &tree,
2912
  size_t gid,
2913
  Data<T> &potentials
2914
)
2915
{
2916
  vector<size_t> nnandi;
2917
  auto &w = *tree.setup.w;
2918
2919
  potentials.clear();
2920
  potentials.resize( 1, w.col(), 0.0 );
2921
2922
  if ( NNPRUNE )
2923
  {
2924
    auto &NN = *tree.setup.NN;
2925
    nnandi.reserve( NN.row() + 1 );
2926
    nnandi.push_back( gid );
2927
    for ( size_t i = 0; i < NN.row(); i ++ )
2928
    {
2929
      nnandi.push_back( NN( i, gid ).second );
2930
    }
2931
#ifdef DEBUG_SPDASKIT
2932
    printf( "nnandi.size() %lu\n", nnandi.size() );
2933
#endif
2934
  }
2935
  else
2936
  {
2937
    nnandi.reserve( 1 );
2938
    nnandi.push_back( gid );
2939
  }
2940
2941
  Evaluate<SYMBOLIC, NNPRUNE>( tree.treelist[ 0 ], gid, nnandi, potentials );
2942
2943
}; /** end Evaluate() */
2944
2945
2946
/**
2947
 *  @brief ComputeAll
2948
 */
2949
template<
2950
  bool     USE_RUNTIME = true,
2951
  bool     USE_OMP_TASK = false,
2952
  bool     NNPRUNE = true,
2953
  bool     CACHE = true,
2954
  typename TREE,
2955
  typename T>
2956
Data<T> Evaluate
2957
(
2958
  TREE &tree,
2959
  Data<T> &weights
2960
)
2961
{
2962
  const bool AUTO_DEPENDENCY = true;
2963
2964
  /** get type NODE = TREE::NODE */
2965
  using NODE = typename TREE::NODE;
2966
2967
  /** all timers */
2968
  double beg, time_ratio, evaluation_time = 0.0;
2969
  double allocate_time, computeall_time;
2970
  double forward_permute_time, backward_permute_time;
2971
2972
  /** clean up all r/w dependencies left on tree nodes */
2973
  tree.DependencyCleanUp();
2974
2975
  /** n-by-nrhs initialize potentials */
2976
  size_t n    = weights.row();
2977
  size_t nrhs = weights.col();
2978
2979
  beg = omp_get_wtime();
2980
  hmlp::Data<T> potentials( n, nrhs, 0.0 );
2981
  tree.setup.w = &weights;
2982
  tree.setup.u = &potentials;
2983
  allocate_time = omp_get_wtime() - beg;
2984
2985
  /** permute weights into w_leaf */
2986
  if ( REPORT_EVALUATE_STATUS )
2987
  {
2988
    printf( "Forward permute ...\n" ); fflush( stdout );
2989
  }
2990
  beg = omp_get_wtime();
2991
  int n_nodes = ( 1 << tree.depth );
2992
  auto level_beg = tree.treelist.begin() + n_nodes - 1;
2993
  #pragma omp parallel for
2994
  for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2995
  {
2996
    auto *node = *(level_beg + node_ind);
2997
2998
2999
    auto &gids = node->gids;
3000
    auto &w_leaf = node->data.w_leaf;
3001
3002
    if ( w_leaf.row() != gids.size() || w_leaf.col() != weights.col() )
3003
    {
3004
      w_leaf.resize( gids.size(), weights.col() );
3005
    }
3006
3007
    for ( size_t j = 0; j < w_leaf.col(); j ++ )
3008
    {
3009
      for ( size_t i = 0; i < w_leaf.row(); i ++ )
3010
      {
3011
        w_leaf( i, j ) = weights( gids[ i ], j );
3012
      }
3013
    };
3014
  }
3015
  forward_permute_time = omp_get_wtime() - beg;
3016
3017
3018
3019
  /** Compute all N2S, S2S, S2N, L2L */
3020
  if ( REPORT_EVALUATE_STATUS )
3021
  {
3022
    printf( "N2S, S2S, S2N, L2L (HMLP Runtime) ...\n" ); fflush( stdout );
3023
  }
3024
  if ( tree.setup.IsSymmetric() )
3025
  {
3026
    beg = omp_get_wtime();
3027
#ifdef HMLP_USE_CUDA
3028
    potentials.AllocateD( hmlp_get_device( 0 ) );
3029
    using LEAFTOLEAFVER2TASK = gpu::LeavesToLeavesVer2Task<CACHE, NNPRUNE, NODE, T>;
3030
    LEAFTOLEAFVER2TASK leaftoleafver2task;
3031
#endif
3032
    using LEAFTOLEAFTASK1 = LeavesToLeavesTask<1, NNPRUNE, NODE, T>;
3033
    using LEAFTOLEAFTASK2 = LeavesToLeavesTask<2, NNPRUNE, NODE, T>;
3034
    using LEAFTOLEAFTASK3 = LeavesToLeavesTask<3, NNPRUNE, NODE, T>;
3035
    using LEAFTOLEAFTASK4 = LeavesToLeavesTask<4, NNPRUNE, NODE, T>;
3036
3037
    using NODETOSKELTASK  = UpdateWeightsTask<NODE, T>;
3038
    using SKELTOSKELTASK  = SkeletonsToSkeletonsTask<NNPRUNE, NODE, T>;
3039
    using SKELTONODETASK  = SkeletonsToNodesTask<NNPRUNE, NODE, T>;
3040
3041
    LEAFTOLEAFTASK1 leaftoleaftask1;
3042
    LEAFTOLEAFTASK2 leaftoleaftask2;
3043
    LEAFTOLEAFTASK3 leaftoleaftask3;
3044
    LEAFTOLEAFTASK4 leaftoleaftask4;
3045
3046
    NODETOSKELTASK  nodetoskeltask;
3047
    SKELTOSKELTASK  skeltoskeltask;
3048
    SKELTONODETASK  skeltonodetask;
3049
3050
3051
//    if ( USE_OMP_TASK )
3052
//    {
3053
//      assert( !USE_RUNTIME );
3054
//      tree.template TraverseLeafs<false, false>( leaftoleaftask1 );
3055
//      tree.template TraverseLeafs<false, false>( leaftoleaftask2 );
3056
//      tree.template TraverseLeafs<false, false>( leaftoleaftask3 );
3057
//      tree.template TraverseLeafs<false, false>( leaftoleaftask4 );
3058
//      tree.template UpDown<true, true, true>( nodetoskeltask, skeltoskeltask, skeltonodetask );
3059
//    }
3060
//    else
3061
//    {
3062
//      assert( !USE_OMP_TASK );
3063
//
3064
//#ifdef HMLP_USE_CUDA
3065
//      tree.template TraverseLeafs<AUTO_DEPENDENCY, USE_RUNTIME>( leaftoleafver2task );
3066
//#else
3067
//      tree.template TraverseLeafs<AUTO_DEPENDENCY, USE_RUNTIME>( leaftoleaftask1 );
3068
//      tree.template TraverseLeafs<AUTO_DEPENDENCY, USE_RUNTIME>( leaftoleaftask2 );
3069
//      tree.template TraverseLeafs<AUTO_DEPENDENCY, USE_RUNTIME>( leaftoleaftask3 );
3070
//      tree.template TraverseLeafs<AUTO_DEPENDENCY, USE_RUNTIME>( leaftoleaftask4 );
3071
//#endif
3072
//
3073
//      /** check scheduler */
3074
//      //hmlp_get_runtime_handle()->scheduler->ReportRemainingTime();
3075
//      tree.template TraverseUp       <AUTO_DEPENDENCY, USE_RUNTIME>( nodetoskeltask );
3076
//      tree.template TraverseUnOrdered<AUTO_DEPENDENCY, USE_RUNTIME>( skeltoskeltask );
3077
//      tree.template TraverseDown     <AUTO_DEPENDENCY, USE_RUNTIME>( skeltonodetask );
3078
//      /** check scheduler */
3079
//      //hmlp_get_runtime_handle()->scheduler->ReportRemainingTime();
3080
//
3081
//      if ( USE_RUNTIME ) hmlp_run();
3082
//
3083
//
3084
//
3085
//#ifdef HMLP_USE_CUDA
3086
//      hmlp::Device *device = hmlp_get_device( 0 );
3087
//      for ( int stream_id = 0; stream_id < 10; stream_id ++ )
3088
//        device->wait( stream_id );
3089
//      //potentials.PrefetchD2H( device, 0 );
3090
//      potentials.FetchD2H( device );
3091
//#endif
3092
//    }
3093
3094
3095
3096
3097
    /** CPU-GPU hybrid uses a different kind of L2L task */
3098
#ifdef HMLP_USE_CUDA
3099
    tree.TraverseLeafs( leaftoleafver2task );
3100
#else
3101
    tree.TraverseLeafs( leaftoleaftask1 );
3102
    tree.TraverseLeafs( leaftoleaftask2 );
3103
    tree.TraverseLeafs( leaftoleaftask3 );
3104
    tree.TraverseLeafs( leaftoleaftask4 );
3105
#endif
3106
    tree.TraverseUp( nodetoskeltask );
3107
    tree.TraverseUnOrdered( skeltoskeltask );
3108
    tree.TraverseDown( skeltonodetask );
3109
    tree.ExecuteAllTasks();
3110
    //if ( USE_RUNTIME ) hmlp_run();
3111
3112
3113
    double d2h_beg_t = omp_get_wtime();
3114
#ifdef HMLP_USE_CUDA
3115
    hmlp::Device *device = hmlp_get_device( 0 );
3116
    for ( int stream_id = 0; stream_id < 10; stream_id ++ )
3117
      device->wait( stream_id );
3118
    //potentials.PrefetchD2H( device, 0 );
3119
    potentials.FetchD2H( device );
3120
#endif
3121
    double d2h_t = omp_get_wtime() - d2h_beg_t;
3122
    printf( "d2h_t %lfs\n", d2h_t );
3123
3124
3125
    double aggregate_beg_t = omp_get_wtime();
3126
    /** reduce direct iteractions from 4 copies */
3127
    #pragma omp parallel for
3128
    for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
3129
    {
3130
      auto *node = *(level_beg + node_ind);
3131
      auto &u_leaf = node->data.u_leaf[ 0 ];
3132
      /** reduce all u_leaf[0:4] */
3133
      for ( size_t p = 1; p < 20; p ++ )
3134
      {
3135
        for ( size_t i = 0; i < node->data.u_leaf[ p ].size(); i ++ )
3136
          u_leaf[ i ] += node->data.u_leaf[ p ][ i ];
3137
      }
3138
    }
3139
    double aggregate_t = omp_get_wtime() - aggregate_beg_t;
3140
    printf( "aggregate_t %lfs\n", d2h_t );
3141
3142
#ifdef HMLP_USE_CUDA
3143
    device->wait( 0 );
3144
#endif
3145
    computeall_time = omp_get_wtime() - beg;
3146
  }
3147
  else // TODO: implement unsymmetric prunning
3148
  {
3149
    /** Not yet implemented. */
3150
    printf( "Non symmetric ComputeAll is not yet implemented\n" );
3151
    exit( 1 );
3152
  }
3153
3154
3155
3156
  /** permute back */
3157
  if ( REPORT_EVALUATE_STATUS )
3158
  {
3159
    printf( "Backward permute ...\n" ); fflush( stdout );
3160
  }
3161
  beg = omp_get_wtime();
3162
  #pragma omp parallel for
3163
  for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
3164
  {
3165
    auto *node = *(level_beg + node_ind);
3166
    auto &amap = node->gids;
3167
    auto &u_leaf = node->data.u_leaf[ 0 ];
3168
3169
3170
3171
    /** assemble u_leaf back to u */
3172
    //for ( size_t j = 0; j < amap.size(); j ++ )
3173
    //  for ( size_t i = 0; i < potentials.row(); i ++ )
3174
    //    potentials[ amap[ j ] * potentials.row() + i ] += u_leaf( j, i );
3175
3176
3177
    for ( size_t j = 0; j < potentials.col(); j ++ )
3178
      for ( size_t i = 0; i < amap.size(); i ++ )
3179
        potentials( amap[ i ], j ) += u_leaf( i, j );
3180
3181
3182
  }
3183
  backward_permute_time = omp_get_wtime() - beg;
3184
3185
  evaluation_time += allocate_time;
3186
  evaluation_time += forward_permute_time;
3187
  evaluation_time += computeall_time;
3188
  evaluation_time += backward_permute_time;
3189
  time_ratio = 100 / evaluation_time;
3190
3191
  if ( REPORT_EVALUATE_STATUS )
3192
  {
3193
    printf( "========================================================\n");
3194
    printf( "GOFMM evaluation phase\n" );
3195
    printf( "========================================================\n");
3196
    printf( "Allocate ------------------------------ %5.2lfs (%5.1lf%%)\n",
3197
        allocate_time, allocate_time * time_ratio );
3198
    printf( "Forward permute ----------------------- %5.2lfs (%5.1lf%%)\n",
3199
        forward_permute_time, forward_permute_time * time_ratio );
3200
    printf( "N2S, S2S, S2N, L2L -------------------- %5.2lfs (%5.1lf%%)\n",
3201
        computeall_time, computeall_time * time_ratio );
3202
    printf( "Backward permute ---------------------- %5.2lfs (%5.1lf%%)\n",
3203
        backward_permute_time, backward_permute_time * time_ratio );
3204
    printf( "========================================================\n");
3205
    printf( "Evaluate ------------------------------ %5.2lfs (%5.1lf%%)\n",
3206
        evaluation_time, evaluation_time * time_ratio );
3207
    printf( "========================================================\n\n");
3208
  }
3209
3210
3211
  /** clean up all r/w dependencies left on tree nodes */
3212
  tree.DependencyCleanUp();
3213
3214
  /** return nrhs-by-N outputs */
3215
  return potentials;
3216
3217
}; /** end Evaluate() */
3218
3219
3220
3221
template<typename SPLITTER, typename T, typename SPDMATRIX>
3222
Data<pair<T, size_t>> FindNeighbors
3223
(
3224
  SPDMATRIX &K,
3225
  SPLITTER splitter,
3226
	Configuration<T> &config,
3227
  size_t n_iter = 10
3228
)
3229
{
3230
  /** Instantiation for the randomisze tree. */
3231
  using DATA  = gofmm::NodeData<T>;
3232
  using SETUP = gofmm::Setup<SPDMATRIX, SPLITTER, T>;
3233
  using TREE  = tree::Tree<SETUP, DATA>;
3234
  /** Derive type NODE from TREE. */
3235
  using NODE  = typename TREE::NODE;
3236
  /** Get all user-defined parameters. */
3237
  DistanceMetric metric = config.MetricType();
3238
  size_t n = config.ProblemSize();
3239
	size_t k = config.NeighborSize();
3240
  /** Iterative all nearnest-neighbor (ANN). */
3241
  pair<T, size_t> init( numeric_limits<T>::max(), n );
3242
  gofmm::NeighborsTask<NODE, T> NEIGHBORStask;
3243
  TREE rkdt;
3244
  rkdt.setup.FromConfiguration( config, K, splitter, NULL );
3245
  return rkdt.AllNearestNeighbor( n_iter, k, n_iter, init, NEIGHBORStask );
3246
}; /** end FindNeighbors() */
3247
3248
3249
/**
3250
 *  @brielf template of the compress routine
3251
 */
3252
template<typename SPLITTER, typename RKDTSPLITTER, typename T, typename SPDMATRIX>
3253
tree::Tree< gofmm::Setup<SPDMATRIX, SPLITTER, T>, gofmm::NodeData<T>>
3254
*Compress
3255
(
3256
  SPDMATRIX &K,
3257
  Data<pair<T, size_t>> &NN,
3258
  SPLITTER splitter,
3259
  RKDTSPLITTER rkdtsplitter,
3260
	Configuration<T> &config
3261
)
3262
{
3263
  /** Get all user-defined parameters. */
3264
  DistanceMetric metric = config.MetricType();
3265
  size_t n = config.ProblemSize();
3266
	size_t m = config.LeafNodeSize();
3267
	size_t k = config.NeighborSize();
3268
	size_t s = config.MaximumRank();
3269
  T stol = config.Tolerance();
3270
	T budget = config.Budget();
3271
3272
  /** options */
3273
  const bool NNPRUNE   = true;
3274
  const bool CACHE     = true;
3275
3276
  /** instantiation for the Spd-Askit tree */
3277
  using SETUP = gofmm::Setup<SPDMATRIX, SPLITTER, T>;
3278
  using DATA  = gofmm::NodeData<T>;
3279
  using TREE  = tree::Tree<SETUP, DATA>;
3280
  /** Derive type NODE from TREE. */
3281
  using NODE  = typename TREE::NODE;
3282
3283
3284
  /** all timers */
3285
  double beg, omptask45_time, omptask_time, ref_time;
3286
  double time_ratio, compress_time = 0.0, other_time = 0.0;
3287
  double ann_time, tree_time, skel_time, mergefarnodes_time, cachefarnodes_time;
3288
  double nneval_time, nonneval_time, fmm_evaluation_time, symbolic_evaluation_time;
3289
3290
  /** Iterative all nearnest-neighbor (ANN). */
3291
  beg = omp_get_wtime();
3292
  if ( NN.size() != n * k )
3293
  {
3294
    NN = gofmm::FindNeighbors( K, rkdtsplitter, config );
3295
  }
3296
  ann_time = omp_get_wtime() - beg;
3297
3298
3299
  /** Initialize metric ball tree using approximate center split. */
3300
  auto *tree_ptr = new TREE();
3301
	auto &tree = *tree_ptr;
3302
  tree.setup.FromConfiguration( config, K, splitter, &NN );
3303
3304
3305
  if ( REPORT_COMPRESS_STATUS )
3306
  {
3307
    printf( "TreePartitioning ...\n" ); fflush( stdout );
3308
  }
3309
  beg = omp_get_wtime();
3310
  tree.TreePartition();
3311
  tree_time = omp_get_wtime() - beg;
3312
3313
3314
#ifdef HMLP_AVX512
3315
  /** if we are using KNL, use nested omp construct */
3316
  assert( omp_get_max_threads() == 68 );
3317
  //mkl_set_dynamic( 0 );
3318
  //mkl_set_num_threads( 4 );
3319
  hmlp_set_num_workers( 17 );
3320
#else
3321
  //if ( omp_get_max_threads() > 8 )
3322
  //{
3323
  //  hmlp_set_num_workers( omp_get_max_threads() / 2 );
3324
  //}
3325
  if ( REPORT_COMPRESS_STATUS )
3326
  {
3327
    printf( "omp_get_max_threads() %d\n", omp_get_max_threads() );
3328
  }
3329
#endif
3330
3331
3332
3333
3334
  /** Build near interaction lists. */
3335
  NearSamplesTask<NODE, T> NEARSAMPLEStask;
3336
  tree.DependencyCleanUp();
3337
  printf( "Dependency clean up\n" ); fflush( stdout );
3338
  tree.TraverseLeafs( NEARSAMPLEStask );
3339
  tree.ExecuteAllTasks();
3340
  //hmlp_run();
3341
  printf( "Finish NearSamplesTask\n" ); fflush( stdout );
3342
  SymmetrizeNearInteractions( tree );
3343
  printf( "Finish SymmetrizeNearInteractions\n" ); fflush( stdout );
3344
3345
3346
3347
  /** Skeletonization */
3348
  if ( REPORT_COMPRESS_STATUS )
3349
  {
3350
    printf( "Skeletonization (HMLP Runtime) ...\n" ); fflush( stdout );
3351
  }
3352
  beg = omp_get_wtime();
3353
  gofmm::SkeletonKIJTask<NNPRUNE, NODE, T> GETMTXtask;
3354
  gofmm::SkeletonizeTask<NODE, T> SKELtask;
3355
  gofmm::InterpolateTask<NODE> PROJtask;
3356
  tree.DependencyCleanUp();
3357
  tree.TraverseUp( GETMTXtask, SKELtask );
3358
  tree.TraverseUnOrdered( PROJtask );
3359
  if ( CACHE )
3360
  {
3361
    gofmm::CacheNearNodesTask<NNPRUNE, NODE> KIJtask;
3362
    tree.template TraverseLeafs( KIJtask );
3363
  }
3364
  other_time += omp_get_wtime() - beg;
3365
  hmlp_run();
3366
  skel_time = omp_get_wtime() - beg;
3367
3368
3369
3370
3371
  /** MergeFarNodes */
3372
  beg = omp_get_wtime();
3373
  if ( REPORT_COMPRESS_STATUS )
3374
  {
3375
    printf( "MergeFarNodes ...\n" ); fflush( stdout );
3376
  }
3377
  gofmm::MergeFarNodes( tree );
3378
  mergefarnodes_time = omp_get_wtime() - beg;
3379
3380
  /** CacheFarNodes */
3381
  beg = omp_get_wtime();
3382
  if ( REPORT_COMPRESS_STATUS )
3383
  {
3384
    printf( "CacheFarNodes ...\n" ); fflush( stdout );
3385
  }
3386
  gofmm::CacheFarNodes<NNPRUNE, CACHE>( tree );
3387
  cachefarnodes_time = omp_get_wtime() - beg;
3388
3389
  /** plot iteraction matrix */
3390
  auto exact_ratio = hmlp::gofmm::DrawInteraction<true>( tree );
3391
3392
  compress_time += ann_time;
3393
  compress_time += tree_time;
3394
  compress_time += skel_time;
3395
  compress_time += mergefarnodes_time;
3396
  compress_time += cachefarnodes_time;
3397
  time_ratio = 100.0 / compress_time;
3398
  if ( REPORT_COMPRESS_STATUS )
3399
  {
3400
    printf( "========================================================\n");
3401
    printf( "GOFMM compression phase\n" );
3402
    printf( "========================================================\n");
3403
    printf( "NeighborSearch ------------------------ %5.2lfs (%5.1lf%%)\n", ann_time, ann_time * time_ratio );
3404
    printf( "TreePartitioning ---------------------- %5.2lfs (%5.1lf%%)\n", tree_time, tree_time * time_ratio );
3405
    printf( "Skeletonization ----------------------- %5.2lfs (%5.1lf%%)\n", skel_time, skel_time * time_ratio );
3406
    printf( "MergeFarNodes ------------------------- %5.2lfs (%5.1lf%%)\n", mergefarnodes_time, mergefarnodes_time * time_ratio );
3407
    printf( "CacheFarNodes ------------------------- %5.2lfs (%5.1lf%%)\n", cachefarnodes_time, cachefarnodes_time * time_ratio );
3408
    printf( "========================================================\n");
3409
    printf( "Compress (%4.2lf not compressed) -------- %5.2lfs (%5.1lf%%)\n",
3410
        exact_ratio, compress_time, compress_time * time_ratio );
3411
    printf( "========================================================\n\n");
3412
  }
3413
3414
  /** Clean up all r/w dependencies left on tree nodes. */
3415
  tree_ptr->DependencyCleanUp();
3416
3417
  /** Return the hierarhical compreesion of K as a binary tree. */
3418
  return tree_ptr;
3419
3420
}; /** end Compress() */
3421
3422
3423
3424
3425
3426
3427
3428
3429
3430
3431
3432
3433
3434
3435
3436
3437
3438
3439
3440
3441
3442
3443
3444
3445
3446
3447
3448
3449
3450
3451
3452
3453
3454
3455
3456
3457
3458
3459
3460
3461
3462
/**
3463
 *  @brielf A simple template for the compress routine.
3464
 */
3465
template<typename T, typename SPDMATRIX>
3466
tree::Tree<
3467
  gofmm::Setup<SPDMATRIX, centersplit<SPDMATRIX, 2, T>, T>,
3468
  gofmm::NodeData<T>>
3469
*Compress( SPDMATRIX &K, T stol, T budget, size_t m, size_t k, size_t s )
3470
{
3471
  using SPLITTER     = centersplit<SPDMATRIX, 2, T>;
3472
  using RKDTSPLITTER = randomsplit<SPDMATRIX, 2, T>;
3473
  Data<pair<T, size_t>> NN;
3474
	/** GOFMM tree splitter */
3475
  SPLITTER splitter( K );
3476
  splitter.Kptr = &K;
3477
	splitter.metric = ANGLE_DISTANCE;
3478
	/** randomized tree splitter */
3479
  RKDTSPLITTER rkdtsplitter( K );
3480
  rkdtsplitter.Kptr = &K;
3481
	rkdtsplitter.metric = ANGLE_DISTANCE;
3482
  size_t n = K.row();
3483
3484
	/** creatgin configuration for all user-define arguments */
3485
	Configuration<T> config( ANGLE_DISTANCE, n, m, k, s, stol, budget );
3486
3487
	/** call the complete interface and return tree_ptr */
3488
  return Compress<SPLITTER, RKDTSPLITTER>
3489
         ( K, NN, //ANGLE_DISTANCE,
3490
					 splitter, rkdtsplitter, //n, m, k, s, stol, budget,
3491
					 config );
3492
}; /** end Compress */
3493
3494
3495
3496
3497
3498
3499
3500
3501
/**
3502
 *  @brielf A simple template for the compress routine.
3503
 */
3504
template<typename T, typename SPDMATRIX>
3505
tree::Tree<
3506
  gofmm::Setup<SPDMATRIX, centersplit<SPDMATRIX, 2, T>, T>,
3507
  gofmm::NodeData<T>>
3508
*Compress( SPDMATRIX &K, T stol, T budget )
3509
{
3510
  using SPLITTER     = centersplit<SPDMATRIX, 2, T>;
3511
  using RKDTSPLITTER = randomsplit<SPDMATRIX, 2, T>;
3512
  Data<pair<T, std::size_t>> NN;
3513
	/** GOFMM tree splitter */
3514
  SPLITTER splitter( K );
3515
  splitter.Kptr = &K;
3516
  splitter.metric = ANGLE_DISTANCE;
3517
	/** randomized tree splitter */
3518
  RKDTSPLITTER rkdtsplitter( K );
3519
  rkdtsplitter.Kptr = &K;
3520
  rkdtsplitter.metric = ANGLE_DISTANCE;
3521
  size_t n = K.row();
3522
  size_t m = 128;
3523
  size_t k = 16;
3524
  size_t s = m;
3525
3526
  /** */
3527
  if ( n >= 16384 )
3528
  {
3529
    m = 128;
3530
    k = 20;
3531
    s = 256;
3532
  }
3533
3534
  if ( n >= 32768 )
3535
  {
3536
    m = 256;
3537
    k = 24;
3538
    s = 384;
3539
  }
3540
3541
  if ( n >= 65536 )
3542
  {
3543
    m = 512;
3544
    k = 32;
3545
    s = 512;
3546
  }
3547
3548
	/** creatgin configuration for all user-define arguments */
3549
	Configuration<T> config( ANGLE_DISTANCE, n, m, k, s, stol, budget );
3550
3551
	/** call the complete interface and return tree_ptr */
3552
  return Compress<SPLITTER, RKDTSPLITTER>
3553
         ( K, NN, //ANGLE_DISTANCE,
3554
					 splitter, rkdtsplitter, config );
3555
3556
}; /** end Compress() */
3557
3558
/**
3559
 *
3560
 */
3561
template<typename T>
3562
tree::Tree<
3563
  gofmm::Setup<SPDMatrix<T>, centersplit<SPDMatrix<T>, 2, T>, T>,
3564
  gofmm::NodeData<T>>
3565
*Compress( SPDMatrix<T> &K, T stol, T budget )
3566
{
3567
	return Compress<T, SPDMatrix<T>>( K, stol, budget );
3568
}; /** end Compress() */
3569
3570
3571
3572
3573
3574
3575
3576
template<typename NODE, typename T>
3577
void ComputeError( NODE *node, Data<T> potentials )
3578
{
3579
  auto &K = *node->setup->K;
3580
  auto &w = node->setup->w;
3581
3582
  auto &amap = node->gids;
3583
  std::vector<size_t> bmap = std::vector<size_t>( K.col() );
3584
3585
  for ( size_t j = 0; j < bmap.size(); j ++ ) bmap[ j ] = j;
3586
3587
  auto Kab = K( amap, bmap );
3588
3589
  auto nrm2 = hmlp_norm( potentials.row(), potentials.col(),
3590
                         potentials.data(), potentials.row() );
3591
3592
  xgemm
3593
  (
3594
    "N", "T",
3595
    Kab.row(), w.row(), w.col(),
3596
    -1.0, Kab.data(),        Kab.row(),
3597
          w.data(),          w.row(),
3598
     1.0, potentials.data(), potentials.row()
3599
  );
3600
3601
  auto err = hmlp_norm( potentials.row(), potentials.col(),
3602
                        potentials.data(), potentials.row() );
3603
3604
  printf( "node relative error %E, nrm2 %E\n", err / nrm2, nrm2 );
3605
3606
3607
}; // end void ComputeError()
3608
3609
3610
3611
3612
3613
3614
3615
3616
/**
3617
 *  @brief
3618
 */
3619
template<typename TREE, typename T>
3620
T ComputeError( TREE &tree, size_t gid, Data<T> potentials )
3621
{
3622
  auto &K = *tree.setup.K;
3623
  auto &w = *tree.setup.w;
3624
3625
  auto amap = std::vector<size_t>( 1, gid );
3626
  auto bmap = std::vector<size_t>( K.col() );
3627
  for ( size_t j = 0; j < bmap.size(); j ++ ) bmap[ j ] = j;
3628
3629
  auto Kab = K( amap, bmap );
3630
  auto exact = potentials;
3631
3632
  xgemm
3633
  (
3634
    "N", "N",
3635
    Kab.row(), w.col(), w.row(),
3636
    1.0,   Kab.data(),   Kab.row(),
3637
             w.data(),     w.row(),
3638
    0.0, exact.data(), exact.row()
3639
  );
3640
3641
3642
  auto nrm2 = hmlp_norm( exact.row(),  exact.col(),
3643
                         exact.data(), exact.row() );
3644
3645
  xgemm
3646
  (
3647
    "N", "N",
3648
    Kab.row(), w.col(), w.row(),
3649
    -1.0, Kab.data(),       Kab.row(),
3650
          w.data(),          w.row(),
3651
     1.0, potentials.data(), potentials.row()
3652
  );
3653
3654
  auto err = hmlp_norm( potentials.row(), potentials.col(),
3655
                        potentials.data(), potentials.row() );
3656
3657
  return err / nrm2;
3658
}; /** end ComputeError() */
3659
3660
3661
3662
template<typename TREE>
3663
void SelfTesting( TREE &tree, size_t ntest, size_t nrhs )
3664
{
3665
  /** Derive type T from TREE. */
3666
  using T = typename TREE::T;
3667
  /** Size of right hand sides. */
3668
  size_t n = tree.n;
3669
  /** Shrink ntest if ntest > n. */
3670
  if ( ntest > n ) ntest = n;
3671
  /** all_rhs = [ 0, 1, ..., nrhs - 1 ]. */
3672
  vector<size_t> all_rhs( nrhs );
3673
  for ( size_t rhs = 0; rhs < nrhs; rhs ++ ) all_rhs[ rhs ] = rhs;
3674
3675
  //auto A = tree.CheckAllInteractions();
3676
3677
  /** Evaluate u ~ K * w. */
3678
  Data<T> w( n, nrhs ); w.rand();
3679
  auto u = Evaluate<true, false, true, true>( tree, w );
3680
3681
  /** Examine accuracy with 3 setups, ASKIT, HODLR, and GOFMM. */
3682
  T nnerr_avg = 0.0;
3683
  T nonnerr_avg = 0.0;
3684
  T fmmerr_avg = 0.0;
3685
  printf( "========================================================\n");
3686
  printf( "Accuracy report\n" );
3687
  printf( "========================================================\n");
3688
  for ( size_t i = 0; i < ntest; i ++ )
3689
  {
3690
    size_t tar = i * n / ntest;
3691
    Data<T> potentials;
3692
    /** ASKIT treecode with NN pruning. */
3693
    Evaluate<false, true>( tree, tar, potentials );
3694
    auto nnerr = ComputeError( tree, tar, potentials );
3695
    /** ASKIT treecode without NN pruning. */
3696
    Evaluate<false, false>( tree, tar, potentials );
3697
    auto nonnerr = ComputeError( tree, tar, potentials );
3698
    /** Get results from GOFMM */
3699
    //potentials = u( vector<size_t>( i ), all_rhs );
3700
    for ( size_t p = 0; p < potentials.col(); p ++ )
3701
    {
3702
      potentials[ p ] = u( tar, p );
3703
    }
3704
    auto fmmerr = ComputeError( tree, tar, potentials );
3705
3706
    /** Only print 10 values. */
3707
    if ( i < 10 )
3708
    {
3709
      printf( "gid %6lu, ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n",
3710
          tar, nnerr, nonnerr, fmmerr );
3711
    }
3712
    nnerr_avg += nnerr;
3713
    nonnerr_avg += nonnerr;
3714
    fmmerr_avg += fmmerr;
3715
  }
3716
  printf( "========================================================\n");
3717
  printf( "            ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n",
3718
      nnerr_avg / ntest , nonnerr_avg / ntest, fmmerr_avg / ntest );
3719
  printf( "========================================================\n");
3720
3721
  if ( !tree.setup.SecureAccuracy() )
3722
  {
3723
    /** Factorization */
3724
    T lambda = 5.0;
3725
    gofmm::Factorize( tree, lambda );
3726
    /** Compute error. */
3727
    gofmm::ComputeError( tree, lambda, w, u );
3728
  }
3729
3730
}; /** end SelfTesting() */
3731
3732
3733
/** @brief Instantiate the splitters here. */
3734
template<typename SPDMATRIX>
3735
void LaunchHelper( SPDMATRIX &K, CommandLineHelper &cmd )
3736
{
3737
  using T = typename SPDMATRIX::T;
3738
3739
  const int N_CHILDREN = 2;
3740
  /** Use geometric-oblivious splitters. */
3741
  using SPLITTER     = gofmm::centersplit<SPDMATRIX, N_CHILDREN, T>;
3742
  using RKDTSPLITTER = gofmm::randomsplit<SPDMATRIX, N_CHILDREN, T>;
3743
  /** GOFMM tree splitter. */
3744
  SPLITTER splitter( K );
3745
  splitter.Kptr = &K;
3746
  splitter.metric = cmd.metric;
3747
  /** Randomized tree splitter. */
3748
  RKDTSPLITTER rkdtsplitter( K );
3749
  rkdtsplitter.Kptr = &K;
3750
  rkdtsplitter.metric = cmd.metric;
3751
	/** Create configuration for all user-define arguments. */
3752
  gofmm::Configuration<T> config( cmd.metric,
3753
      cmd.n, cmd.m, cmd.k, cmd.s, cmd.stol, cmd.budget );
3754
  /** (Optional) provide neighbors, leave uninitialized otherwise. */
3755
  Data<pair<T, size_t>> NN;
3756
  /** Compress K. */
3757
  //auto *tree_ptr = gofmm::Compress( X, K, NN, splitter, rkdtsplitter, config );
3758
  auto *tree_ptr = gofmm::Compress( K, NN, splitter, rkdtsplitter, config );
3759
	auto &tree = *tree_ptr;
3760
  /** Examine accuracies. */
3761
  gofmm::SelfTesting( tree, 100, cmd.nrhs );
3762
3763
3764
//  //#ifdef DUMP_ANALYSIS_DATA
3765
//  gofmm::Summary<NODE> summary;
3766
//  tree.Summary( summary );
3767
//  summary.Print();
3768
3769
	/** delete tree_ptr */
3770
  delete tree_ptr;
3771
}; /** end LaunchHelper() */
3772
3773
3774
3775
3776
3777
3778
template<typename T, typename SPDMATRIX>
3779
class SimpleGOFMM
3780
{
3781
  public:
3782
3783
    SimpleGOFMM( SPDMATRIX &K, T stol, T budget )
3784
    {
3785
      tree_ptr = Compress( K, stol, budget );
3786
    };
3787
3788
    ~SimpleGOFMM()
3789
    {
3790
      if ( tree_ptr ) delete tree_ptr;
3791
    };
3792
3793
    void Multiply( Data<T> &y, Data<T> &x )
3794
    {
3795
      //hmlp::Data<T> weights( x.col(), x.row() );
3796
3797
      //for ( size_t j = 0; j < x.col(); j ++ )
3798
      //  for ( size_t i = 0; i < x.row(); i ++ )
3799
      //    weights( j, i ) = x( i, j );
3800
3801
3802
      y = gofmm::Evaluate( *tree_ptr, x );
3803
      //auto potentials = hmlp::gofmm::Evaluate( *tree_ptr, weights );
3804
3805
      //for ( size_t j = 0; j < y.col(); j ++ )
3806
      //  for ( size_t i = 0; i < y.row(); i ++ )
3807
      //    y( i, j ) = potentials( j, i );
3808
3809
    };
3810
3811
  private:
3812
3813
    /** GOFMM tree */
3814
    tree::Tree<
3815
      gofmm::Setup<SPDMATRIX, centersplit<SPDMATRIX, 2, T>, T>,
3816
      gofmm::NodeData<T>> *tree_ptr = NULL;
3817
3818
}; /** end class SimpleGOFMM */
3819
3820
3821
3822
3823
3824
3825
3826
3827
3828
3829
3830
3831
3832
3833
/**
3834
 *  Instantiation types for double and single precision
3835
 */
3836
typedef SPDMatrix<double> dSPDMatrix_t;
3837
typedef SPDMatrix<float > sSPDMatrix_t;
3838
3839
typedef hmlp::gofmm::Setup<SPDMatrix<double>,
3840
    centersplit<SPDMatrix<double>, 2, double>, double> dSetup_t;
3841
3842
typedef hmlp::gofmm::Setup<SPDMatrix<float>,
3843
    centersplit<SPDMatrix<float >, 2,  float>,  float> sSetup_t;
3844
3845
typedef tree::Tree<dSetup_t, gofmm::NodeData<double>> dTree_t;
3846
typedef tree::Tree<sSetup_t, gofmm::NodeData<float >> sTree_t;
3847
3848
3849
3850
3851
3852
/**
3853
 *  PyCompress prototype. Notice that all pass-by-reference
3854
 *  arguments are replaced by pass-by-pointer. There implementaion
3855
 *  can be found at hmlp/package/$HMLP_ARCH/gofmm.gpp
3856
 **/
3857
Data<double> Evaluate( dTree_t *tree, Data<double> *weights );
3858
Data<float>  Evaluate( dTree_t *tree, Data<float > *weights );
3859
3860
dTree_t *Compress( dSPDMatrix_t *K, double stol, double budget );
3861
sTree_t *Compress( sSPDMatrix_t *K,  float stol,  float budget );
3862
3863
dTree_t *Compress( dSPDMatrix_t *K, double stol, double budget,
3864
		size_t m, size_t k, size_t s );
3865
sTree_t *Compress( sSPDMatrix_t *K,  float stol,  float budget,
3866
		size_t m, size_t k, size_t s );
3867
3868
3869
double ComputeError( dTree_t *tree, size_t gid, hmlp::Data<double> *potentials );
3870
float  ComputeError( sTree_t *tree, size_t gid, hmlp::Data<float>  *potentials );
3871
3872
3873
3874
3875
3876
3877
3878
3879
3880
3881
}; /** end namespace gofmm */
3882
}; /** end namespace hmlp */
3883
3884
#endif /** define GOFMM_HPP */