HMLP: High-performance Machine Learning Primitives
gofmm.hpp
1 
21 #ifndef GOFMM_HPP
22 #define GOFMM_HPP
23 
25 #include <future>
26 #include <thread>
27 #include <chrono>
28 #include <set>
29 #include <vector>
30 #include <map>
31 #include <unordered_set>
32 #include <deque>
33 #include <assert.h>
34 #include <typeinfo>
35 #include <algorithm>
36 #include <functional>
37 #include <array>
38 #include <random>
39 #include <numeric>
40 #include <sstream>
41 #include <iostream>
42 #include <string>
43 #include <stdio.h>
44 #include <omp.h>
45 #include <time.h>
46 
48 #include <hmlp.h>
49 #include <hmlp_base.hpp>
51 #include <primitives/lowrank.hpp>
52 #include <primitives/combinatorics.hpp>
53 #include <primitives/gemm.hpp>
55 #include <containers/VirtualMatrix.hpp>
56 #include <containers/SPDMatrix.hpp>
58 #include <tree.hpp>
59 #include <igofmm.hpp>
61 #ifdef HMLP_USE_CUDA
62 #include <cuda_runtime.h>
63 #include <gofmm_gpu.hpp>
64 #endif
65 
66 using namespace std;
67 using namespace hmlp;
68 
69 
70 
72 #define MAX_NRHS 1024
73 
74 #define GEMM_NB 256
75 
76 
77 //#define DEBUG_SPDASKIT 1
78 #define REPORT_ANN_ACCURACY 1
79 #define REPORT_COMPRESS_STATUS 1
80 #define REPORT_EVALUATE_STATUS 1
81 
82 
83 namespace hmlp
84 {
85 namespace gofmm
86 {
87 
90 {
91  public:
92 
94  CommandLineHelper( int argc, char *argv[] )
95  {
97  sscanf( argv[ 1 ], "%lu", &n );
99  sscanf( argv[ 2 ], "%lu", &m );
101  sscanf( argv[ 3 ], "%lu", &k );
103  sscanf( argv[ 4 ], "%lu", &s );
105  sscanf( argv[ 5 ], "%lu", &nrhs );
107  sscanf( argv[ 6 ], "%lf", &stol );
109  sscanf( argv[ 7 ], "%lf", &budget );
111  distance_type = argv[ 8 ];
112  if ( !distance_type.compare( "geometry" ) )
113  {
114  metric = GEOMETRY_DISTANCE;
115  }
116  else if ( !distance_type.compare( "kernel" ) )
117  {
118  metric = KERNEL_DISTANCE;
119  }
120  else if ( !distance_type.compare( "angle" ) )
121  {
122  metric = ANGLE_DISTANCE;
123  }
124  else
125  {
126  printf( "%s is not supported\n", argv[ 8 ] );
127  exit( 1 );
128  }
130  spdmatrix_type = argv[ 9 ];
131  if ( !spdmatrix_type.compare( "testsuit" ) )
132  {
134  }
135  else if ( !spdmatrix_type.compare( "userdefine" ) )
136  {
138  }
139  else if ( !spdmatrix_type.compare( "pvfmm" ) )
140  {
142  }
143  else if ( !spdmatrix_type.compare( "dense" ) || !spdmatrix_type.compare( "ooc" ) )
144  {
146  user_matrix_filename = argv[ 10 ];
147  if ( argc > 11 )
148  {
150  user_points_filename = argv[ 11 ];
152  sscanf( argv[ 12 ], "%lu", &d );
153  }
154  }
155  else if ( !spdmatrix_type.compare( "mlp" ) )
156  {
157  hidden_layers = argv[ 10 ];
158  user_points_filename = argv[ 11 ];
160  sscanf( argv[ 12 ], "%lu", &d );
161  }
162  else if ( !spdmatrix_type.compare( "cov" ) )
163  {
164  kernelmatrix_type = argv[ 10 ];
165  user_points_filename = argv[ 11 ];
167  sscanf( argv[ 12 ], "%lu", &d );
169  sscanf( argv[ 13 ], "%lu", &nb );
170  }
171  else if ( !spdmatrix_type.compare( "kernel" ) )
172  {
173  kernelmatrix_type = argv[ 10 ];
174  user_points_filename = argv[ 11 ];
176  sscanf( argv[ 12 ], "%lu", &d );
178  if ( argc > 13 ) sscanf( argv[ 13 ], "%lf", &h );
179  }
180  else
181  {
182  printf( "%s is not supported\n", argv[ 9 ] );
183  exit( 1 );
184  }
185  };
188  size_t n, m, k, s, nrhs;
190  double stol = 1E-3;
191  double budget = 0.0;
193  DistanceMetric metric = ANGLE_DISTANCE;
194 
196  size_t d, nb;
198  double h = 1.0;
199  string distance_type;
200  string spdmatrix_type;
201  string kernelmatrix_type;
202  string hidden_layers;
203  string user_matrix_filename;
204  string user_points_filename;
205 };
211 template<typename T>
213 {
214  public:
215 
216  Configuration() {};
217 
218  Configuration( DistanceMetric metric_type,
219  size_t problem_size, size_t leaf_node_size,
220  size_t neighbor_size, size_t maximum_rank,
221  T tolerance, T budget )
222  {
223  Set( metric_type, problem_size, leaf_node_size,
224  neighbor_size, maximum_rank, tolerance, budget );
225  };
226 
227  void Set( DistanceMetric metric_type,
228  size_t problem_size, size_t leaf_node_size,
229  size_t neighbor_size, size_t maximum_rank,
230  T tolerance, T budget )
231  {
232  this->metric_type = metric_type;
233  this->problem_size = problem_size;
234  this->leaf_node_size = leaf_node_size;
235  this->neighbor_size = neighbor_size;
236  this->maximum_rank = maximum_rank;
237  this->tolerance = tolerance;
238  this->budget = budget;
239  };
240 
241  void CopyFrom( Configuration<T> &config ) { *this = config; };
242 
243  DistanceMetric MetricType() { return metric_type; };
244 
245  size_t ProblemSize() { return problem_size; };
246 
247  size_t LeafNodeSize() { return leaf_node_size; };
248 
249  size_t NeighborSize() { return neighbor_size; };
250 
251  size_t MaximumRank() { return maximum_rank; };
252 
253  T Tolerance() { return tolerance; };
254 
255  T Budget() { return budget; };
256 
257  bool IsSymmetric() { return is_symmetric; };
258 
259  bool UseAdaptiveRanks() { return use_adaptive_ranks; };
260 
261  bool SecureAccuracy() { return secure_accuracy; };
262 
263  private:
264 
266  DistanceMetric metric_type = ANGLE_DISTANCE;
267 
269  size_t problem_size = 0;
270 
272  size_t leaf_node_size = 64;
273 
275  size_t neighbor_size = 32;
276 
278  size_t maximum_rank = 64;
279 
281  T tolerance = 1E-3;
282 
284  T budget = 0.03;
285 
287  bool is_symmetric = true;
288 
290  bool use_adaptive_ranks = true;
291 
293  bool secure_accuracy = false;
294 
295 };
300 template<typename SPDMATRIX, typename SPLITTER, typename T>
301 class Setup : public tree::Setup<SPLITTER, T>,
302  public Configuration<T>
303 {
304  public:
305 
306  Setup() {};
307 
310  SPDMATRIX &K, SPLITTER &splitter, Data<pair<T, size_t>> *NN )
311  {
312  this->CopyFrom( config );
313  this->K = &K;
314  this->splitter = splitter;
315  this->NN = NN;
316  };
317 
319  SPDMATRIX *K = NULL;
320 
322  Data<T> *w = NULL;
323  Data<T> *u = NULL;
324 
326  Data<T> *input = NULL;
327  Data<T> *output = NULL;
328 
330  T lambda = 0.0;
331 
333  bool do_ulv_factorization = true;
334 
335  private:
336 
337 
338 
339 
340 };
344 template<typename T>
345 class NodeData : public Factor<T>
346 {
347  public:
348 
350  NodeData() {};
351 
353  Lock lock;
354 
356  bool isskel = false;
357 
359  vector<size_t> skels;
360 
362  vector<int> jpvt;
363 
366 
368  map<size_t, T> snids;
369 
371  vector<size_t> candidate_rows;
372  vector<size_t> candidate_cols;
373 
376 
379  Data<T> u_skel;
380 
383  Data<T> u_leaf[ 20 ];
384 
387  View<T> u_view;
388 
391  Data<T> NearKab;
392  Data<T> FarKab;
393 
394 
397  Event updateweight;
398  Event skeltoskel;
399  Event skeltonode;
400  Event s2s;
401  Event s2n;
402 
404  double knn_acc = 0.0;
405  size_t num_acc = 0;
406 
407 };
413 template<typename NODE>
414 class TreeViewTask : public Task
415 {
416  public:
417 
418  NODE *arg = NULL;
419 
420  void Set( NODE *user_arg )
421  {
422  arg = user_arg;
423  name = string( "TreeView" );
424  label = to_string( arg->treelist_id );
425  cost = 1.0;
426  };
427 
429  void DependencyAnalysis() { arg->DependOnParent( this ); };
430 
431  void Execute( Worker* user_worker )
432  {
433  //printf( "TreeView %lu\n", node->treelist_id );
434  auto *node = arg;
435  auto &data = node->data;
436  auto *setup = node->setup;
437 
439  auto &w = *(setup->u);
440  auto &u = *(setup->w);
441 
443  auto &U = data.u_view;
444  auto &W = data.w_view;
445 
447  if ( !node->parent )
448  {
450  U.Set( u );
451  W.Set( w );
452  }
453 
455  if ( !node->isleaf )
456  {
457  auto &UL = node->lchild->data.u_view;
458  auto &UR = node->rchild->data.u_view;
459  auto &WL = node->lchild->data.w_view;
460  auto &WR = node->rchild->data.w_view;
465  U.Partition2x1( UL,
466  UR, node->lchild->n, TOP );
467  W.Partition2x1( WL,
468  WR, node->lchild->n, TOP );
469  }
470  };
471 };
492 template<typename NODE>
493 class Summary
494 {
495 
496  public:
497 
498  Summary() {};
499 
500  deque<Statistic> rank;
501 
502  deque<Statistic> skeletonize;
503 
505  deque<Statistic> updateweight;
506 
508  deque<Statistic> s2s_kij_t;
509  deque<Statistic> s2s_t;
510  deque<Statistic> s2s_gfp;
511 
513  deque<Statistic> s2n_kij_t;
514  deque<Statistic> s2n_t;
515  deque<Statistic> s2n_gfp;
516 
517 
518  void operator() ( NODE *node )
519  {
520  if ( rank.size() <= node->l )
521  {
522  rank.push_back( hmlp::Statistic() );
523  skeletonize.push_back( hmlp::Statistic() );
524  updateweight.push_back( hmlp::Statistic() );
525  }
526 
527  rank[ node->l ].Update( (double)node->data.skels.size() );
528  skeletonize[ node->l ].Update( node->data.skeletonize.GetDuration() );
529  updateweight[ node->l ].Update( node->data.updateweight.GetDuration() );
530 
531 #ifdef DUMP_ANALYSIS_DATA
532  if ( node->parent )
533  {
534  auto *parent = node->parent;
535  printf( "@TREE\n" );
536  printf( "#%lu (s%lu), #%lu (s%lu), %lu, %lu\n",
537  node->treelist_id, node->data.skels.size(),
538  parent->treelist_id, parent->data.skels.size(),
539  node->data.skels.size(), node->l );
540  }
541  else
542  {
543  printf( "@TREE\n" );
544  printf( "#%lu (s%lu), , %lu, %lu\n",
545  node->treelist_id, node->data.skels.size(),
546  node->data.skels.size(), node->l );
547  }
548 #endif
549  };
550 
551  void Print()
552  {
553  for ( size_t l = 1; l < rank.size(); l ++ )
554  {
555  printf( "@SUMMARY\n" );
556  printf( "level %2lu, ", l ); rank[ l ].Print();
557  //printf( "skel_t: " ); skeletonize[ l ].Print();
558  //printf( "... ... ...\n" );
559  //printf( "n2s_t: " ); updateweight[ l ].Print();
561  //printf( "s2s_t: " ); s2s_t[ l ].Print();
562  //printf( "s2s_gfp: " ); s2s_gfp[ l ].Print();
564  //printf( "s2n_t: " ); s2n_t[ l ].Print();
565  //printf( "s2n_gfp: " ); s2n_gfp[ l ].Print();
566  }
567  };
568 
569 };
579 template<typename SPDMATRIX, int N_SPLIT, typename T>
581 {
583  SPDMATRIX *Kptr = NULL;
585  DistanceMetric metric = ANGLE_DISTANCE;
587  size_t n_centroid_samples = 5;
588 
589  centersplit() {};
590 
591  centersplit( SPDMATRIX& K ) { this->Kptr = &K; };
592 
594  vector<vector<size_t>> operator() ( vector<size_t>& gids ) const
595  {
597  assert( N_SPLIT == 2 );
598  assert( Kptr );
599 
600 
601  SPDMATRIX &K = *Kptr;
602  vector<vector<size_t>> split( N_SPLIT );
603  size_t n = gids.size();
604  vector<T> temp( n, 0.0 );
605 
607  auto column_samples = combinatorics::SampleWithoutReplacement(
608  n_centroid_samples, gids );
609 
610 
612  auto DIC = K.Distances( this->metric, gids, column_samples );
613 
615  for ( auto & it : temp ) it = 0;
616 
618  for ( size_t j = 0; j < DIC.col(); j ++ )
619  for ( size_t i = 0; i < DIC.row(); i ++ )
620  temp[ i ] += DIC( i, j );
621 
623  auto idf2c = distance( temp.begin(), max_element( temp.begin(), temp.end() ) );
624 
626  vector<size_t> P( 1, gids[ idf2c ] );
627 
629  auto DIP = K.Distances( this->metric, gids, P );
630 
632  auto idf2f = distance( DIP.begin(), max_element( DIP.begin(), DIP.end() ) );
633 
635  vector<size_t> Q( 1, gids[ idf2f ] );
636 
638  auto DIQ = K.Distances( this->metric, gids, P );
639 
640  for ( size_t i = 0; i < temp.size(); i ++ )
641  temp[ i ] = DIP[ i ] - DIQ[ i ];
642 
643  return combinatorics::MedianSplit( temp );
644  };
645 };
657 template<typename SPDMATRIX, int N_SPLIT, typename T>
659 {
661  SPDMATRIX *Kptr = NULL;
662 
664  DistanceMetric metric = ANGLE_DISTANCE;
665 
666  randomsplit() {};
667 
668  randomsplit( SPDMATRIX& K ) { this->Kptr = &K; };
669 
671  inline vector<vector<size_t> > operator() ( vector<size_t>& gids ) const
672  {
673  assert( Kptr && ( N_SPLIT == 2 ) );
674 
675  SPDMATRIX &K = *Kptr;
676  size_t n = gids.size();
677  vector<vector<size_t> > split( N_SPLIT );
678  vector<T> temp( n, 0.0 );
679 
681  size_t idf2c = std::rand() % n;
682  size_t idf2f = std::rand() % n;
683  while ( idf2c == idf2f ) idf2f = std::rand() % n;
684 
685 
686  vector<size_t> P( 1, gids[ idf2c ] );
687  vector<size_t> Q( 1, gids[ idf2f ] );
688 
689 
691  auto DIP = K.Distances( this->metric, gids, P );
692  auto DIQ = K.Distances( this->metric, gids, Q );
693 
694  for ( size_t i = 0; i < temp.size(); i ++ )
695  temp[ i ] = DIP[ i ] - DIQ[ i ];
696 
697  return combinatorics::MedianSplit( temp );
698 
699  };
700 };
705 template<typename NODE>
706 void FindNeighbors( NODE *node, DistanceMetric metric )
707 {
709  using T = typename NODE::T;
710  auto & setup = *(node->setup);
711  auto & K = *(setup.K);
712  auto & NN = *(setup.NN);
713  auto & I = node->gids;
715  size_t kappa = NN.row();
717  pair<T, size_t> init( numeric_limits<T>::max(), NN.col() );
719  auto candidates = K.NeighborSearch( metric, kappa, I, I, init );
721  #pragma omp parallel
722  {
723  vector<pair<T, size_t> > aux( 2 * kappa );
724  #pragma omp for
725  for ( size_t j = 0; j < I.size(); j ++ )
726  {
727  MergeNeighbors( kappa, NN.columndata( I[ j ] ),
728  candidates.columndata( j ), aux );
729  }
730  }
731 };
737 template<class NODE, typename T>
738 class NeighborsTask : public Task
739 {
740  public:
741 
742  NODE *arg = NULL;
743 
745  DistanceMetric metric = ANGLE_DISTANCE;
746 
747  void Set( NODE *user_arg )
748  {
749  arg = user_arg;
750  name = string( "Neighbors" );
751  label = to_string( arg->treelist_id );
753  metric = arg->setup->MetricType();
754 
755  //--------------------------------------
756  double flops, mops;
757  auto &gids = arg->gids;
758  auto &NN = *arg->setup->NN;
759  flops = gids.size();
760  flops *= ( 4.0 * gids.size() );
761  // Heap select worst case
762  mops = (size_t)std::log( NN.row() ) * gids.size();
763  mops *= gids.size();
764  // Access K
765  mops += flops;
766  event.Set( name + label, flops, mops );
767  //--------------------------------------
768 
769  // TODO: Need an accurate cost model.
770  cost = mops / 1E+9;
771  };
772 
773  void DependencyAnalysis() { arg->DependOnNoOne( this ); };
774 
775  void Execute( Worker* user_worker ) { FindNeighbors( arg, metric ); };
776 
777 };
782 template<bool DOAPPROXIMATE, bool SORTED, typename T, typename CSCMATRIX>
783 Data<pair<T, size_t>> SparsePattern( size_t n, size_t k, CSCMATRIX &K )
784 {
785  pair<T, size_t> initNN( numeric_limits<T>::max(), n );
786  Data<pair<T, size_t>> NN( k, n, initNN );
787 
788  printf( "SparsePattern k %lu n %lu, NN.row %lu NN.col %lu ...",
789  k, n, NN.row(), NN.col() ); fflush( stdout );
790 
791  #pragma omp parallel for schedule( dynamic )
792  for ( size_t j = 0; j < n; j ++ )
793  {
794  std::set<size_t> NNset;
795  size_t nnz = K.ColPtr( j + 1 ) - K.ColPtr( j );
796  if ( DOAPPROXIMATE && nnz > 2 * k ) nnz = 2 * k;
797 
798  //printf( "j %lu nnz %lu\n", j, nnz );
799 
800  for ( size_t i = 0; i < nnz; i ++ )
801  {
802  // TODO: this is lid. Need to be gid.
803  auto row_ind = K.RowInd( K.ColPtr( j ) + i );
804  auto val = K.Value( K.ColPtr( j ) + i );
805 
806  if ( val ) val = 1.0 / std::abs( val );
807  else val = std::numeric_limits<T>::max() - 1.0;
808 
809  NNset.insert( row_ind );
810  std::pair<T, std::size_t> query( val, row_ind );
811  if ( nnz < k ) // not enough candidates
812  {
813  NN[ j * k + i ] = query;
814  }
815  else
816  {
817  hmlp::HeapSelect( 1, NN.row(), &query, NN.data() + j * NN.row() );
818  }
819  }
820 
821  while ( nnz < k )
822  {
823  std::size_t row_ind = rand() % n;
824  if ( !NNset.count( row_ind ) )
825  {
826  T val = std::numeric_limits<T>::max() - 1.0;
827  std::pair<T, std::size_t> query( val, row_ind );
828  NNset.insert( row_ind );
829  NN[ j * k + nnz ] = query;
830  nnz ++;
831  }
832  }
833  }
834  printf( "Done.\n" ); fflush( stdout );
835 
836  if ( SORTED )
837  {
838  printf( "Sorting ... " ); fflush( stdout );
839  struct
840  {
841  bool operator () ( std::pair<T, size_t> a, std::pair<T, size_t> b )
842  {
843  return a.first < b.first;
844  }
845  } ANNLess;
846 
847  //printf( "SparsePattern k %lu n %lu, NN.row %lu NN.col %lu\n", k, n, NN.row(), NN.col() );
848 
849  #pragma omp parallel for
850  for ( size_t j = 0; j < NN.col(); j ++ )
851  {
852  std::sort( NN.data() + j * NN.row(), NN.data() + ( j + 1 ) * NN.row(), ANNLess );
853  }
854  printf( "Done.\n" ); fflush( stdout );
855  }
856 
857  return NN;
858 };
861 /* @brief Helper functions for sorting sampling neighbors. */
862 template<typename TA, typename TB>
863 pair<TB, TA> flip_pair( const pair<TA, TB> &p )
864 {
865  return pair<TB, TA>( p.second, p.first );
866 };
869 template<typename TA, typename TB>
870 multimap<TB, TA> flip_map( const map<TA, TB> &src )
871 {
872  multimap<TB, TA> dst;
873  transform( src.begin(), src.end(), inserter( dst, dst.begin() ),
874  flip_pair<TA, TB> );
875  return dst;
876 };
882 template<typename NODE>
883 void Interpolate( NODE *node )
884 {
886  using T = typename NODE::T;
888  if ( !node ) return;
889 
890  auto &K = *node->setup->K;
891  auto &data = node->data;
892  auto &skels = data.skels;
893  auto &proj = data.proj;
894  auto &jpvt = data.jpvt;
895  auto s = proj.row();
896  auto n = proj.col();
897 
899  if ( !data.isskel || proj[ 0 ] == 0 ) return;
900 
901  assert( s );
902  assert( s <= n );
903  assert( jpvt.size() == n );
904 
906  if ( data.isskel )
907  {
908  data.w_skel.reserve( skels.size(), MAX_NRHS );
909  data.u_skel.reserve( skels.size(), MAX_NRHS );
910  }
911 
912 
914  Data<T> R1( s, s, 0.0 );
915 
916  for ( int j = 0; j < s; j ++ )
917  {
918  for ( int i = 0; i < s; i ++ )
919  {
920  if ( i <= j ) R1[ j * s + i ] = proj[ j * s + i ];
921  }
922  }
923 
925  Data<T> tmp = proj;
926 
928  xtrsm( "L", "U", "N", "N", s, n, 1.0, R1.data(), s, tmp.data(), s );
929 
931  for ( int j = 0; j < n; j ++ )
932  {
933  for ( int i = 0; i < s; i ++ )
934  {
935  proj[ jpvt[ j ] * s + i ] = tmp[ j * s + i ];
936  }
937  }
938 
939 
940 };
944 template<typename NODE>
945 class InterpolateTask : public Task
946 {
947  public:
948 
949  NODE *arg = NULL;
950 
951  void Set( NODE *user_arg )
952  {
953  arg = user_arg;
954  name = string( "it" );
955  label = to_string( arg->treelist_id );
956  // Need an accurate cost model.
957  cost = 1.0;
958  };
959 
960  void DependencyAnalysis() { arg->DependOnNoOne( this ); };
961 
962  void Execute( Worker* user_worker ) { Interpolate( arg ); };
963 
964 };
972 template<bool NNPRUNE, typename NODE>
973 void RowSamples( NODE *node, size_t nsamples )
974 {
976  using T = typename NODE::T;
977  auto &setup = *(node->setup);
978  auto &data = node->data;
979  auto &K = *(setup.K);
980 
982  auto &amap = data.candidate_rows;
983 
985  amap.clear();
986 
988  if ( setup.NN )
989  {
990  //printf( "construct snids NN.row() %lu NN.col() %lu\n",
991  // node->setup->NN->row(), node->setup->NN->col() ); fflush( stdout );
992  auto &NN = *(setup.NN);
993  auto &gids = node->gids;
994  auto &snids = data.snids;
995  size_t knum = NN.row();
996 
997  if ( node->isleaf )
998  {
999  snids.clear();
1000 
1001  vector<pair<T, size_t>> tmp( knum * gids.size() );
1002  for ( size_t j = 0; j < gids.size(); j ++ )
1003  for ( size_t i = 0; i < knum; i ++ )
1004  tmp[ j * knum + i ] = NN( i, gids[ j ] );
1005 
1007  sort( tmp.begin(), tmp.end() );
1008 
1010  for ( auto it : tmp )
1011  {
1012  size_t it_gid = it.second;
1013  size_t it_morton = setup.morton[ it_gid ];
1014 
1015  if ( snids.size() >= nsamples ) break;
1016 
1018  bool is_near;
1019  if ( NNPRUNE ) is_near = node->NNNearNodeMortonIDs.count( it_morton );
1020  else is_near = (it_morton == node->morton );
1021 
1022  if ( !is_near )
1023  {
1025  auto ret = snids.insert( make_pair( it.second, it.first ) );
1026  }
1027  }
1028  }
1029  else
1030  {
1031  auto &lsnids = node->lchild->data.snids;
1032  auto &rsnids = node->rchild->data.snids;
1033 
1035  snids = lsnids;
1036 
1042  for ( auto it = rsnids.begin(); it != rsnids.end(); it ++ )
1043  {
1044  auto ret = snids.insert( *it );
1045  if ( !ret.second )
1046  {
1047  if ( ret.first->second > (*it).first )
1048  ret.first->second = (*it).first;
1049  }
1050  }
1051 
1053  for ( auto gid : gids ) snids.erase( gid );
1054  }
1055 
1056 
1057  if ( nsamples < K.col() - node->n )
1058  {
1060  multimap<T, size_t> ordered_snids = flip_map( snids );
1062  amap.reserve( nsamples );
1063 
1065  for ( auto it : ordered_snids )
1066  {
1067  if ( amap.size() >= nsamples ) break;
1069  amap.push_back( it.second );
1070  }
1071 
1073  while ( amap.size() < nsamples )
1074  {
1075  //size_t sample = rand() % K.col();
1076  auto important_sample = K.ImportantSample( 0 );
1077  size_t sample_gid = important_sample.second;
1078  size_t sample_morton = setup.morton[ sample_gid ];
1079 
1080  if ( !MortonHelper::IsMyParent( sample_morton, node->morton ) )
1081  {
1082  amap.push_back( sample_gid );
1083  }
1084  }
1085  }
1086  else
1087  {
1088  for ( size_t sample = 0; sample < K.col(); sample ++ )
1089  {
1090  size_t sample_morton = setup.morton[ sample ];
1091  if ( !MortonHelper::IsMyParent( sample_morton, node->morton ) )
1092  {
1093  amap.push_back( sample );
1094  }
1095  }
1096  }
1097  }
1099 };
1108 template<bool NNPRUNE, typename NODE>
1109 void SkeletonKIJ( NODE *node )
1110 {
1112  using T = typename NODE::T;
1114  auto &K = *(node->setup->K);
1116  auto &data = node->data;
1117  auto &candidate_rows = data.candidate_rows;
1118  auto &candidate_cols = data.candidate_cols;
1119  auto &KIJ = data.KIJ;
1121  auto *lchild = node->lchild;
1122  auto *rchild = node->rchild;
1123 
1124  if ( node->isleaf )
1125  {
1127  candidate_cols = node->gids;
1128  }
1129  else
1130  {
1131  auto &lskels = lchild->data.skels;
1132  auto &rskels = rchild->data.skels;
1134  if ( !lskels.size() || !rskels.size() ) return;
1136  candidate_cols = lskels;
1137  candidate_cols.insert( candidate_cols.end(),
1138  rskels.begin(), rskels.end() );
1139  }
1140 
1142  size_t nsamples = 2 * candidate_cols.size();
1143 
1145  if ( nsamples < 2 * node->setup->LeafNodeSize() )
1146  nsamples = 2 * node->setup->LeafNodeSize();
1147 
1149  RowSamples<NNPRUNE>( node, nsamples );
1150 
1152  KIJ = K( candidate_rows, candidate_cols );
1153 
1154 
1155 
1156 };
1167 template<bool NNPRUNE, typename NODE, typename T>
1168 class SkeletonKIJTask : public Task
1169 {
1170  public:
1171 
1172  NODE *arg = NULL;
1173 
1174  void Set( NODE *user_arg )
1175  {
1176  arg = user_arg;
1177  name = string( "par-gskm" );
1178  label = to_string( arg->treelist_id );
1180  cost = 5.0;
1182  priority = true;
1183  };
1184 
1185  void DependencyAnalysis() { arg->DependOnChildren( this ); };
1186 
1187  void Execute( Worker* user_worker ) { SkeletonKIJ<NNPRUNE>( arg ); };
1188 
1189 };
1193 template<typename NODE>
1194 void Skeletonize( NODE *node )
1195 {
1197  using T = typename NODE::T;
1199  if ( !node->parent ) return;
1200 
1202  auto &K = *(node->setup->K);
1203  auto &NN = *(node->setup->NN);
1204  auto maxs = node->setup->MaximumRank();
1205  auto stol = node->setup->Tolerance();
1206  bool secure_accuracy = node->setup->SecureAccuracy();
1207  bool use_adaptive_ranks = node->setup->UseAdaptiveRanks();
1208 
1210  auto &data = node->data;
1211  auto &skels = data.skels;
1212  auto &proj = data.proj;
1213  auto &jpvt = data.jpvt;
1214  auto &KIJ = data.KIJ;
1215  auto &candidate_cols = data.candidate_cols;
1216 
1218  size_t N = K.col();
1219  size_t m = KIJ.row();
1220  size_t n = KIJ.col();
1221  size_t q = node->n;
1222 
1223  if ( secure_accuracy )
1224  {
1225  if ( !node->isleaf && ( !node->lchild->data.isskel || !node->rchild->data.isskel ) )
1226  {
1227  skels.clear();
1228  proj.resize( 0, 0 );
1229  data.isskel = false;
1230  return;
1231  }
1232  }
1233 
1235  T scaled_stol = std::sqrt( (T)n / q ) * std::sqrt( (T)m / (N - q) ) * stol;
1237  scaled_stol *= std::sqrt( (T)q / N );
1238 
1240  lowrank::id( use_adaptive_ranks, secure_accuracy,
1241  KIJ.row(), KIJ.col(), maxs, scaled_stol, KIJ, skels, proj, jpvt );
1242 
1244  KIJ.resize( 0, 0 );
1245 
1247  if ( secure_accuracy )
1248  {
1250  data.isskel = (skels.size() != 0);
1251  }
1252  else
1253  {
1254  assert( skels.size() && proj.size() && jpvt.size() );
1255  data.isskel = true;
1256  }
1257 
1259  for ( size_t i = 0; i < skels.size(); i ++ ) skels[ i ] = candidate_cols[ skels[ i ] ];
1260 
1261 };
1264 template<typename NODE, typename T>
1265 class SkeletonizeTask : public Task
1266 {
1267  public:
1268 
1269  NODE *arg = NULL;
1270 
1271  void Set( NODE *user_arg )
1272  {
1273  arg = user_arg;
1274  name = string( "sk" );
1275  label = to_string( arg->treelist_id );
1277  cost = 5.0;
1279  priority = true;
1280  };
1281 
1283  {
1284  double flops = 0.0, mops = 0.0;
1285 
1286  auto &K = *arg->setup->K;
1287  size_t n = arg->data.proj.col();
1288  size_t m = 2 * n;
1289  size_t k = arg->data.proj.row();
1290 
1292  flops += ( 2.0 / 3.0 ) * n * n * ( 3 * m - n );
1293  mops += ( 2.0 / 3.0 ) * n * n * ( 3 * m - n );
1294 
1295  /* TRSM */
1296  flops += k * ( k - 1 ) * ( n + 1 );
1297  mops += 2.0 * ( k * k + k * n );
1298 
1299  //flops += ( 2.0 / 3.0 ) * k * k * ( 3 * m - k );
1300  //mops += 2.0 * m * k;
1301  //flops += 2.0 * m * n * k;
1302  //mops += 2.0 * ( m * k + k * n + m * n );
1303  //flops += ( 1.0 / 3.0 ) * k * k * n;
1304  //mops += 2.0 * ( k * k + k * n );
1305 
1306  event.Set( label + name, flops, mops );
1307  arg->data.skeletonize = event;
1308  };
1309 
1310  void DependencyAnalysis() { arg->DependOnNoOne( this ); };
1311 
1312  void Execute( Worker* user_worker ) { Skeletonize( arg ); };
1313 
1314 };
1349 template<typename NODE>
1350 void UpdateWeights( NODE *node )
1351 {
1353  using T = typename NODE::T;
1355  if ( !node->parent || !node->data.isskel ) return;
1356 
1358  auto &w = *node->setup->w;
1359 
1361  auto &data = node->data;
1362  auto &proj = data.proj;
1363  auto &skels = data.skels;
1364  auto &w_skel = data.w_skel;
1365  auto &w_leaf = data.w_leaf;
1366  auto *lchild = node->lchild;
1367  auto *rchild = node->rchild;
1368 
1369 
1370  size_t nrhs = w.col();
1371 
1373  w_skel.resize( skels.size(), nrhs );
1374 
1375  //printf( "%lu UpdateWeight w_skel.num() %lu\n", node->treelist_id, w_skel.num() );
1376 
1377  if ( node->isleaf )
1378  {
1379  if ( w_leaf.size() )
1380  {
1381  //printf( "%8lu w_leaf allocated [%lu %lu]\n",
1382  // node->morton, w_leaf.row(), w_leaf.col() ); fflush( stdout );
1383 
1385  xgemm
1386  (
1387  "N", "N",
1388  w_skel.row(), w_skel.col(), w_leaf.row(),
1389  1.0, proj.data(), proj.row(),
1390  w_leaf.data(), w_leaf.row(),
1391  0.0, w_skel.data(), w_skel.row()
1392  );
1393  }
1394  else
1395  {
1397  View<T> W = data.w_view;
1398  //printf( "%8lu n2s W[%lu %lu ld %lu]\n",
1399  // node->morton, W.row(), W.col(), W.ld() ); fflush( stdout );
1400  //for ( int i = 0; i < 10; i ++ )
1401  // printf( "%lu W.data() + %d = %E\n", node->gids[ i ], i, *(W.data() + i) );
1402  xgemm
1403  (
1404  "N", "N",
1405  w_skel.row(), w_skel.col(), W.row(),
1406  1.0, proj.data(), proj.row(),
1407  W.data(), W.ld(),
1408  0.0, w_skel.data(), w_skel.row()
1409  );
1410  }
1411 
1412  //double update_leaf_time = omp_get_wtime() - beg;
1413  //printf( "%lu, m %lu n %lu k %lu, total %.3E\n",
1414  // node->treelist_id,
1415  // w_skel.row(), w_skel.col(), w_leaf.col(), update_leaf_time );
1416  }
1417  else
1418  {
1419  //double beg = omp_get_wtime();
1420  auto &w_lskel = lchild->data.w_skel;
1421  auto &w_rskel = rchild->data.w_skel;
1422  auto &lskel = lchild->data.skels;
1423  auto &rskel = rchild->data.skels;
1424 
1425  //if ( 1 )
1426  if ( node->treelist_id > 6 )
1427  {
1428  //printf( "%8lu n2s\n", node->morton ); fflush( stdout );
1429  xgemm
1430  (
1431  "N", "N",
1432  w_skel.row(), w_skel.col(), lskel.size(),
1433  1.0, proj.data(), proj.row(),
1434  w_lskel.data(), w_lskel.row(),
1435  0.0, w_skel.data(), w_skel.row()
1436  );
1437  xgemm
1438  (
1439  "N", "N",
1440  w_skel.row(), w_skel.col(), rskel.size(),
1441  1.0, proj.data() + proj.row() * lskel.size(), proj.row(),
1442  w_rskel.data(), w_rskel.row(),
1443  1.0, w_skel.data(), w_skel.row()
1444  );
1445  }
1446  else
1447  {
1449  View<T> P( false, proj ), PL,
1450  PR;
1451  View<T> W( false, w_skel ), WL( false, w_lskel ),
1452  WR( false, w_rskel );
1454  P.Partition1x2( PL, PR, lskel.size(), LEFT );
1456  gemm::xgemm<GEMM_NB>( (T)1.0, PL, WL, (T)0.0, W );
1457  W.DependencyCleanUp();
1459  gemm::xgemm<GEMM_NB>( (T)1.0, PR, WR, (T)1.0, W );
1460  //W.DependencyCleanUp();
1461  }
1462  }
1463 };
1469 template<typename NODE, typename T>
1470 class UpdateWeightsTask : public Task
1471 {
1472  public:
1473 
1474  NODE *arg = NULL;
1475 
1476  void Set( NODE *user_arg )
1477  {
1478  arg = user_arg;
1479  name = string( "n2s" );
1480  label = to_string( arg->treelist_id );
1481 
1483  double flops, mops;
1484  auto &gids = arg->gids;
1485  auto &skels = arg->data.skels;
1486  auto &w = *arg->setup->w;
1487  if ( arg->isleaf )
1488  {
1489  auto m = skels.size();
1490  auto n = w.col();
1491  auto k = gids.size();
1492  flops = 2.0 * m * n * k;
1493  mops = 2.0 * ( m * n + m * k + k * n );
1494  }
1495  else
1496  {
1497  auto &lskels = arg->lchild->data.skels;
1498  auto &rskels = arg->rchild->data.skels;
1499  auto m = skels.size();
1500  auto n = w.col();
1501  auto k = lskels.size() + rskels.size();
1502  flops = 2.0 * m * n * k;
1503  mops = 2.0 * ( m * n + m * k + k * n );
1504  }
1505 
1507  event.Set( label + name, flops, mops );
1509  cost = flops / 1E+9;
1511  priority = true;
1512  };
1513 
1514  void Prefetch( Worker* user_worker )
1515  {
1516  auto &proj = arg->data.proj;
1517  __builtin_prefetch( proj.data() );
1518  auto &w_skel = arg->data.w_skel;
1519  __builtin_prefetch( w_skel.data() );
1520  if ( arg->isleaf )
1521  {
1522  auto &w_leaf = arg->data.w_leaf;
1523  __builtin_prefetch( w_leaf.data() );
1524  }
1525  else
1526  {
1527  auto &w_lskel = arg->lchild->data.w_skel;
1528  __builtin_prefetch( w_lskel.data() );
1529  auto &w_rskel = arg->rchild->data.w_skel;
1530  __builtin_prefetch( w_rskel.data() );
1531  }
1532 #ifdef HMLP_USE_CUDA
1533  hmlp::Device *device = NULL;
1534  if ( user_worker ) device = user_worker->GetDevice();
1535  if ( device )
1536  {
1537  proj.CacheD( device );
1538  proj.PrefetchH2D( device, 1 );
1539  if ( arg->isleaf )
1540  {
1541  auto &w_leaf = arg->data.w_leaf;
1542  w_leaf.CacheD( device );
1543  w_leaf.PrefetchH2D( device, 1 );
1544  }
1545  else
1546  {
1547  auto &w_lskel = arg->lchild->data.w_skel;
1548  w_lskel.CacheD( device );
1549  w_lskel.PrefetchH2D( device, 1 );
1550  auto &w_rskel = arg->rchild->data.w_skel;
1551  w_rskel.CacheD( device );
1552  w_rskel.PrefetchH2D( device, 1 );
1553  }
1554  }
1555 #endif
1556  };
1557 
1558  void DependencyAnalysis() { arg->DependOnChildren( this ); };
1559 
1560  void Execute( Worker* user_worker )
1561  {
1562 #ifdef HMLP_USE_CUDA
1563  hmlp::Device *device = NULL;
1564  if ( user_worker ) device = user_worker->GetDevice();
1565  if ( device ) gpu::UpdateWeights( device, arg );
1566  else UpdateWeights<NODE, T>( arg );
1567 #else
1568  UpdateWeights( arg );
1569 #endif
1570  };
1571 
1572 };
1582 template<typename NODE>
1583 void SkeletonsToSkeletons( NODE *node )
1584 {
1586  using T = typename NODE::T;
1588  if ( !node->parent || !node->data.isskel ) return;
1589 
1590  double beg, u_skel_time, s2s_time;
1591 
1592  auto *FarNodes = &node->NNFarNodes;
1593 
1594  auto &K = *node->setup->K;
1595  auto &data = node->data;
1596  auto &amap = node->data.skels;
1597  auto &u_skel = node->data.u_skel;
1598  auto &FarKab = node->data.FarKab;
1599 
1600  size_t nrhs = node->setup->w->col();
1601 
1603  beg = omp_get_wtime();
1604  u_skel.resize( 0, 0 );
1605  u_skel.resize( amap.size(), nrhs, 0.0 );
1606  u_skel_time = omp_get_wtime() - beg;
1607 
1608  size_t offset = 0;
1609 
1610 
1612  View<T> FarKab_v( FarKab );
1613 
1615  for ( auto it = FarNodes->begin(); it != FarNodes->end(); it ++ )
1616  {
1617  auto &bmap = (*it)->data.skels;
1618  auto &w_skel = (*it)->data.w_skel;
1619  assert( w_skel.col() == nrhs );
1620  assert( w_skel.row() == bmap.size() );
1621  assert( w_skel.size() == nrhs * bmap.size() );
1622 
1623  if ( FarKab.size() )
1624  {
1625  //if ( node->treelist_id > 6 )
1626  if ( 1 )
1627  {
1628  assert( FarKab.row() == amap.size() );
1629  assert( u_skel.row() * offset <= FarKab.size() );
1630 
1631  //printf( "%8lu s2s %8lu w_skel[%lu %lu]\n",
1632  // node->morton, (*it)->morton, w_skel.row(), w_skel.col() );
1633  //fflush( stdout );
1634  xgemm
1635  (
1636  "N", "N",
1637  u_skel.row(), u_skel.col(), w_skel.row(),
1638  1.0, FarKab.data() + u_skel.row() * offset, FarKab.row(),
1639  w_skel.data(), w_skel.row(),
1640  1.0, u_skel.data(), u_skel.row()
1641  );
1642 
1643  }
1644  else
1645  {
1647  View<T> U( false, u_skel );
1648  View<T> W( false, w_skel );
1649  View<T> Kab;
1650  assert( FarKab.col() >= W.row() + offset );
1651  Kab.Set( FarKab.row(), W.row(), 0, offset, &FarKab_v );
1652  gemm::xgemm<GEMM_NB>( (T)1.0, Kab, W, (T)1.0, U );
1653  }
1654 
1656  offset += w_skel.row();
1657  }
1658  else
1659  {
1660  printf( "Far Kab not cached treelist_id %lu, l %lu\n\n",
1661  node->treelist_id, node->l ); fflush( stdout );
1662 
1664  auto Kab = K( amap, bmap );
1665  xgemm( "N", "N", u_skel.row(), u_skel.col(), w_skel.row(),
1666  1.0, Kab.data(), Kab.row(),
1667  w_skel.data(), w_skel.row(),
1668  1.0, u_skel.data(), u_skel.row() );
1669  }
1670  }
1671 
1672 };
1686 template<bool NNPRUNE, typename NODE, typename T>
1688 {
1689  public:
1690 
1691  NODE *arg = NULL;
1692 
1693  void Set( NODE *user_arg )
1694  {
1695  arg = user_arg;
1696  name = string( "s2s" );
1697  {
1698  //label = std::to_string( arg->treelist_id );
1699  ostringstream ss;
1700  ss << arg->treelist_id;
1701  label = ss.str();
1702  }
1703 
1705  double flops = 0.0, mops = 0.0;
1706  auto &w = *arg->setup->w;
1707  size_t m = arg->data.skels.size();
1708  size_t n = w.col();
1709 
1710  std::set<NODE*> *FarNodes;
1711  if ( NNPRUNE ) FarNodes = &arg->NNFarNodes;
1712  else FarNodes = &arg->FarNodes;
1713 
1714  for ( auto it = FarNodes->begin(); it != FarNodes->end(); it ++ )
1715  {
1716  size_t k = (*it)->data.skels.size();
1717  flops += 2.0 * m * n * k;
1718  mops += m * k; // cost of Kab
1719  mops += 2.0 * ( m * n + n * k + k * n );
1720  }
1721 
1723  event.Set( label + name, flops, mops );
1725  cost = flops / 1E+9;
1727  priority = true;
1728  };
1729 
1730  void DependencyAnalysis()
1731  {
1732  for ( auto it : arg->NNFarNodes ) it->DependencyAnalysis( R, this );
1733  arg->DependencyAnalysis( RW, this );
1734  this->TryEnqueue();
1735  };
1736 
1737  void Execute( Worker* user_worker ) { SkeletonsToSkeletons( arg ); };
1738 };
1746 template<typename NODE>
1747 void SkeletonsToNodes( NODE *node )
1748 {
1750  using T = typename NODE::T;
1751 
1753  auto &K = *node->setup->K;
1754  auto &w = *node->setup->w;
1755 
1757  auto &gids = node->gids;
1758  auto &data = node->data;
1759  auto &proj = data.proj;
1760  auto &skels = data.skels;
1761  auto &u_skel = data.u_skel;
1762  auto *lchild = node->lchild;
1763  auto *rchild = node->rchild;
1764 
1765  size_t nrhs = w.col();
1766 
1767 
1768 
1769 
1770 
1771  if ( node->isleaf )
1772  {
1774  View<T> U = data.u_view;
1775 
1776  if ( U.col() == nrhs )
1777  {
1778  //printf( "%8lu s2n U[%lu %lu %lu]\n",
1779  // node->morton, U.row(), U.col(), U.ld() ); fflush( stdout );
1780  xgemm
1781  (
1782  "Transpose", "Non-transpose",
1783  U.row(), U.col(), u_skel.row(),
1784  1.0, proj.data(), proj.row(),
1785  u_skel.data(), u_skel.row(),
1786  1.0, U.data(), U.ld()
1787  );
1788  }
1789  else
1790  {
1791  //printf( "%8lu use u_leaf u_view [%lu %lu ld %lu]\n",
1792  // node->morton, U.row(), U.col(), U.ld() ); fflush( stdout );
1793 
1794  auto &u_leaf = node->data.u_leaf[ 0 ];
1795 
1797  u_leaf.resize( 0, 0 );
1798  u_leaf.resize( gids.size(), nrhs, 0.0 );
1799 
1801  if ( data.isskel )
1802  {
1804  xgemm
1805  (
1806  "T", "N",
1807  u_leaf.row(), u_leaf.col(), u_skel.row(),
1808  1.0, proj.data(), proj.row(),
1809  u_skel.data(), u_skel.row(),
1810  1.0, u_leaf.data(), u_leaf.row()
1811  );
1812  }
1813  }
1814  }
1815  else
1816  {
1817  if ( !node->parent || !node->data.isskel ) return;
1818 
1819  auto &u_lskel = lchild->data.u_skel;
1820  auto &u_rskel = rchild->data.u_skel;
1821  auto &lskel = lchild->data.skels;
1822  auto &rskel = rchild->data.skels;
1823 
1824  //if ( 1 )
1825  if ( node->treelist_id > 6 )
1826  {
1827  //printf( "%8lu s2n\n", node->morton ); fflush( stdout );
1828  xgemm
1829  (
1830  "Transpose", "No transpose",
1831  u_lskel.row(), u_lskel.col(), proj.row(),
1832  1.0, proj.data(), proj.row(),
1833  u_skel.data(), u_skel.row(),
1834  1.0, u_lskel.data(), u_lskel.row()
1835  );
1836  xgemm
1837  (
1838  "Transpose", "No transpose",
1839  u_rskel.row(), u_rskel.col(), proj.row(),
1840  1.0, proj.data() + proj.row() * lskel.size(), proj.row(),
1841  u_skel.data(), u_skel.row(),
1842  1.0, u_rskel.data(), u_rskel.row()
1843  );
1844  }
1845  else
1846  {
1848  View<T> P( true, proj ), PL,
1849  PR;
1850  View<T> U( false, u_skel ), UL( false, u_lskel ),
1851  UR( false, u_rskel );
1853  P.Partition2x1( PL,
1854  PR, lskel.size(), TOP );
1856  gemm::xgemm<GEMM_NB>( (T)1.0, PL, U, (T)1.0, UL );
1858  gemm::xgemm<GEMM_NB>( (T)1.0, PR, U, (T)1.0, UR );
1859  }
1860  }
1861  //printf( "\n" );
1862 
1863 };
1866 template<bool NNPRUNE, typename NODE, typename T>
1868 {
1869  public:
1870 
1871  NODE *arg = NULL;
1872 
1873  void Set( NODE *user_arg )
1874  {
1875  arg = user_arg;
1876  name = string( "s2n" );
1877  label = to_string( arg->treelist_id );
1878 
1879  //--------------------------------------
1880  double flops = 0.0, mops = 0.0;
1881  auto &gids = arg->gids;
1882  auto &data = arg->data;
1883  auto &proj = data.proj;
1884  auto &skels = data.skels;
1885  auto &w = *arg->setup->w;
1886 
1887  if ( arg->isleaf )
1888  {
1889  size_t m = proj.col();
1890  size_t n = w.col();
1891  size_t k = proj.row();
1892  flops += 2.0 * m * n * k;
1893  mops += 2.0 * ( m * n + n * k + m * k );
1894  }
1895  else
1896  {
1897  if ( !arg->parent || !arg->data.isskel )
1898  {
1899  // No computation.
1900  }
1901  else
1902  {
1903  size_t m = proj.col();
1904  size_t n = w.col();
1905  size_t k = proj.row();
1906  flops += 2.0 * m * n * k;
1907  mops += 2.0 * ( m * n + n * k + m * k );
1908  }
1909  }
1910 
1912  event.Set( label + name, flops, mops );
1914  cost = flops / 1E+9;
1916  priority = true;
1917  };
1918 
1919  void Prefetch( Worker* user_worker )
1920  {
1921  auto &proj = arg->data.proj;
1922  __builtin_prefetch( proj.data() );
1923  auto &u_skel = arg->data.u_skel;
1924  __builtin_prefetch( u_skel.data() );
1925  if ( arg->isleaf )
1926  {
1927  //__builtin_prefetch( arg->data.u_leaf[ 0 ].data() );
1928  //__builtin_prefetch( arg->data.u_leaf[ 1 ].data() );
1929  //__builtin_prefetch( arg->data.u_leaf[ 2 ].data() );
1930  //__builtin_prefetch( arg->data.u_leaf[ 3 ].data() );
1931  }
1932  else
1933  {
1934  auto &u_lskel = arg->lchild->data.u_skel;
1935  __builtin_prefetch( u_lskel.data() );
1936  auto &u_rskel = arg->rchild->data.u_skel;
1937  __builtin_prefetch( u_rskel.data() );
1938  }
1939 #ifdef HMLP_USE_CUDA
1940  hmlp::Device *device = NULL;
1941  if ( user_worker ) device = user_worker->GetDevice();
1942  if ( device )
1943  {
1944  int stream_id = arg->treelist_id % 8;
1945  proj.CacheD( device );
1946  proj.PrefetchH2D( device, stream_id );
1947  u_skel.CacheD( device );
1948  u_skel.PrefetchH2D( device, stream_id );
1949  if ( arg->isleaf )
1950  {
1951  }
1952  else
1953  {
1954  auto &u_lskel = arg->lchild->data.u_skel;
1955  u_lskel.CacheD( device );
1956  u_lskel.PrefetchH2D( device, stream_id );
1957  auto &u_rskel = arg->rchild->data.u_skel;
1958  u_rskel.CacheD( device );
1959  u_rskel.PrefetchH2D( device, stream_id );
1960  }
1961  }
1962 #endif
1963  };
1964 
1965  void DependencyAnalysis() { arg->DependOnParent( this ); };
1966 
1967  void Execute( Worker* user_worker )
1968  {
1969 #ifdef HMLP_USE_CUDA
1970  Device *device = NULL;
1971  if ( user_worker ) device = user_worker->GetDevice();
1972  if ( device ) gpu::SkeletonsToNodes<NNPRUNE, NODE, T>( device, arg );
1973  else SkeletonsToNodes<NNPRUNE, NODE, T>( arg );
1974 #else
1975  SkeletonsToNodes( arg );
1976 #endif
1977  };
1978 
1979 };
1983 template<int SUBTASKID, bool NNPRUNE, typename NODE, typename T>
1984 void LeavesToLeaves( NODE *node, size_t itbeg, size_t itend )
1985 {
1986  assert( node->isleaf );
1987 
1988  double beg, u_leaf_time, before_writeback_time, after_writeback_time;
1989 
1991  auto &K = *node->setup->K;
1992  auto &w = *node->setup->w;
1993 
1994  auto &gids = node->gids;
1995  auto &data = node->data;
1996  auto &amap = node->gids;
1997  auto &NearKab = data.NearKab;
1998 
1999  size_t nrhs = w.col();
2000 
2001  set<NODE*> *NearNodes;
2002  if ( NNPRUNE ) NearNodes = &node->NNNearNodes;
2003  else NearNodes = &node->NearNodes;
2004 
2008  auto &u_leaf = data.u_leaf[ SUBTASKID ];
2009  u_leaf.resize( 0, 0 );
2010 
2012  if ( itbeg == itend )
2013  {
2014  return;
2015  }
2016  else
2017  {
2018  u_leaf.resize( gids.size(), nrhs, 0.0 );
2019  }
2020 
2021  if ( NearKab.size() )
2022  {
2023  size_t itptr = 0;
2024  size_t offset = 0;
2025 
2026  for ( auto it = NearNodes->begin(); it != NearNodes->end(); it ++ )
2027  {
2028  if ( itbeg <= itptr && itptr < itend )
2029  {
2030  //auto wb = w( bmap );
2031  auto wb = (*it)->data.w_leaf;
2032 
2033  if ( wb.size() )
2034  {
2036  xgemm
2037  (
2038  "N", "N",
2039  u_leaf.row(), u_leaf.col(), wb.row(),
2040  1.0, NearKab.data() + offset * NearKab.row(), NearKab.row(),
2041  wb.data(), wb.row(),
2042  1.0, u_leaf.data(), u_leaf.row()
2043  );
2044  }
2045  else
2046  {
2047  View<T> W = (*it)->data.w_view;
2048  xgemm
2049  (
2050  "N", "N",
2051  u_leaf.row(), u_leaf.col(), W.row(),
2052  1.0, NearKab.data() + offset * NearKab.row(), NearKab.row(),
2053  W.data(), W.ld(),
2054  1.0, u_leaf.data(), u_leaf.row()
2055  );
2056  }
2057  }
2058  offset += (*it)->gids.size();
2059  itptr ++;
2060  }
2061  }
2062  else
2063  {
2064  size_t itptr = 0;
2065  for ( auto it = NearNodes->begin(); it != NearNodes->end(); it ++ )
2066  {
2067  if ( itbeg <= itptr && itptr < itend )
2068  {
2069  auto &bmap = (*it)->gids;
2070  auto wb = (*it)->data.w_leaf;
2071 
2073  auto Kab = K( amap, bmap );
2074 
2075  if ( wb.size() )
2076  {
2078  xgemm( "N", "N", u_leaf.row(), u_leaf.col(), wb.row(),
2079  1.0, Kab.data(), Kab.row(),
2080  wb.data(), wb.row(),
2081  1.0, u_leaf.data(), u_leaf.row());
2082  }
2083  else
2084  {
2085  View<T> W = (*it)->data.w_view;
2086  xgemm( "N", "N", u_leaf.row(), u_leaf.col(), W.row(),
2087  1.0, Kab.data(), Kab.row(),
2088  W.data(), W.ld(),
2089  1.0, u_leaf.data(), u_leaf.row() );
2090  }
2091  }
2092  itptr ++;
2093  }
2094  }
2095  before_writeback_time = omp_get_wtime() - beg;
2096 
2097 };
2100 template<int SUBTASKID, bool NNPRUNE, typename NODE, typename T>
2101 class LeavesToLeavesTask : public Task
2102 {
2103  public:
2104 
2105  NODE *arg = NULL;
2106 
2107  size_t itbeg;
2108 
2109  size_t itend;
2110 
2111  void Set( NODE *user_arg )
2112  {
2113  arg = user_arg;
2114  name = string( "l2l" );
2115  label = to_string( arg->treelist_id );
2116 
2118  //--------------------------------------
2119  double flops = 0.0, mops = 0.0;
2120  auto &gids = arg->gids;
2121  auto &data = arg->data;
2122  auto &proj = data.proj;
2123  auto &skels = data.skels;
2124  auto &w = *arg->setup->w;
2125  auto &K = *arg->setup->K;
2126  auto &NearKab = data.NearKab;
2127 
2128  assert( arg->isleaf );
2129 
2130  size_t m = gids.size();
2131  size_t n = w.col();
2132 
2133  set<NODE*> *NearNodes;
2134  if ( NNPRUNE ) NearNodes = &arg->NNNearNodes;
2135  else NearNodes = &arg->NearNodes;
2136 
2138  size_t itptr = 0;
2139  size_t itrange = ( NearNodes->size() + 3 ) / 4;
2140  if ( itrange < 1 ) itrange = 1;
2141  itbeg = ( SUBTASKID - 1 ) * itrange;
2142  itend = ( SUBTASKID + 0 ) * itrange;
2143  if ( itbeg > NearNodes->size() ) itbeg = NearNodes->size();
2144  if ( itend > NearNodes->size() ) itend = NearNodes->size();
2145  if ( SUBTASKID == 4 ) itend = NearNodes->size();
2146 
2147  for ( auto it = NearNodes->begin(); it != NearNodes->end(); it ++ )
2148  {
2149  if ( itbeg <= itptr && itptr < itend )
2150  {
2151  size_t k = (*it)->gids.size();
2152  flops += 2.0 * m * n * k;
2153  mops += m * k;
2154  mops += 2.0 * ( m * n + n * k + m * k );
2155  }
2156  itptr ++;
2157  }
2158 
2160  event.Set( label + name, flops, mops );
2161 
2163  cost = flops / 1E+9;
2164  };
2165 
2166  void Prefetch( Worker* user_worker )
2167  {
2168  auto &u_leaf = arg->data.u_leaf[ SUBTASKID ];
2169  __builtin_prefetch( u_leaf.data() );
2170  };
2171 
2173  {
2175  //arg->data.s2n = event;
2176  };
2177 
2179  {
2180  assert( arg->isleaf );
2182  this->TryEnqueue();
2183 
2185  //auto &u_leaf = arg->data.u_leaf[ SUBTASKID ];
2186  //u_leaf.DependencyAnalysis( hmlp::ReadWriteType::W, this );
2187  };
2188 
2189  void Execute( Worker* user_worker )
2190  {
2191  LeavesToLeaves<SUBTASKID, NNPRUNE, NODE, T>( arg, itbeg, itend );
2192  };
2193 
2194 };
2199 template<typename NODE>
2200 void PrintSet( set<NODE*> &set )
2201 {
2202  for ( auto it = set.begin(); it != set.end(); it ++ )
2203  {
2204  printf( "%lu, ", (*it)->treelist_id );
2205  }
2206  printf( "\n" );
2207 };
2216 template<typename NODE>
2217 multimap<size_t, size_t> NearNodeBallots( NODE *node )
2218 {
2220  assert( node->isleaf );
2221 
2222  auto &setup = *(node->setup);
2223  auto &NN = *(setup.NN);
2224  auto &gids = node->gids;
2225 
2227  map<size_t, size_t> ballot;
2228 
2229  size_t HasMissingNeighbors = 0;
2230 
2231 
2233  for ( size_t j = 0; j < gids.size(); j ++ )
2234  {
2235  for ( size_t i = 0; i < NN.row(); i ++ )
2236  {
2237  auto value = NN( i, gids[ j ] ).first;
2238  size_t neighbor_gid = NN( i, gids[ j ] ).second;
2240  if ( neighbor_gid >= 0 && neighbor_gid < NN.col() )
2241  {
2242  size_t neighbor_morton = setup.morton[ neighbor_gid ];
2243  size_t weighted_ballot = 1.0 / ( value + 1E-3 );
2244  //printf( "gid %lu i %lu neighbor_gid %lu morton %lu\n", gids[ j ], i,
2245  // neighbor_gid, neighbor_morton );
2246 
2247  if ( i < NN.row() / 2 )
2248  {
2249  if ( ballot.find( neighbor_morton ) != ballot.end() )
2250  {
2251  //ballot[ neighbor_morton ] ++;
2252  ballot[ neighbor_morton ] += weighted_ballot;
2253  }
2254  else
2255  {
2256  //ballot[ neighbor_morton ] = 1;
2257  ballot[ neighbor_morton ] = weighted_ballot;
2258  }
2259  }
2260  }
2261  else
2262  {
2263  HasMissingNeighbors ++;
2264  }
2265  }
2266  }
2267 
2268  if ( HasMissingNeighbors )
2269  {
2270  printf( "Missing %lu neighbor pairs\n", HasMissingNeighbors );
2271  fflush( stdout );
2272  }
2273 
2275  return flip_map( ballot );
2276 
2277 };
2282 template<typename NODE, typename T>
2283 void NearSamples( NODE *node )
2284 {
2285  auto &setup = *(node->setup);
2286  auto &NN = *(setup.NN);
2287 
2288  if ( node->isleaf )
2289  {
2290  auto &gids = node->gids;
2291  //double budget = setup.budget;
2292  double budget = setup.Budget();
2293  size_t n_nodes = ( 1 << node->l );
2294 
2296  node->NearNodes.insert( node );
2297  node->NNNearNodes.insert( node );
2298  node->NNNearNodeMortonIDs.insert( node->morton );
2299 
2301  multimap<size_t, size_t> sorted_ballot = NearNodeBallots( node );
2302 
2304  for ( auto it = sorted_ballot.rbegin(); it != sorted_ballot.rend(); it ++ )
2305  {
2307  if ( node->NNNearNodes.size() >= n_nodes * budget ) break;
2309  auto *target = (*node->morton2node)[ (*it).second ];
2310  node->NNNearNodeMortonIDs.insert( (*it).second );
2311  node->NNNearNodes.insert( target );
2312  }
2313  }
2314 
2315 };
2319 template<typename NODE, typename T>
2320 class NearSamplesTask : public Task
2321 {
2322  public:
2323 
2324  NODE *arg = NULL;
2325 
2326  void Set( NODE *user_arg )
2327  {
2328  arg = user_arg;
2329  name = string( "near" );
2330 
2331  //--------------------------------------
2332  double flops = 0.0, mops = 0.0;
2333 
2335  event.Set( label + name, flops, mops );
2337  cost = 1.0;
2339  priority = true;
2340  }
2341 
2342  void DependencyAnalysis() { this->TryEnqueue(); };
2343 
2344  void Execute( Worker* user_worker )
2345  {
2346  NearSamples<NODE, T>( arg );
2347  };
2348 
2349 };
2352 template<typename TREE>
2353 void SymmetrizeNearInteractions( TREE & tree )
2354 {
2355  int n_nodes = 1 << tree.depth;
2356  auto level_beg = tree.treelist.begin() + n_nodes - 1;
2357 
2358  for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2359  {
2360  auto *node = *(level_beg + node_ind);
2361  auto & NearMortonIDs = node->NNNearNodeMortonIDs;
2362  for ( auto & it : NearMortonIDs )
2363  {
2364  auto *target = tree.morton2node[ it ];
2365  target->NNNearNodes.insert( node );
2366  target->NNNearNodeMortonIDs.insert( it );
2367  }
2368  }
2369 };
2373 template<bool NNPRUNE, typename NODE>
2374 class CacheNearNodesTask : public Task
2375 {
2376  public:
2377 
2378  NODE *arg;
2379 
2380  void Set( NODE *user_arg )
2381  {
2382  arg = user_arg;
2383  name = string( "c-n" );
2384  label = to_string( arg->treelist_id );
2386  cost = 1.0;
2387  };
2388 
2390  {
2391  double flops = 0.0, mops = 0.0;
2392 
2393  NODE *node = arg;
2394  auto *NearNodes = &node->NearNodes;
2395  if ( NNPRUNE ) NearNodes = &node->NNNearNodes;
2396  auto &K = *node->setup->K;
2397 
2398  size_t m = node->gids.size();
2399  size_t n = 0;
2400  for ( auto it = NearNodes->begin(); it != NearNodes->end(); it ++ )
2401  {
2402  n += (*it)->gids.size();
2403  }
2405  event.Set( label + name, flops, mops );
2406  };
2407 
2408  void DependencyAnalysis() { arg->DependOnNoOne( this ); };
2409 
2410  void Execute( Worker* user_worker )
2411  {
2412  //printf( "%lu CacheNearNodes beg\n", arg->treelist_id ); fflush( stdout );
2413 
2414  NODE *node = arg;
2415  auto *NearNodes = &node->NearNodes;
2416  if ( NNPRUNE ) NearNodes = &node->NNNearNodes;
2417  auto &K = *node->setup->K;
2418  auto &data = node->data;
2419  auto &amap = node->gids;
2420  vector<size_t> bmap;
2421  for ( auto it = NearNodes->begin(); it != NearNodes->end(); it ++ )
2422  {
2423  bmap.insert( bmap.end(), (*it)->gids.begin(), (*it)->gids.end() );
2424  }
2425  data.NearKab = K( amap, bmap );
2426 
2428  data.Nearbmap.resize( bmap.size(), 1 );
2429  for ( size_t i = 0; i < bmap.size(); i ++ )
2430  data.Nearbmap[ i ] = bmap[ i ];
2431 
2432 #ifdef HMLP_USE_CUDA
2433  auto *device = hmlp_get_device( 0 );
2435  node->data.Nearbmap.PrefetchH2D( device, 8 );
2436 
2437  size_t preserve_size = 3000000000;
2438  //if ( data.NearKab.col() * MAX_NRHS < 1200000000 &&
2439  // data.NearKab.size() * 8 + preserve_size < device->get_memory_left() &&
2440  // data.NearKab.size() * 8 > 4096 * 4096 * 8 * 4 )
2441  if ( data.NearKab.col() * MAX_NRHS < 1200000000 &&
2442  data.NearKab.size() * 8 + preserve_size < device->get_memory_left() )
2443  {
2445  data.NearKab.PrefetchH2D( device, 8 );
2446  }
2447  else
2448  {
2449  printf( "Kab %lu %lu not cache\n", data.NearKab.row(), data.NearKab.col() );
2450  }
2451 #endif
2452 
2453  //printf( "%lu CacheNearNodesTask end\n", arg->treelist_id ); fflush( stdout );
2454  };
2455 };
2468 template<typename NODE>
2469 void FindFarNodes( NODE *node, NODE *target )
2470 {
2472  assert( target->isleaf );
2473 
2475  set<NODE*> *NearNodes;
2476  auto &data = node->data;
2477  auto *lchild = node->lchild;
2478  auto *rchild = node->rchild;
2479 
2487  NearNodes = &target->NearNodes;
2488 
2490  if ( !data.isskel || node->ContainAny( *NearNodes ) )
2491  {
2492  if ( !node->isleaf )
2493  {
2495  FindFarNodes( lchild, target );
2496  FindFarNodes( rchild, target );
2497  }
2498  }
2499  else
2500  {
2502  target->FarNodes.insert( node );
2503  }
2504 
2512  NearNodes = &target->NNNearNodes;
2513 
2515  if ( !data.isskel || node->ContainAny( *NearNodes ) )
2516  {
2517  if ( !node->isleaf )
2518  {
2520  FindFarNodes( lchild, target );
2521  FindFarNodes( rchild, target );
2522  }
2523  }
2524  else
2525  {
2526  if ( node->setup->IsSymmetric() && ( node->morton < target->morton ) )
2527  {
2532  }
2533  else
2534  {
2535  target->NNFarNodes.insert( node );
2536  }
2537  }
2538 
2539 };
2552 template<typename TREE>
2553 void MergeFarNodes( TREE &tree )
2554 {
2555  for ( int l = tree.depth; l >= 0; l -- )
2556  {
2557  size_t n_nodes = ( 1 << l );
2558  auto level_beg = tree.treelist.begin() + n_nodes - 1;
2559 
2560  for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2561  {
2562  auto *node = *(level_beg + node_ind);
2563 
2565  if ( !node->data.isskel ) continue;
2566 
2567  if ( node->isleaf )
2568  {
2569  FindFarNodes( tree.treelist[ 0 ] , node );
2570  }
2571  else
2572  {
2574  auto *lchild = node->lchild;
2575  auto *rchild = node->rchild;
2576 
2578  auto &pFarNodes = node->FarNodes;
2579  auto &lFarNodes = lchild->FarNodes;
2580  auto &rFarNodes = rchild->FarNodes;
2582  for ( auto it = lFarNodes.begin(); it != lFarNodes.end(); ++ it )
2583  {
2584  if ( rFarNodes.count( *it ) ) pFarNodes.insert( *it );
2585  }
2587  for ( auto it = pFarNodes.begin(); it != pFarNodes.end(); it ++ )
2588  {
2589  lFarNodes.erase( *it ); rFarNodes.erase( *it );
2590  }
2591 
2592 
2594  auto &pNNFarNodes = node->NNFarNodes;
2595  auto &lNNFarNodes = lchild->NNFarNodes;
2596  auto &rNNFarNodes = rchild->NNFarNodes;
2597 
2598  //printf( "node %lu\n", node->treelist_id );
2599  //PrintSet( pNNFarNodes );
2600  //PrintSet( lNNFarNodes );
2601  //PrintSet( rNNFarNodes );
2602 
2603 
2605  for ( auto it = lNNFarNodes.begin(); it != lNNFarNodes.end(); ++ it )
2606  {
2607  if ( rNNFarNodes.count( *it ) ) pNNFarNodes.insert( *it );
2608  }
2610  for ( auto it = pNNFarNodes.begin(); it != pNNFarNodes.end(); it ++ )
2611  {
2612  lNNFarNodes.erase( *it );
2613  rNNFarNodes.erase( *it );
2614  }
2615 
2616  //PrintSet( pNNFarNodes );
2617  //PrintSet( lNNFarNodes );
2618  //PrintSet( rNNFarNodes );
2619  }
2620  }
2621  }
2622 
2623  if ( tree.setup.IsSymmetric() )
2624  {
2626  for ( int l = tree.depth; l >= 0; l -- )
2627  {
2628  std::size_t n_nodes = 1 << l;
2629  auto level_beg = tree.treelist.begin() + n_nodes - 1;
2630 
2631  for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2632  {
2633  auto *node = *(level_beg + node_ind);
2634  auto &pFarNodes = node->NNFarNodes;
2635  for ( auto it = pFarNodes.begin(); it != pFarNodes.end(); it ++ )
2636  {
2637  (*it)->NNFarNodes.insert( node );
2638  }
2639  }
2640  }
2641  }
2642 
2643 #ifdef DEBUG_SPDASKIT
2644  for ( int l = tree.depth; l >= 0; l -- )
2645  {
2646  std::size_t n_nodes = 1 << l;
2647  auto level_beg = tree.treelist.begin() + n_nodes - 1;
2648 
2649  for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2650  {
2651  auto *node = *(level_beg + node_ind);
2652  auto &pFarNodes = node->NNFarNodes;
2653  for ( auto it = pFarNodes.begin(); it != pFarNodes.end(); it ++ )
2654  {
2655  if ( !( (*it)->NNFarNodes.count( node ) ) )
2656  {
2657  printf( "Unsymmetric FarNodes %lu, %lu\n", node->treelist_id, (*it)->treelist_id );
2658  printf( "Near\n" );
2659  PrintSet( node->NNNearNodes );
2660  PrintSet( (*it)->NNNearNodes );
2661  printf( "Far\n" );
2662  PrintSet( node->NNFarNodes );
2663  PrintSet( (*it)->NNFarNodes );
2664  printf( "======\n" );
2665  break;
2666  }
2667  }
2668  if ( pFarNodes.size() )
2669  {
2670  printf( "l %2lu FarNodes(%lu) ", node->l, node->treelist_id );
2671  PrintSet( pFarNodes );
2672  }
2673  }
2674  }
2675 #endif
2676 };
2677 
2678 
2686 template<bool NNPRUNE, bool CACHE = true, typename TREE>
2687 void CacheFarNodes( TREE &tree )
2688 {
2690  #pragma omp parallel for schedule( dynamic )
2691  for ( size_t i = 0; i < tree.treelist.size(); i ++ )
2692  {
2693  auto *node = tree.treelist[ i ];
2694  if ( node->isleaf )
2695  {
2696  node->data.w_leaf.reserve( node->gids.size(), MAX_NRHS );
2697  node->data.u_leaf[ 0 ].reserve( MAX_NRHS, node->gids.size() );
2698  }
2699  }
2700 
2702  if ( CACHE )
2703  {
2705  #pragma omp parallel for schedule( dynamic )
2706  for ( size_t i = 0; i < tree.treelist.size(); i ++ )
2707  {
2708  auto *node = tree.treelist[ i ];
2709  auto *FarNodes = &node->FarNodes;
2710  if ( NNPRUNE ) FarNodes = &node->NNFarNodes;
2711  auto &K = *node->setup->K;
2712  auto &data = node->data;
2713  auto &amap = data.skels;
2714  std::vector<size_t> bmap;
2715  for ( auto it = FarNodes->begin(); it != FarNodes->end(); it ++ )
2716  {
2717  bmap.insert( bmap.end(), (*it)->data.skels.begin(),
2718  (*it)->data.skels.end() );
2719  }
2720  data.FarKab = K( amap, bmap );
2721  }
2722  }
2723 };
2729 template<bool NNPRUNE, typename TREE>
2730 double DrawInteraction( TREE &tree )
2731 {
2732  double exact_ratio = 0.0;
2733  FILE * pFile;
2734  //int n;
2735  char name[ 100 ];
2736 
2737  pFile = fopen ( "interaction.m", "w" );
2738 
2739  fprintf( pFile, "figure('Position',[100,100,800,800]);" );
2740  fprintf( pFile, "hold on;" );
2741  fprintf( pFile, "axis square;" );
2742  fprintf( pFile, "axis ij;" );
2743 
2744  for ( int l = tree.depth; l >= 0; l -- )
2745  {
2746  std::size_t n_nodes = 1 << l;
2747  auto level_beg = tree.treelist.begin() + n_nodes - 1;
2748 
2749  for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2750  {
2751  auto *node = *(level_beg + node_ind);
2752 
2753  if ( NNPRUNE )
2754  {
2755  auto &pNearNodes = node->NNNearNodes;
2756  auto &pFarNodes = node->NNFarNodes;
2757  for ( auto it = pFarNodes.begin(); it != pFarNodes.end(); it ++ )
2758  {
2759  double gb = (double)std::min( node->l, (*it)->l ) / tree.depth;
2760  //printf( "node->l %lu (*it)->l %lu depth %lu\n", node->l, (*it)->l, tree.depth );
2761  fprintf( pFile, "rectangle('position',[%lu %lu %lu %lu],'facecolor',[1.0,%lf,%lf]);\n",
2762  node->offset, (*it)->offset,
2763  node->gids.size(), (*it)->gids.size(),
2764  gb, gb );
2765  }
2766  for ( auto it = pNearNodes.begin(); it != pNearNodes.end(); it ++ )
2767  {
2768  fprintf( pFile, "rectangle('position',[%lu %lu %lu %lu],'facecolor',[0.2,0.4,1.0]);\n",
2769  node->offset, (*it)->offset,
2770  node->gids.size(), (*it)->gids.size() );
2771 
2773  exact_ratio += node->gids.size() * (*it)->gids.size();
2774  }
2775  }
2776  else
2777  {
2778  }
2779  }
2780  }
2781  fprintf( pFile, "hold off;" );
2782  fclose( pFile );
2783 
2784  return exact_ratio / ( tree.n * tree.n );
2785 };
2796 template<bool SYMBOLIC, bool NNPRUNE, typename NODE, typename T>
2797 void Evaluate
2798 (
2799  NODE *node,
2800  size_t gid,
2801  vector<size_t> &nnandi, // k + 1 non-prunable lists
2802  Data<T> &potentials
2803 )
2804 {
2805  auto &K = *node->setup->K;
2806  auto &w = *node->setup->w;
2807  auto &gids = node->gids;
2808  auto &data = node->data;
2809  auto *lchild = node->lchild;
2810  auto *rchild = node->rchild;
2811 
2812  size_t nrhs = w.col();
2813 
2814  auto amap = std::vector<size_t>( 1 );
2815  amap[ 0 ] = gid;
2816 
2817  if ( !SYMBOLIC ) // No potential evaluation.
2818  {
2819  assert( potentials.size() == amap.size() * nrhs );
2820  }
2821 
2822  if ( !data.isskel || node->ContainAny( nnandi ) )
2823  {
2824  if ( node->isleaf )
2825  {
2826  if ( SYMBOLIC )
2827  {
2829  data.lock.Acquire();
2830  {
2831  if ( NNPRUNE ) node->NNNearIDs.insert( gid );
2832  else node->NearIDs.insert( gid );
2833  }
2834  data.lock.Release();
2835  }
2836  else
2837  {
2838 #ifdef DEBUG_SPDASKIT
2839  printf( "level %lu direct evaluation\n", node->l );
2840 #endif
2841 
2842  auto Kab = K( amap, gids );
2843 
2845  std::vector<size_t> bmap( nrhs );
2846  for ( size_t j = 0; j < bmap.size(); j ++ )
2847  bmap[ j ] = j;
2848 
2850  auto wb = w( gids, bmap );
2851 
2852  xgemm
2853  (
2854  "N", "N",
2855  Kab.row(), wb.col(), wb.row(),
2856  1.0, Kab.data(), Kab.row(),
2857  wb.data(), wb.row(),
2858  1.0, potentials.data(), potentials.row()
2859  );
2860  }
2861  }
2862  else
2863  {
2864  Evaluate<SYMBOLIC, NNPRUNE>( lchild, gid, nnandi, potentials );
2865  Evaluate<SYMBOLIC, NNPRUNE>( rchild, gid, nnandi, potentials );
2866  }
2867  }
2868  else // need gid's morton and neighbors' mortons
2869  {
2870  //printf( "level %lu is prunable\n", node->l );
2871  if ( SYMBOLIC )
2872  {
2873  data.lock.Acquire();
2874  {
2875  // Add gid to prunable list.
2876  if ( NNPRUNE ) node->FarIDs.insert( gid );
2877  else node->NNFarIDs.insert( gid );
2878  }
2879  data.lock.Release();
2880  }
2881  else
2882  {
2883 #ifdef DEBUG_SPDASKIT
2884  printf( "level %lu is prunable\n", node->l );
2885 #endif
2886  auto Kab = K( amap, node->data.skels );
2887  auto &w_skel = node->data.w_skel;
2888  xgemm
2889  (
2890  "N", "N",
2891  Kab.row(), w_skel.col(), w_skel.row(),
2892  1.0, Kab.data(), Kab.row(),
2893  w_skel.data(), w_skel.row(),
2894  1.0, potentials.data(), potentials.row()
2895  );
2896  }
2897  }
2898 
2899 
2900 
2901 };
2908 template<bool SYMBOLIC, bool NNPRUNE, typename TREE, typename T>
2909 void Evaluate
2910 (
2911  TREE &tree,
2912  size_t gid,
2913  Data<T> &potentials
2914 )
2915 {
2916  vector<size_t> nnandi;
2917  auto &w = *tree.setup.w;
2918 
2919  potentials.clear();
2920  potentials.resize( 1, w.col(), 0.0 );
2921 
2922  if ( NNPRUNE )
2923  {
2924  auto &NN = *tree.setup.NN;
2925  nnandi.reserve( NN.row() + 1 );
2926  nnandi.push_back( gid );
2927  for ( size_t i = 0; i < NN.row(); i ++ )
2928  {
2929  nnandi.push_back( NN( i, gid ).second );
2930  }
2931 #ifdef DEBUG_SPDASKIT
2932  printf( "nnandi.size() %lu\n", nnandi.size() );
2933 #endif
2934  }
2935  else
2936  {
2937  nnandi.reserve( 1 );
2938  nnandi.push_back( gid );
2939  }
2940 
2941  Evaluate<SYMBOLIC, NNPRUNE>( tree.treelist[ 0 ], gid, nnandi, potentials );
2942 
2943 };
2949 template<
2950  bool USE_RUNTIME = true,
2951  bool USE_OMP_TASK = false,
2952  bool NNPRUNE = true,
2953  bool CACHE = true,
2954  typename TREE,
2955  typename T>
2956 Data<T> Evaluate
2957 (
2958  TREE &tree,
2959  Data<T> &weights
2960 )
2961 {
2962  const bool AUTO_DEPENDENCY = true;
2963 
2965  using NODE = typename TREE::NODE;
2966 
2968  double beg, time_ratio, evaluation_time = 0.0;
2969  double allocate_time, computeall_time;
2970  double forward_permute_time, backward_permute_time;
2971 
2973  tree.DependencyCleanUp();
2974 
2976  size_t n = weights.row();
2977  size_t nrhs = weights.col();
2978 
2979  beg = omp_get_wtime();
2980  hmlp::Data<T> potentials( n, nrhs, 0.0 );
2981  tree.setup.w = &weights;
2982  tree.setup.u = &potentials;
2983  allocate_time = omp_get_wtime() - beg;
2984 
2986  if ( REPORT_EVALUATE_STATUS )
2987  {
2988  printf( "Forward permute ...\n" ); fflush( stdout );
2989  }
2990  beg = omp_get_wtime();
2991  int n_nodes = ( 1 << tree.depth );
2992  auto level_beg = tree.treelist.begin() + n_nodes - 1;
2993  #pragma omp parallel for
2994  for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2995  {
2996  auto *node = *(level_beg + node_ind);
2997 
2998 
2999  auto &gids = node->gids;
3000  auto &w_leaf = node->data.w_leaf;
3001 
3002  if ( w_leaf.row() != gids.size() || w_leaf.col() != weights.col() )
3003  {
3004  w_leaf.resize( gids.size(), weights.col() );
3005  }
3006 
3007  for ( size_t j = 0; j < w_leaf.col(); j ++ )
3008  {
3009  for ( size_t i = 0; i < w_leaf.row(); i ++ )
3010  {
3011  w_leaf( i, j ) = weights( gids[ i ], j );
3012  }
3013  };
3014  }
3015  forward_permute_time = omp_get_wtime() - beg;
3016 
3017 
3018 
3020  if ( REPORT_EVALUATE_STATUS )
3021  {
3022  printf( "N2S, S2S, S2N, L2L (HMLP Runtime) ...\n" ); fflush( stdout );
3023  }
3024  if ( tree.setup.IsSymmetric() )
3025  {
3026  beg = omp_get_wtime();
3027 #ifdef HMLP_USE_CUDA
3028  potentials.AllocateD( hmlp_get_device( 0 ) );
3029  using LEAFTOLEAFVER2TASK = gpu::LeavesToLeavesVer2Task<CACHE, NNPRUNE, NODE, T>;
3030  LEAFTOLEAFVER2TASK leaftoleafver2task;
3031 #endif
3032  using LEAFTOLEAFTASK1 = LeavesToLeavesTask<1, NNPRUNE, NODE, T>;
3033  using LEAFTOLEAFTASK2 = LeavesToLeavesTask<2, NNPRUNE, NODE, T>;
3034  using LEAFTOLEAFTASK3 = LeavesToLeavesTask<3, NNPRUNE, NODE, T>;
3035  using LEAFTOLEAFTASK4 = LeavesToLeavesTask<4, NNPRUNE, NODE, T>;
3036 
3037  using NODETOSKELTASK = UpdateWeightsTask<NODE, T>;
3038  using SKELTOSKELTASK = SkeletonsToSkeletonsTask<NNPRUNE, NODE, T>;
3039  using SKELTONODETASK = SkeletonsToNodesTask<NNPRUNE, NODE, T>;
3040 
3041  LEAFTOLEAFTASK1 leaftoleaftask1;
3042  LEAFTOLEAFTASK2 leaftoleaftask2;
3043  LEAFTOLEAFTASK3 leaftoleaftask3;
3044  LEAFTOLEAFTASK4 leaftoleaftask4;
3045 
3046  NODETOSKELTASK nodetoskeltask;
3047  SKELTOSKELTASK skeltoskeltask;
3048  SKELTONODETASK skeltonodetask;
3049 
3050 
3051 // if ( USE_OMP_TASK )
3052 // {
3053 // assert( !USE_RUNTIME );
3054 // tree.template TraverseLeafs<false, false>( leaftoleaftask1 );
3055 // tree.template TraverseLeafs<false, false>( leaftoleaftask2 );
3056 // tree.template TraverseLeafs<false, false>( leaftoleaftask3 );
3057 // tree.template TraverseLeafs<false, false>( leaftoleaftask4 );
3058 // tree.template UpDown<true, true, true>( nodetoskeltask, skeltoskeltask, skeltonodetask );
3059 // }
3060 // else
3061 // {
3062 // assert( !USE_OMP_TASK );
3063 //
3064 //#ifdef HMLP_USE_CUDA
3065 // tree.template TraverseLeafs<AUTO_DEPENDENCY, USE_RUNTIME>( leaftoleafver2task );
3066 //#else
3067 // tree.template TraverseLeafs<AUTO_DEPENDENCY, USE_RUNTIME>( leaftoleaftask1 );
3068 // tree.template TraverseLeafs<AUTO_DEPENDENCY, USE_RUNTIME>( leaftoleaftask2 );
3069 // tree.template TraverseLeafs<AUTO_DEPENDENCY, USE_RUNTIME>( leaftoleaftask3 );
3070 // tree.template TraverseLeafs<AUTO_DEPENDENCY, USE_RUNTIME>( leaftoleaftask4 );
3071 //#endif
3072 //
3073 // /** check scheduler */
3074 // //hmlp_get_runtime_handle()->scheduler->ReportRemainingTime();
3075 // tree.template TraverseUp <AUTO_DEPENDENCY, USE_RUNTIME>( nodetoskeltask );
3076 // tree.template TraverseUnOrdered<AUTO_DEPENDENCY, USE_RUNTIME>( skeltoskeltask );
3077 // tree.template TraverseDown <AUTO_DEPENDENCY, USE_RUNTIME>( skeltonodetask );
3078 // /** check scheduler */
3079 // //hmlp_get_runtime_handle()->scheduler->ReportRemainingTime();
3080 //
3081 // if ( USE_RUNTIME ) hmlp_run();
3082 //
3083 //
3084 //
3085 //#ifdef HMLP_USE_CUDA
3086 // hmlp::Device *device = hmlp_get_device( 0 );
3087 // for ( int stream_id = 0; stream_id < 10; stream_id ++ )
3088 // device->wait( stream_id );
3089 // //potentials.PrefetchD2H( device, 0 );
3090 // potentials.FetchD2H( device );
3091 //#endif
3092 // }
3093 
3094 
3095 
3096 
3098 #ifdef HMLP_USE_CUDA
3099  tree.TraverseLeafs( leaftoleafver2task );
3100 #else
3101  tree.TraverseLeafs( leaftoleaftask1 );
3102  tree.TraverseLeafs( leaftoleaftask2 );
3103  tree.TraverseLeafs( leaftoleaftask3 );
3104  tree.TraverseLeafs( leaftoleaftask4 );
3105 #endif
3106  tree.TraverseUp( nodetoskeltask );
3107  tree.TraverseUnOrdered( skeltoskeltask );
3108  tree.TraverseDown( skeltonodetask );
3109  tree.ExecuteAllTasks();
3110  //if ( USE_RUNTIME ) hmlp_run();
3111 
3112 
3113  double d2h_beg_t = omp_get_wtime();
3114 #ifdef HMLP_USE_CUDA
3115  hmlp::Device *device = hmlp_get_device( 0 );
3116  for ( int stream_id = 0; stream_id < 10; stream_id ++ )
3117  device->wait( stream_id );
3118  //potentials.PrefetchD2H( device, 0 );
3119  potentials.FetchD2H( device );
3120 #endif
3121  double d2h_t = omp_get_wtime() - d2h_beg_t;
3122  printf( "d2h_t %lfs\n", d2h_t );
3123 
3124 
3125  double aggregate_beg_t = omp_get_wtime();
3127  #pragma omp parallel for
3128  for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
3129  {
3130  auto *node = *(level_beg + node_ind);
3131  auto &u_leaf = node->data.u_leaf[ 0 ];
3133  for ( size_t p = 1; p < 20; p ++ )
3134  {
3135  for ( size_t i = 0; i < node->data.u_leaf[ p ].size(); i ++ )
3136  u_leaf[ i ] += node->data.u_leaf[ p ][ i ];
3137  }
3138  }
3139  double aggregate_t = omp_get_wtime() - aggregate_beg_t;
3140  printf( "aggregate_t %lfs\n", d2h_t );
3141 
3142 #ifdef HMLP_USE_CUDA
3143  device->wait( 0 );
3144 #endif
3145  computeall_time = omp_get_wtime() - beg;
3146  }
3147  else // TODO: implement unsymmetric prunning
3148  {
3150  printf( "Non symmetric ComputeAll is not yet implemented\n" );
3151  exit( 1 );
3152  }
3153 
3154 
3155 
3157  if ( REPORT_EVALUATE_STATUS )
3158  {
3159  printf( "Backward permute ...\n" ); fflush( stdout );
3160  }
3161  beg = omp_get_wtime();
3162  #pragma omp parallel for
3163  for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ )
3164  {
3165  auto *node = *(level_beg + node_ind);
3166  auto &amap = node->gids;
3167  auto &u_leaf = node->data.u_leaf[ 0 ];
3168 
3169 
3170 
3172  //for ( size_t j = 0; j < amap.size(); j ++ )
3173  // for ( size_t i = 0; i < potentials.row(); i ++ )
3174  // potentials[ amap[ j ] * potentials.row() + i ] += u_leaf( j, i );
3175 
3176 
3177  for ( size_t j = 0; j < potentials.col(); j ++ )
3178  for ( size_t i = 0; i < amap.size(); i ++ )
3179  potentials( amap[ i ], j ) += u_leaf( i, j );
3180 
3181 
3182  }
3183  backward_permute_time = omp_get_wtime() - beg;
3184 
3185  evaluation_time += allocate_time;
3186  evaluation_time += forward_permute_time;
3187  evaluation_time += computeall_time;
3188  evaluation_time += backward_permute_time;
3189  time_ratio = 100 / evaluation_time;
3190 
3191  if ( REPORT_EVALUATE_STATUS )
3192  {
3193  printf( "========================================================\n");
3194  printf( "GOFMM evaluation phase\n" );
3195  printf( "========================================================\n");
3196  printf( "Allocate ------------------------------ %5.2lfs (%5.1lf%%)\n",
3197  allocate_time, allocate_time * time_ratio );
3198  printf( "Forward permute ----------------------- %5.2lfs (%5.1lf%%)\n",
3199  forward_permute_time, forward_permute_time * time_ratio );
3200  printf( "N2S, S2S, S2N, L2L -------------------- %5.2lfs (%5.1lf%%)\n",
3201  computeall_time, computeall_time * time_ratio );
3202  printf( "Backward permute ---------------------- %5.2lfs (%5.1lf%%)\n",
3203  backward_permute_time, backward_permute_time * time_ratio );
3204  printf( "========================================================\n");
3205  printf( "Evaluate ------------------------------ %5.2lfs (%5.1lf%%)\n",
3206  evaluation_time, evaluation_time * time_ratio );
3207  printf( "========================================================\n\n");
3208  }
3209 
3210 
3212  tree.DependencyCleanUp();
3213 
3215  return potentials;
3216 
3217 };
3221 template<typename SPLITTER, typename T, typename SPDMATRIX>
3222 Data<pair<T, size_t>> FindNeighbors
3223 (
3224  SPDMATRIX &K,
3225  SPLITTER splitter,
3226  Configuration<T> &config,
3227  size_t n_iter = 10
3228 )
3229 {
3231  using DATA = gofmm::NodeData<T>;
3233  using TREE = tree::Tree<SETUP, DATA>;
3235  using NODE = typename TREE::NODE;
3237  DistanceMetric metric = config.MetricType();
3238  size_t n = config.ProblemSize();
3239  size_t k = config.NeighborSize();
3241  pair<T, size_t> init( numeric_limits<T>::max(), n );
3242  gofmm::NeighborsTask<NODE, T> NEIGHBORStask;
3243  TREE rkdt;
3244  rkdt.setup.FromConfiguration( config, K, splitter, NULL );
3245  return rkdt.AllNearestNeighbor( n_iter, k, n_iter, init, NEIGHBORStask );
3246 };
3252 template<typename SPLITTER, typename RKDTSPLITTER, typename T, typename SPDMATRIX>
3254 *Compress
3255 (
3256  SPDMATRIX &K,
3257  Data<pair<T, size_t>> &NN,
3258  SPLITTER splitter,
3259  RKDTSPLITTER rkdtsplitter,
3260  Configuration<T> &config
3261 )
3262 {
3264  DistanceMetric metric = config.MetricType();
3265  size_t n = config.ProblemSize();
3266  size_t m = config.LeafNodeSize();
3267  size_t k = config.NeighborSize();
3268  size_t s = config.MaximumRank();
3269  T stol = config.Tolerance();
3270  T budget = config.Budget();
3271 
3273  const bool NNPRUNE = true;
3274  const bool CACHE = true;
3275 
3278  using DATA = gofmm::NodeData<T>;
3279  using TREE = tree::Tree<SETUP, DATA>;
3281  using NODE = typename TREE::NODE;
3282 
3283 
3285  double beg, omptask45_time, omptask_time, ref_time;
3286  double time_ratio, compress_time = 0.0, other_time = 0.0;
3287  double ann_time, tree_time, skel_time, mergefarnodes_time, cachefarnodes_time;
3288  double nneval_time, nonneval_time, fmm_evaluation_time, symbolic_evaluation_time;
3289 
3291  beg = omp_get_wtime();
3292  if ( NN.size() != n * k )
3293  {
3294  NN = gofmm::FindNeighbors( K, rkdtsplitter, config );
3295  }
3296  ann_time = omp_get_wtime() - beg;
3297 
3298 
3300  auto *tree_ptr = new TREE();
3301  auto &tree = *tree_ptr;
3302  tree.setup.FromConfiguration( config, K, splitter, &NN );
3303 
3304 
3305  if ( REPORT_COMPRESS_STATUS )
3306  {
3307  printf( "TreePartitioning ...\n" ); fflush( stdout );
3308  }
3309  beg = omp_get_wtime();
3310  tree.TreePartition();
3311  tree_time = omp_get_wtime() - beg;
3312 
3313 
3314 #ifdef HMLP_AVX512
3315 
3316  assert( omp_get_max_threads() == 68 );
3317  //mkl_set_dynamic( 0 );
3318  //mkl_set_num_threads( 4 );
3319  hmlp_set_num_workers( 17 );
3320 #else
3321  //if ( omp_get_max_threads() > 8 )
3322  //{
3323  // hmlp_set_num_workers( omp_get_max_threads() / 2 );
3324  //}
3325  if ( REPORT_COMPRESS_STATUS )
3326  {
3327  printf( "omp_get_max_threads() %d\n", omp_get_max_threads() );
3328  }
3329 #endif
3330 
3331 
3332 
3333 
3335  NearSamplesTask<NODE, T> NEARSAMPLEStask;
3336  tree.DependencyCleanUp();
3337  printf( "Dependency clean up\n" ); fflush( stdout );
3338  tree.TraverseLeafs( NEARSAMPLEStask );
3339  tree.ExecuteAllTasks();
3340  //hmlp_run();
3341  printf( "Finish NearSamplesTask\n" ); fflush( stdout );
3342  SymmetrizeNearInteractions( tree );
3343  printf( "Finish SymmetrizeNearInteractions\n" ); fflush( stdout );
3344 
3345 
3346 
3348  if ( REPORT_COMPRESS_STATUS )
3349  {
3350  printf( "Skeletonization (HMLP Runtime) ...\n" ); fflush( stdout );
3351  }
3352  beg = omp_get_wtime();
3356  tree.DependencyCleanUp();
3357  tree.TraverseUp( GETMTXtask, SKELtask );
3358  tree.TraverseUnOrdered( PROJtask );
3359  if ( CACHE )
3360  {
3362  tree.template TraverseLeafs( KIJtask );
3363  }
3364  other_time += omp_get_wtime() - beg;
3365  hmlp_run();
3366  skel_time = omp_get_wtime() - beg;
3367 
3368 
3369 
3370 
3372  beg = omp_get_wtime();
3373  if ( REPORT_COMPRESS_STATUS )
3374  {
3375  printf( "MergeFarNodes ...\n" ); fflush( stdout );
3376  }
3377  gofmm::MergeFarNodes( tree );
3378  mergefarnodes_time = omp_get_wtime() - beg;
3379 
3381  beg = omp_get_wtime();
3382  if ( REPORT_COMPRESS_STATUS )
3383  {
3384  printf( "CacheFarNodes ...\n" ); fflush( stdout );
3385  }
3386  gofmm::CacheFarNodes<NNPRUNE, CACHE>( tree );
3387  cachefarnodes_time = omp_get_wtime() - beg;
3388 
3390  auto exact_ratio = hmlp::gofmm::DrawInteraction<true>( tree );
3391 
3392  compress_time += ann_time;
3393  compress_time += tree_time;
3394  compress_time += skel_time;
3395  compress_time += mergefarnodes_time;
3396  compress_time += cachefarnodes_time;
3397  time_ratio = 100.0 / compress_time;
3398  if ( REPORT_COMPRESS_STATUS )
3399  {
3400  printf( "========================================================\n");
3401  printf( "GOFMM compression phase\n" );
3402  printf( "========================================================\n");
3403  printf( "NeighborSearch ------------------------ %5.2lfs (%5.1lf%%)\n", ann_time, ann_time * time_ratio );
3404  printf( "TreePartitioning ---------------------- %5.2lfs (%5.1lf%%)\n", tree_time, tree_time * time_ratio );
3405  printf( "Skeletonization ----------------------- %5.2lfs (%5.1lf%%)\n", skel_time, skel_time * time_ratio );
3406  printf( "MergeFarNodes ------------------------- %5.2lfs (%5.1lf%%)\n", mergefarnodes_time, mergefarnodes_time * time_ratio );
3407  printf( "CacheFarNodes ------------------------- %5.2lfs (%5.1lf%%)\n", cachefarnodes_time, cachefarnodes_time * time_ratio );
3408  printf( "========================================================\n");
3409  printf( "Compress (%4.2lf not compressed) -------- %5.2lfs (%5.1lf%%)\n",
3410  exact_ratio, compress_time, compress_time * time_ratio );
3411  printf( "========================================================\n\n");
3412  }
3413 
3415  tree_ptr->DependencyCleanUp();
3416 
3418  return tree_ptr;
3419 
3420 };
3465 template<typename T, typename SPDMATRIX>
3466 tree::Tree<
3469 *Compress( SPDMATRIX &K, T stol, T budget, size_t m, size_t k, size_t s )
3470 {
3471  using SPLITTER = centersplit<SPDMATRIX, 2, T>;
3472  using RKDTSPLITTER = randomsplit<SPDMATRIX, 2, T>;
3475  SPLITTER splitter( K );
3476  splitter.Kptr = &K;
3477  splitter.metric = ANGLE_DISTANCE;
3479  RKDTSPLITTER rkdtsplitter( K );
3480  rkdtsplitter.Kptr = &K;
3481  rkdtsplitter.metric = ANGLE_DISTANCE;
3482  size_t n = K.row();
3483 
3485  Configuration<T> config( ANGLE_DISTANCE, n, m, k, s, stol, budget );
3486 
3488  return Compress<SPLITTER, RKDTSPLITTER>
3489  ( K, NN, //ANGLE_DISTANCE,
3490  splitter, rkdtsplitter, //n, m, k, s, stol, budget,
3491  config );
3492 };
3504 template<typename T, typename SPDMATRIX>
3505 tree::Tree<
3506  gofmm::Setup<SPDMATRIX, centersplit<SPDMATRIX, 2, T>, T>,
3508 *Compress( SPDMATRIX &K, T stol, T budget )
3509 {
3510  using SPLITTER = centersplit<SPDMATRIX, 2, T>;
3511  using RKDTSPLITTER = randomsplit<SPDMATRIX, 2, T>;
3514  SPLITTER splitter( K );
3515  splitter.Kptr = &K;
3516  splitter.metric = ANGLE_DISTANCE;
3518  RKDTSPLITTER rkdtsplitter( K );
3519  rkdtsplitter.Kptr = &K;
3520  rkdtsplitter.metric = ANGLE_DISTANCE;
3521  size_t n = K.row();
3522  size_t m = 128;
3523  size_t k = 16;
3524  size_t s = m;
3525 
3527  if ( n >= 16384 )
3528  {
3529  m = 128;
3530  k = 20;
3531  s = 256;
3532  }
3533 
3534  if ( n >= 32768 )
3535  {
3536  m = 256;
3537  k = 24;
3538  s = 384;
3539  }
3540 
3541  if ( n >= 65536 )
3542  {
3543  m = 512;
3544  k = 32;
3545  s = 512;
3546  }
3547 
3549  Configuration<T> config( ANGLE_DISTANCE, n, m, k, s, stol, budget );
3550 
3552  return Compress<SPLITTER, RKDTSPLITTER>
3553  ( K, NN, //ANGLE_DISTANCE,
3554  splitter, rkdtsplitter, config );
3555 
3556 };
3561 template<typename T>
3562 tree::Tree<
3565 *Compress( SPDMatrix<T> &K, T stol, T budget )
3566 {
3567  return Compress<T, SPDMatrix<T>>( K, stol, budget );
3568 };
3576 template<typename NODE, typename T>
3577 void ComputeError( NODE *node, Data<T> potentials )
3578 {
3579  auto &K = *node->setup->K;
3580  auto &w = node->setup->w;
3581 
3582  auto &amap = node->gids;
3583  std::vector<size_t> bmap = std::vector<size_t>( K.col() );
3584 
3585  for ( size_t j = 0; j < bmap.size(); j ++ ) bmap[ j ] = j;
3586 
3587  auto Kab = K( amap, bmap );
3588 
3589  auto nrm2 = hmlp_norm( potentials.row(), potentials.col(),
3590  potentials.data(), potentials.row() );
3591 
3592  xgemm
3593  (
3594  "N", "T",
3595  Kab.row(), w.row(), w.col(),
3596  -1.0, Kab.data(), Kab.row(),
3597  w.data(), w.row(),
3598  1.0, potentials.data(), potentials.row()
3599  );
3600 
3601  auto err = hmlp_norm( potentials.row(), potentials.col(),
3602  potentials.data(), potentials.row() );
3603 
3604  printf( "node relative error %E, nrm2 %E\n", err / nrm2, nrm2 );
3605 
3606 
3607 }; // end void ComputeError()
3608 
3609 
3610 
3611 
3612 
3613 
3614 
3615 
3619 template<typename TREE, typename T>
3620 T ComputeError( TREE &tree, size_t gid, Data<T> potentials )
3621 {
3622  auto &K = *tree.setup.K;
3623  auto &w = *tree.setup.w;
3624 
3625  auto amap = std::vector<size_t>( 1, gid );
3626  auto bmap = std::vector<size_t>( K.col() );
3627  for ( size_t j = 0; j < bmap.size(); j ++ ) bmap[ j ] = j;
3628 
3629  auto Kab = K( amap, bmap );
3630  auto exact = potentials;
3631 
3632  xgemm
3633  (
3634  "N", "N",
3635  Kab.row(), w.col(), w.row(),
3636  1.0, Kab.data(), Kab.row(),
3637  w.data(), w.row(),
3638  0.0, exact.data(), exact.row()
3639  );
3640 
3641 
3642  auto nrm2 = hmlp_norm( exact.row(), exact.col(),
3643  exact.data(), exact.row() );
3644 
3645  xgemm
3646  (
3647  "N", "N",
3648  Kab.row(), w.col(), w.row(),
3649  -1.0, Kab.data(), Kab.row(),
3650  w.data(), w.row(),
3651  1.0, potentials.data(), potentials.row()
3652  );
3653 
3654  auto err = hmlp_norm( potentials.row(), potentials.col(),
3655  potentials.data(), potentials.row() );
3656 
3657  return err / nrm2;
3658 };
3662 template<typename TREE>
3663 void SelfTesting( TREE &tree, size_t ntest, size_t nrhs )
3664 {
3666  using T = typename TREE::T;
3668  size_t n = tree.n;
3670  if ( ntest > n ) ntest = n;
3672  vector<size_t> all_rhs( nrhs );
3673  for ( size_t rhs = 0; rhs < nrhs; rhs ++ ) all_rhs[ rhs ] = rhs;
3674 
3675  //auto A = tree.CheckAllInteractions();
3676 
3678  Data<T> w( n, nrhs ); w.rand();
3679  auto u = Evaluate<true, false, true, true>( tree, w );
3680 
3682  T nnerr_avg = 0.0;
3683  T nonnerr_avg = 0.0;
3684  T fmmerr_avg = 0.0;
3685  printf( "========================================================\n");
3686  printf( "Accuracy report\n" );
3687  printf( "========================================================\n");
3688  for ( size_t i = 0; i < ntest; i ++ )
3689  {
3690  size_t tar = i * n / ntest;
3691  Data<T> potentials;
3693  Evaluate<false, true>( tree, tar, potentials );
3694  auto nnerr = ComputeError( tree, tar, potentials );
3696  Evaluate<false, false>( tree, tar, potentials );
3697  auto nonnerr = ComputeError( tree, tar, potentials );
3699  //potentials = u( vector<size_t>( i ), all_rhs );
3700  for ( size_t p = 0; p < potentials.col(); p ++ )
3701  {
3702  potentials[ p ] = u( tar, p );
3703  }
3704  auto fmmerr = ComputeError( tree, tar, potentials );
3705 
3707  if ( i < 10 )
3708  {
3709  printf( "gid %6lu, ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n",
3710  tar, nnerr, nonnerr, fmmerr );
3711  }
3712  nnerr_avg += nnerr;
3713  nonnerr_avg += nonnerr;
3714  fmmerr_avg += fmmerr;
3715  }
3716  printf( "========================================================\n");
3717  printf( " ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n",
3718  nnerr_avg / ntest , nonnerr_avg / ntest, fmmerr_avg / ntest );
3719  printf( "========================================================\n");
3720 
3721  if ( !tree.setup.SecureAccuracy() )
3722  {
3724  T lambda = 5.0;
3725  gofmm::Factorize( tree, lambda );
3727  gofmm::ComputeError( tree, lambda, w, u );
3728  }
3729 
3730 };
3734 template<typename SPDMATRIX>
3735 void LaunchHelper( SPDMATRIX &K, CommandLineHelper &cmd )
3736 {
3737  using T = typename SPDMATRIX::T;
3738 
3739  const int N_CHILDREN = 2;
3742  using RKDTSPLITTER = gofmm::randomsplit<SPDMATRIX, N_CHILDREN, T>;
3744  SPLITTER splitter( K );
3745  splitter.Kptr = &K;
3746  splitter.metric = cmd.metric;
3748  RKDTSPLITTER rkdtsplitter( K );
3749  rkdtsplitter.Kptr = &K;
3750  rkdtsplitter.metric = cmd.metric;
3752  gofmm::Configuration<T> config( cmd.metric,
3753  cmd.n, cmd.m, cmd.k, cmd.s, cmd.stol, cmd.budget );
3757  //auto *tree_ptr = gofmm::Compress( X, K, NN, splitter, rkdtsplitter, config );
3758  auto *tree_ptr = gofmm::Compress( K, NN, splitter, rkdtsplitter, config );
3759  auto &tree = *tree_ptr;
3761  gofmm::SelfTesting( tree, 100, cmd.nrhs );
3762 
3763 
3764 // //#ifdef DUMP_ANALYSIS_DATA
3765 // gofmm::Summary<NODE> summary;
3766 // tree.Summary( summary );
3767 // summary.Print();
3768 
3770  delete tree_ptr;
3771 };
3778 template<typename T, typename SPDMATRIX>
3780 {
3781  public:
3782 
3783  SimpleGOFMM( SPDMATRIX &K, T stol, T budget )
3784  {
3785  tree_ptr = Compress( K, stol, budget );
3786  };
3787 
3788  ~SimpleGOFMM()
3789  {
3790  if ( tree_ptr ) delete tree_ptr;
3791  };
3792 
3793  void Multiply( Data<T> &y, Data<T> &x )
3794  {
3795  //hmlp::Data<T> weights( x.col(), x.row() );
3796 
3797  //for ( size_t j = 0; j < x.col(); j ++ )
3798  // for ( size_t i = 0; i < x.row(); i ++ )
3799  // weights( j, i ) = x( i, j );
3800 
3801 
3802  y = gofmm::Evaluate( *tree_ptr, x );
3803  //auto potentials = hmlp::gofmm::Evaluate( *tree_ptr, weights );
3804 
3805  //for ( size_t j = 0; j < y.col(); j ++ )
3806  // for ( size_t i = 0; i < y.row(); i ++ )
3807  // y( i, j ) = potentials( j, i );
3808 
3809  };
3810 
3811  private:
3812 
3814  tree::Tree<
3815  gofmm::Setup<SPDMATRIX, centersplit<SPDMATRIX, 2, T>, T>,
3816  gofmm::NodeData<T>> *tree_ptr = NULL;
3817 
3818 };
3838 
3840  centersplit<SPDMatrix<double>, 2, double>, double> dSetup_t;
3841 
3843  centersplit<SPDMatrix<float >, 2, float>, float> sSetup_t;
3844 
3847 
3848 
3849 
3850 
3851 
3857 Data<double> Evaluate( dTree_t *tree, Data<double> *weights );
3858 Data<float> Evaluate( dTree_t *tree, Data<float > *weights );
3859 
3860 dTree_t *Compress( dSPDMatrix_t *K, double stol, double budget );
3861 sTree_t *Compress( sSPDMatrix_t *K, float stol, float budget );
3862 
3863 dTree_t *Compress( dSPDMatrix_t *K, double stol, double budget,
3864  size_t m, size_t k, size_t s );
3865 sTree_t *Compress( sSPDMatrix_t *K, float stol, float budget,
3866  size_t m, size_t k, size_t s );
3867 
3868 
3869 double ComputeError( dTree_t *tree, size_t gid, hmlp::Data<double> *potentials );
3870 float ComputeError( sTree_t *tree, size_t gid, hmlp::Data<float> *potentials );
3871 
3872 
3873 
3874 
3875 
3876 
3877 
3878 
3879 
3880 
3881 };
3882 };
3884 #endif
deque< Statistic > s2s_kij_t
Definition: gofmm.hpp:508
Definition: tree.hpp:955
Configuration contains all user-defined parameters.
Definition: gofmm.hpp:212
void Set(NODE *user_arg)
Definition: gofmm.hpp:1271
Data< T > w_leaf
Definition: gofmm.hpp:382
deque< Statistic > s2n_kij_t
Definition: gofmm.hpp:513
Definition: gofmm_gpu.hpp:497
Event skeletonize
Definition: gofmm.hpp:396
NodeData()
Definition: gofmm.hpp:350
void FromConfiguration(Configuration< T > &config, SPDMATRIX &K, SPLITTER &splitter, Data< pair< T, size_t >> *NN)
Definition: gofmm.hpp:309
void Set(NODE *user_arg)
Definition: gofmm.hpp:1476
void GetEventRecord()
Definition: gofmm.hpp:2389
Definition: gofmm.hpp:3779
vector< size_t > candidate_rows
Definition: gofmm.hpp:371
Definition: gofmm.hpp:1265
Definition: igofmm.hpp:70
Data< T > w_skel
Definition: gofmm.hpp:378
void Set(NODE *user_arg)
Definition: gofmm.hpp:747
Definition: util.hpp:572
This task creates an hierarchical tree view for w<RIDS> and u<RIDS>.
Definition: gofmm.hpp:414
void Set(NODE *user_arg)
Definition: gofmm.hpp:2111
vector< size_t > skels
Definition: gofmm.hpp:359
size_t n
Definition: tree.hpp:971
Definition: gofmm.hpp:1867
size_t row()
Definition: View.hpp:345
map< size_t, T > snids
Definition: gofmm.hpp:368
void GetEventRecord()
Definition: gofmm.hpp:1282
CommandLineHelper(int argc, char *argv[])
Definition: gofmm.hpp:94
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
void DependencyAnalysis()
Definition: gofmm.hpp:2178
Data and setup that are shared with all nodes.
Definition: tree.hpp:876
void Set(NODE *user_arg)
Definition: gofmm.hpp:1174
Definition: gofmm.hpp:1168
size_t col()
Definition: View.hpp:348
deque< Statistic > updateweight
Definition: gofmm.hpp:505
View< T > w_view
Definition: gofmm.hpp:386
T * data()
Definition: View.hpp:354
class Device * GetDevice()
Definition: thread.cpp:554
Data< T > KIJ
Definition: gofmm.hpp:375
void DependencyAnalysis()
Definition: gofmm.hpp:429
void Set(NODE *user_arg)
Definition: gofmm.hpp:1693
This class does not need to inherit hmlp::Data<T>, but it should support two interfaces for data fetc...
Definition: SPDMatrix.hpp:21
This the splitter used in the randomized tree.
Definition: gofmm.hpp:658
Data< size_t > Nearbmap
Definition: gofmm.hpp:390
void Set(NODE *user_arg)
Definition: gofmm.hpp:2380
void Set(NODE *user_arg)
Definition: gofmm.hpp:2326
size_t col() const noexcept
Definition: Data.hpp:281
void GetEventRecord()
Definition: gofmm.hpp:2172
Wrapper for omp or pthread mutex.
Definition: tci.hpp:50
Provide statistics summary for the execution section.
Definition: gofmm.hpp:493
void HeapSelect(size_t n, size_t k, std::pair< T, size_t > *Query, std::pair< T, size_t > *NN)
Definition: util.hpp:520
void MergeNeighbors(size_t k, pair< T, size_t > *A, pair< T, size_t > *B, vector< pair< T, size_t >> &aux)
Definition: tree.hpp:212
Task wrapper for CacheNearNodes().
Definition: gofmm.hpp:2374
size_t row() const noexcept
Definition: Data.hpp:278
These are data that shared by the whole local tree.
Definition: gofmm.hpp:301
Data< T > proj
Definition: gofmm.hpp:365
void Set(bool TRANS, Data< T > &buff)
Definition: View.hpp:60
void Set(NODE *user_arg)
Definition: gofmm.hpp:1873
Definition: Data.hpp:134
Definition: gofmm.hpp:2101
Definition: gofmm.hpp:738
Definition: View.hpp:43
size_t ld()
Definition: View.hpp:351
There is no dependency between each task. However there are raw (read after write) dependencies: ...
Definition: gofmm.hpp:1687
The correponding task of Interpolate().
Definition: gofmm.hpp:945
DistanceMetric metric
Definition: gofmm.hpp:193
This is a helper class that parses the arguments from command lines.
Definition: gofmm.hpp:89
This class describes devices or accelerators that require a master thread to control. A device can accept tasks from multiple workers. All received tasks are expected to be executed independently in a time-sharing fashion. Whether these tasks are executed in parallel, sequential or with some built-in context switching scheme does not matter.
Definition: device.hpp:125
void Execute(Worker *user_worker)
Definition: gofmm.hpp:431
size_t n
Definition: gofmm.hpp:185
vector< int > jpvt
Definition: gofmm.hpp:362
void xtrsm(const char *side, const char *uplo, const char *transA, const char *diag, int m, int n, double alpha, double *A, int lda, double *B, int ldb)
DTRSM wrapper.
Definition: blas_lapack.cpp:315
Definition: gofmm.hpp:1470
double stol
Definition: gofmm.hpp:190
size_t d
Definition: gofmm.hpp:196
This the main splitter used to build the Spd-Askit tree. First compute the approximate center using s...
Definition: gofmm.hpp:580
Wrapper for omp or pthread mutex.
Definition: runtime.hpp:113
Definition: gofmm.hpp:83
This class contains all GOFMM related data carried by a tree node.
Definition: gofmm.hpp:345
Definition: runtime.hpp:174
size_t col()
Definition: VirtualMatrix.hpp:85
Definition: gofmm.hpp:2320
Definition: thread.hpp:166