31 #include <unordered_set> 49 #include <hmlp_base.hpp> 51 #include <primitives/lowrank.hpp> 52 #include <primitives/combinatorics.hpp> 53 #include <primitives/gemm.hpp> 55 #include <containers/VirtualMatrix.hpp> 56 #include <containers/SPDMatrix.hpp> 62 #include <cuda_runtime.h> 63 #include <gofmm_gpu.hpp> 78 #define REPORT_ANN_ACCURACY 1 79 #define REPORT_COMPRESS_STATUS 1 80 #define REPORT_EVALUATE_STATUS 1 97 sscanf( argv[ 1 ],
"%lu", &n );
99 sscanf( argv[ 2 ],
"%lu", &m );
101 sscanf( argv[ 3 ],
"%lu", &k );
103 sscanf( argv[ 4 ],
"%lu", &s );
105 sscanf( argv[ 5 ],
"%lu", &nrhs );
107 sscanf( argv[ 6 ],
"%lf", &stol );
109 sscanf( argv[ 7 ],
"%lf", &budget );
111 distance_type = argv[ 8 ];
112 if ( !distance_type.compare(
"geometry" ) )
114 metric = GEOMETRY_DISTANCE;
116 else if ( !distance_type.compare(
"kernel" ) )
118 metric = KERNEL_DISTANCE;
120 else if ( !distance_type.compare(
"angle" ) )
122 metric = ANGLE_DISTANCE;
126 printf(
"%s is not supported\n", argv[ 8 ] );
130 spdmatrix_type = argv[ 9 ];
131 if ( !spdmatrix_type.compare(
"testsuit" ) )
135 else if ( !spdmatrix_type.compare(
"userdefine" ) )
139 else if ( !spdmatrix_type.compare(
"pvfmm" ) )
143 else if ( !spdmatrix_type.compare(
"dense" ) || !spdmatrix_type.compare(
"ooc" ) )
146 user_matrix_filename = argv[ 10 ];
150 user_points_filename = argv[ 11 ];
152 sscanf( argv[ 12 ],
"%lu", &d );
155 else if ( !spdmatrix_type.compare(
"mlp" ) )
157 hidden_layers = argv[ 10 ];
158 user_points_filename = argv[ 11 ];
160 sscanf( argv[ 12 ],
"%lu", &d );
162 else if ( !spdmatrix_type.compare(
"cov" ) )
164 kernelmatrix_type = argv[ 10 ];
165 user_points_filename = argv[ 11 ];
167 sscanf( argv[ 12 ],
"%lu", &d );
169 sscanf( argv[ 13 ],
"%lu", &nb );
171 else if ( !spdmatrix_type.compare(
"kernel" ) )
173 kernelmatrix_type = argv[ 10 ];
174 user_points_filename = argv[ 11 ];
176 sscanf( argv[ 12 ],
"%lu", &d );
178 if ( argc > 13 ) sscanf( argv[ 13 ],
"%lf", &h );
182 printf(
"%s is not supported\n", argv[ 9 ] );
188 size_t n, m, k, s, nrhs;
193 DistanceMetric metric = ANGLE_DISTANCE;
199 string distance_type;
200 string spdmatrix_type;
201 string kernelmatrix_type;
202 string hidden_layers;
203 string user_matrix_filename;
204 string user_points_filename;
219 size_t problem_size,
size_t leaf_node_size,
220 size_t neighbor_size,
size_t maximum_rank,
221 T tolerance, T budget )
223 Set( metric_type, problem_size, leaf_node_size,
224 neighbor_size, maximum_rank, tolerance, budget );
227 void Set( DistanceMetric metric_type,
228 size_t problem_size,
size_t leaf_node_size,
229 size_t neighbor_size,
size_t maximum_rank,
230 T tolerance, T budget )
232 this->metric_type = metric_type;
233 this->problem_size = problem_size;
234 this->leaf_node_size = leaf_node_size;
235 this->neighbor_size = neighbor_size;
236 this->maximum_rank = maximum_rank;
237 this->tolerance = tolerance;
238 this->budget = budget;
243 DistanceMetric MetricType() {
return metric_type; };
245 size_t ProblemSize() {
return problem_size; };
247 size_t LeafNodeSize() {
return leaf_node_size; };
249 size_t NeighborSize() {
return neighbor_size; };
251 size_t MaximumRank() {
return maximum_rank; };
253 T Tolerance() {
return tolerance; };
255 T Budget() {
return budget; };
257 bool IsSymmetric() {
return is_symmetric; };
259 bool UseAdaptiveRanks() {
return use_adaptive_ranks; };
261 bool SecureAccuracy() {
return secure_accuracy; };
266 DistanceMetric metric_type = ANGLE_DISTANCE;
269 size_t problem_size = 0;
272 size_t leaf_node_size = 64;
275 size_t neighbor_size = 32;
278 size_t maximum_rank = 64;
287 bool is_symmetric =
true;
290 bool use_adaptive_ranks =
true;
293 bool secure_accuracy =
false;
300 template<
typename SPDMATRIX,
typename SPLITTER,
typename T>
310 SPDMATRIX &K, SPLITTER &splitter,
Data<pair<T, size_t>> *NN )
312 this->CopyFrom( config );
314 this->splitter = splitter;
333 bool do_ulv_factorization =
true;
372 vector<size_t> candidate_cols;
404 double knn_acc = 0.0;
413 template<
typename NODE>
420 void Set( NODE *user_arg )
423 name = string(
"TreeView" );
424 label = to_string( arg->treelist_id );
435 auto &data = node->data;
436 auto *setup = node->setup;
439 auto &w = *(setup->u);
440 auto &u = *(setup->w);
443 auto &U = data.u_view;
444 auto &W = data.w_view;
457 auto &UL = node->lchild->data.u_view;
458 auto &UR = node->rchild->data.u_view;
459 auto &WL = node->lchild->data.w_view;
460 auto &WR = node->rchild->data.w_view;
466 UR, node->lchild->n, TOP );
468 WR, node->lchild->n, TOP );
492 template<
typename NODE>
500 deque<Statistic> rank;
502 deque<Statistic> skeletonize;
509 deque<Statistic> s2s_t;
510 deque<Statistic> s2s_gfp;
514 deque<Statistic> s2n_t;
515 deque<Statistic> s2n_gfp;
518 void operator() ( NODE *node )
520 if ( rank.size() <= node->l )
527 rank[ node->l ].Update( (
double)node->data.skels.size() );
528 skeletonize[ node->l ].Update( node->data.skeletonize.GetDuration() );
529 updateweight[ node->l ].Update( node->data.updateweight.GetDuration() );
531 #ifdef DUMP_ANALYSIS_DATA 534 auto *parent = node->parent;
536 printf(
"#%lu (s%lu), #%lu (s%lu), %lu, %lu\n",
537 node->treelist_id, node->data.skels.size(),
538 parent->treelist_id, parent->data.skels.size(),
539 node->data.skels.size(), node->l );
544 printf(
"#%lu (s%lu), , %lu, %lu\n",
545 node->treelist_id, node->data.skels.size(),
546 node->data.skels.size(), node->l );
553 for (
size_t l = 1; l < rank.size(); l ++ )
555 printf(
"@SUMMARY\n" );
556 printf(
"level %2lu, ", l ); rank[ l ].Print();
579 template<
typename SPDMATRIX,
int N_SPLIT,
typename T>
583 SPDMATRIX *Kptr = NULL;
585 DistanceMetric metric = ANGLE_DISTANCE;
587 size_t n_centroid_samples = 5;
594 vector<vector<size_t>> operator() ( vector<size_t>& gids )
const 597 assert( N_SPLIT == 2 );
601 SPDMATRIX &K = *Kptr;
602 vector<vector<size_t>> split( N_SPLIT );
603 size_t n = gids.size();
604 vector<T> temp( n, 0.0 );
607 auto column_samples = combinatorics::SampleWithoutReplacement(
608 n_centroid_samples, gids );
612 auto DIC = K.Distances( this->metric, gids, column_samples );
615 for (
auto & it : temp ) it = 0;
618 for (
size_t j = 0; j < DIC.col(); j ++ )
619 for (
size_t i = 0; i < DIC.row(); i ++ )
620 temp[ i ] += DIC( i, j );
623 auto idf2c = distance( temp.begin(), max_element( temp.begin(), temp.end() ) );
626 vector<size_t> P( 1, gids[ idf2c ] );
629 auto DIP = K.Distances( this->metric, gids, P );
632 auto idf2f = distance( DIP.begin(), max_element( DIP.begin(), DIP.end() ) );
635 vector<size_t> Q( 1, gids[ idf2f ] );
638 auto DIQ = K.Distances( this->metric, gids, P );
640 for (
size_t i = 0; i < temp.size(); i ++ )
641 temp[ i ] = DIP[ i ] - DIQ[ i ];
643 return combinatorics::MedianSplit( temp );
657 template<
typename SPDMATRIX,
int N_SPLIT,
typename T>
661 SPDMATRIX *Kptr = NULL;
664 DistanceMetric metric = ANGLE_DISTANCE;
671 inline vector<vector<size_t> > operator() ( vector<size_t>& gids )
const 673 assert( Kptr && ( N_SPLIT == 2 ) );
675 SPDMATRIX &K = *Kptr;
676 size_t n = gids.size();
677 vector<vector<size_t> > split( N_SPLIT );
678 vector<T> temp( n, 0.0 );
681 size_t idf2c = std::rand() % n;
682 size_t idf2f = std::rand() % n;
683 while ( idf2c == idf2f ) idf2f = std::rand() % n;
686 vector<size_t> P( 1, gids[ idf2c ] );
687 vector<size_t> Q( 1, gids[ idf2f ] );
691 auto DIP = K.Distances( this->metric, gids, P );
692 auto DIQ = K.Distances( this->metric, gids, Q );
694 for (
size_t i = 0; i < temp.size(); i ++ )
695 temp[ i ] = DIP[ i ] - DIQ[ i ];
697 return combinatorics::MedianSplit( temp );
705 template<
typename NODE>
706 void FindNeighbors( NODE *node, DistanceMetric metric )
709 using T =
typename NODE::T;
710 auto & setup = *(node->setup);
711 auto & K = *(setup.K);
712 auto & NN = *(setup.NN);
713 auto & I = node->gids;
715 size_t kappa = NN.row();
717 pair<T, size_t> init( numeric_limits<T>::max(), NN.col() );
719 auto candidates = K.NeighborSearch( metric, kappa, I, I, init );
723 vector<pair<T, size_t> > aux( 2 * kappa );
725 for (
size_t j = 0; j < I.size(); j ++ )
728 candidates.columndata( j ), aux );
737 template<
class NODE,
typename T>
745 DistanceMetric metric = ANGLE_DISTANCE;
747 void Set( NODE *user_arg )
750 name = string(
"Neighbors" );
751 label = to_string( arg->treelist_id );
753 metric = arg->setup->MetricType();
757 auto &gids = arg->gids;
758 auto &NN = *arg->setup->NN;
760 flops *= ( 4.0 * gids.size() );
762 mops = (size_t)std::log( NN.row() ) * gids.size();
766 event.Set( name + label, flops, mops );
773 void DependencyAnalysis() { arg->DependOnNoOne(
this ); };
775 void Execute(
Worker* user_worker ) { FindNeighbors( arg, metric ); };
782 template<
bool DOAPPROXIMATE,
bool SORTED,
typename T,
typename CSCMATRIX>
785 pair<T, size_t> initNN( numeric_limits<T>::max(), n );
788 printf(
"SparsePattern k %lu n %lu, NN.row %lu NN.col %lu ...",
789 k, n, NN.
row(), NN.
col() ); fflush( stdout );
791 #pragma omp parallel for schedule( dynamic ) 792 for (
size_t j = 0; j < n; j ++ )
794 std::set<size_t> NNset;
795 size_t nnz = K.ColPtr( j + 1 ) - K.ColPtr( j );
796 if ( DOAPPROXIMATE && nnz > 2 * k ) nnz = 2 * k;
800 for (
size_t i = 0; i < nnz; i ++ )
803 auto row_ind = K.RowInd( K.ColPtr( j ) + i );
804 auto val = K.Value( K.ColPtr( j ) + i );
806 if ( val ) val = 1.0 / std::abs( val );
807 else val = std::numeric_limits<T>::max() - 1.0;
809 NNset.insert( row_ind );
810 std::pair<T, std::size_t> query( val, row_ind );
813 NN[ j * k + i ] = query;
823 std::size_t row_ind = rand() % n;
824 if ( !NNset.count( row_ind ) )
826 T val = std::numeric_limits<T>::max() - 1.0;
827 std::pair<T, std::size_t> query( val, row_ind );
828 NNset.insert( row_ind );
829 NN[ j * k + nnz ] = query;
834 printf(
"Done.\n" ); fflush( stdout );
838 printf(
"Sorting ... " ); fflush( stdout );
841 bool operator () ( std::pair<T, size_t> a, std::pair<T, size_t> b )
843 return a.first < b.first;
849 #pragma omp parallel for 850 for (
size_t j = 0; j < NN.
col(); j ++ )
852 std::sort( NN.data() + j * NN.
row(), NN.data() + ( j + 1 ) * NN.
row(), ANNLess );
854 printf(
"Done.\n" ); fflush( stdout );
862 template<
typename TA,
typename TB>
863 pair<TB, TA> flip_pair(
const pair<TA, TB> &p )
865 return pair<TB, TA>( p.second, p.first );
869 template<
typename TA,
typename TB>
870 multimap<TB, TA> flip_map(
const map<TA, TB> &src )
872 multimap<TB, TA> dst;
873 transform( src.begin(), src.end(), inserter( dst, dst.begin() ),
882 template<
typename NODE>
883 void Interpolate( NODE *node )
886 using T =
typename NODE::T;
890 auto &K = *node->setup->K;
891 auto &data = node->data;
892 auto &skels = data.skels;
893 auto &proj = data.proj;
894 auto &jpvt = data.jpvt;
899 if ( !data.isskel || proj[ 0 ] == 0 )
return;
903 assert( jpvt.size() == n );
908 data.w_skel.reserve( skels.size(), MAX_NRHS );
909 data.u_skel.reserve( skels.size(), MAX_NRHS );
916 for (
int j = 0; j < s; j ++ )
918 for (
int i = 0; i < s; i ++ )
920 if ( i <= j ) R1[ j * s + i ] = proj[ j * s + i ];
928 xtrsm(
"L",
"U",
"N",
"N", s, n, 1.0, R1.data(), s, tmp.data(), s );
931 for (
int j = 0; j < n; j ++ )
933 for (
int i = 0; i < s; i ++ )
935 proj[ jpvt[ j ] * s + i ] = tmp[ j * s + i ];
944 template<
typename NODE>
951 void Set( NODE *user_arg )
954 name = string(
"it" );
955 label = to_string( arg->treelist_id );
960 void DependencyAnalysis() { arg->DependOnNoOne(
this ); };
962 void Execute(
Worker* user_worker ) { Interpolate( arg ); };
972 template<
bool NNPRUNE,
typename NODE>
973 void RowSamples( NODE *node,
size_t nsamples )
976 using T =
typename NODE::T;
977 auto &setup = *(node->setup);
978 auto &data = node->data;
979 auto &K = *(setup.K);
982 auto &amap = data.candidate_rows;
992 auto &NN = *(setup.NN);
993 auto &gids = node->gids;
994 auto &snids = data.snids;
995 size_t knum = NN.row();
1001 vector<pair<T, size_t>> tmp( knum * gids.size() );
1002 for (
size_t j = 0; j < gids.size(); j ++ )
1003 for (
size_t i = 0; i < knum; i ++ )
1004 tmp[ j * knum + i ] = NN( i, gids[ j ] );
1007 sort( tmp.begin(), tmp.end() );
1010 for (
auto it : tmp )
1012 size_t it_gid = it.second;
1013 size_t it_morton = setup.morton[ it_gid ];
1015 if ( snids.size() >= nsamples )
break;
1019 if ( NNPRUNE ) is_near = node->NNNearNodeMortonIDs.count( it_morton );
1020 else is_near = (it_morton == node->morton );
1025 auto ret = snids.insert( make_pair( it.second, it.first ) );
1031 auto &lsnids = node->lchild->data.snids;
1032 auto &rsnids = node->rchild->data.snids;
1042 for (
auto it = rsnids.begin(); it != rsnids.end(); it ++ )
1044 auto ret = snids.insert( *it );
1047 if ( ret.first->second > (*it).first )
1048 ret.first->second = (*it).first;
1053 for (
auto gid : gids ) snids.erase( gid );
1057 if ( nsamples < K.col() - node->n )
1060 multimap<T, size_t> ordered_snids = flip_map( snids );
1062 amap.reserve( nsamples );
1065 for (
auto it : ordered_snids )
1067 if ( amap.size() >= nsamples )
break;
1069 amap.push_back( it.second );
1073 while ( amap.size() < nsamples )
1076 auto important_sample = K.ImportantSample( 0 );
1077 size_t sample_gid = important_sample.second;
1078 size_t sample_morton = setup.morton[ sample_gid ];
1080 if ( !MortonHelper::IsMyParent( sample_morton, node->morton ) )
1082 amap.push_back( sample_gid );
1088 for (
size_t sample = 0; sample < K.col(); sample ++ )
1090 size_t sample_morton = setup.morton[ sample ];
1091 if ( !MortonHelper::IsMyParent( sample_morton, node->morton ) )
1093 amap.push_back( sample );
1108 template<
bool NNPRUNE,
typename NODE>
1109 void SkeletonKIJ( NODE *node )
1112 using T =
typename NODE::T;
1114 auto &K = *(node->setup->K);
1116 auto &data = node->data;
1117 auto &candidate_rows = data.candidate_rows;
1118 auto &candidate_cols = data.candidate_cols;
1119 auto &KIJ = data.KIJ;
1121 auto *lchild = node->lchild;
1122 auto *rchild = node->rchild;
1127 candidate_cols = node->gids;
1131 auto &lskels = lchild->data.skels;
1132 auto &rskels = rchild->data.skels;
1134 if ( !lskels.size() || !rskels.size() )
return;
1136 candidate_cols = lskels;
1137 candidate_cols.insert( candidate_cols.end(),
1138 rskels.begin(), rskels.end() );
1142 size_t nsamples = 2 * candidate_cols.size();
1145 if ( nsamples < 2 * node->setup->LeafNodeSize() )
1146 nsamples = 2 * node->setup->LeafNodeSize();
1149 RowSamples<NNPRUNE>( node, nsamples );
1152 KIJ = K( candidate_rows, candidate_cols );
1167 template<
bool NNPRUNE,
typename NODE,
typename T>
1177 name = string(
"par-gskm" );
1178 label = to_string( arg->treelist_id );
1185 void DependencyAnalysis() { arg->DependOnChildren(
this ); };
1187 void Execute(
Worker* user_worker ) { SkeletonKIJ<NNPRUNE>( arg ); };
1193 template<
typename NODE>
1194 void Skeletonize( NODE *node )
1197 using T =
typename NODE::T;
1199 if ( !node->parent )
return;
1202 auto &K = *(node->setup->K);
1203 auto &NN = *(node->setup->NN);
1204 auto maxs = node->setup->MaximumRank();
1205 auto stol = node->setup->Tolerance();
1206 bool secure_accuracy = node->setup->SecureAccuracy();
1207 bool use_adaptive_ranks = node->setup->UseAdaptiveRanks();
1210 auto &data = node->data;
1211 auto &skels = data.skels;
1212 auto &proj = data.proj;
1213 auto &jpvt = data.jpvt;
1214 auto &KIJ = data.KIJ;
1215 auto &candidate_cols = data.candidate_cols;
1219 size_t m = KIJ.row();
1220 size_t n = KIJ.col();
1223 if ( secure_accuracy )
1225 if ( !node->isleaf && ( !node->lchild->data.isskel || !node->rchild->data.isskel ) )
1228 proj.resize( 0, 0 );
1229 data.isskel =
false;
1235 T scaled_stol = std::sqrt( (T)n / q ) * std::sqrt( (T)m / (N - q) ) * stol;
1237 scaled_stol *= std::sqrt( (T)q / N );
1240 lowrank::id( use_adaptive_ranks, secure_accuracy,
1241 KIJ.row(), KIJ.col(), maxs, scaled_stol, KIJ, skels, proj, jpvt );
1247 if ( secure_accuracy )
1250 data.isskel = (skels.size() != 0);
1254 assert( skels.size() && proj.size() && jpvt.size() );
1259 for (
size_t i = 0; i < skels.size(); i ++ ) skels[ i ] = candidate_cols[ skels[ i ] ];
1264 template<
typename NODE,
typename T>
1274 name = string(
"sk" );
1275 label = to_string( arg->treelist_id );
1284 double flops = 0.0, mops = 0.0;
1286 auto &K = *arg->setup->K;
1287 size_t n = arg->data.proj.col();
1289 size_t k = arg->data.proj.row();
1292 flops += ( 2.0 / 3.0 ) * n * n * ( 3 * m - n );
1293 mops += ( 2.0 / 3.0 ) * n * n * ( 3 * m - n );
1296 flops += k * ( k - 1 ) * ( n + 1 );
1297 mops += 2.0 * ( k * k + k * n );
1306 event.Set( label + name, flops, mops );
1307 arg->data.skeletonize = event;
1310 void DependencyAnalysis() { arg->DependOnNoOne(
this ); };
1312 void Execute(
Worker* user_worker ) { Skeletonize( arg ); };
1349 template<
typename NODE>
1350 void UpdateWeights( NODE *node )
1353 using T =
typename NODE::T;
1355 if ( !node->parent || !node->data.isskel )
return;
1358 auto &w = *node->setup->w;
1361 auto &data = node->data;
1362 auto &proj = data.proj;
1363 auto &skels = data.skels;
1364 auto &w_skel = data.w_skel;
1365 auto &w_leaf = data.w_leaf;
1366 auto *lchild = node->lchild;
1367 auto *rchild = node->rchild;
1370 size_t nrhs = w.col();
1373 w_skel.resize( skels.size(), nrhs );
1379 if ( w_leaf.size() )
1388 w_skel.row(), w_skel.col(), w_leaf.row(),
1389 1.0, proj.data(), proj.row(),
1390 w_leaf.data(), w_leaf.row(),
1391 0.0, w_skel.data(), w_skel.row()
1405 w_skel.row(), w_skel.col(), W.
row(),
1406 1.0, proj.data(), proj.row(),
1408 0.0, w_skel.data(), w_skel.row()
1420 auto &w_lskel = lchild->data.w_skel;
1421 auto &w_rskel = rchild->data.w_skel;
1422 auto &lskel = lchild->data.skels;
1423 auto &rskel = rchild->data.skels;
1426 if ( node->treelist_id > 6 )
1432 w_skel.row(), w_skel.col(), lskel.size(),
1433 1.0, proj.data(), proj.row(),
1434 w_lskel.data(), w_lskel.row(),
1435 0.0, w_skel.data(), w_skel.row()
1440 w_skel.row(), w_skel.col(), rskel.size(),
1441 1.0, proj.data() + proj.row() * lskel.size(), proj.row(),
1442 w_rskel.data(), w_rskel.row(),
1443 1.0, w_skel.data(), w_skel.row()
1451 View<T> W(
false, w_skel ), WL(
false, w_lskel ),
1452 WR(
false, w_rskel );
1454 P.Partition1x2( PL, PR, lskel.size(), LEFT );
1456 gemm::xgemm<GEMM_NB>( (T)1.0, PL, WL, (T)0.0, W );
1457 W.DependencyCleanUp();
1459 gemm::xgemm<GEMM_NB>( (T)1.0, PR, WR, (T)1.0, W );
1469 template<
typename NODE,
typename T>
1479 name = string(
"n2s" );
1480 label = to_string( arg->treelist_id );
1484 auto &gids = arg->gids;
1485 auto &skels = arg->data.skels;
1486 auto &w = *arg->setup->w;
1489 auto m = skels.size();
1491 auto k = gids.size();
1492 flops = 2.0 * m * n * k;
1493 mops = 2.0 * ( m * n + m * k + k * n );
1497 auto &lskels = arg->lchild->data.skels;
1498 auto &rskels = arg->rchild->data.skels;
1499 auto m = skels.size();
1501 auto k = lskels.size() + rskels.size();
1502 flops = 2.0 * m * n * k;
1503 mops = 2.0 * ( m * n + m * k + k * n );
1507 event.Set( label + name, flops, mops );
1509 cost = flops / 1E+9;
1514 void Prefetch(
Worker* user_worker )
1516 auto &proj = arg->data.proj;
1517 __builtin_prefetch( proj.data() );
1518 auto &w_skel = arg->data.w_skel;
1519 __builtin_prefetch( w_skel.data() );
1522 auto &w_leaf = arg->data.w_leaf;
1523 __builtin_prefetch( w_leaf.data() );
1527 auto &w_lskel = arg->lchild->data.w_skel;
1528 __builtin_prefetch( w_lskel.data() );
1529 auto &w_rskel = arg->rchild->data.w_skel;
1530 __builtin_prefetch( w_rskel.data() );
1532 #ifdef HMLP_USE_CUDA 1534 if ( user_worker ) device = user_worker->
GetDevice();
1537 proj.CacheD( device );
1538 proj.PrefetchH2D( device, 1 );
1541 auto &w_leaf = arg->data.w_leaf;
1542 w_leaf.CacheD( device );
1543 w_leaf.PrefetchH2D( device, 1 );
1547 auto &w_lskel = arg->lchild->data.w_skel;
1548 w_lskel.CacheD( device );
1549 w_lskel.PrefetchH2D( device, 1 );
1550 auto &w_rskel = arg->rchild->data.w_skel;
1551 w_rskel.CacheD( device );
1552 w_rskel.PrefetchH2D( device, 1 );
1558 void DependencyAnalysis() { arg->DependOnChildren(
this ); };
1560 void Execute(
Worker* user_worker )
1562 #ifdef HMLP_USE_CUDA 1564 if ( user_worker ) device = user_worker->
GetDevice();
1565 if ( device ) gpu::UpdateWeights( device, arg );
1566 else UpdateWeights<NODE, T>( arg );
1568 UpdateWeights( arg );
1582 template<
typename NODE>
1583 void SkeletonsToSkeletons( NODE *node )
1586 using T =
typename NODE::T;
1588 if ( !node->parent || !node->data.isskel )
return;
1590 double beg, u_skel_time, s2s_time;
1592 auto *FarNodes = &node->NNFarNodes;
1594 auto &K = *node->setup->K;
1595 auto &data = node->data;
1596 auto &amap = node->data.skels;
1597 auto &u_skel = node->data.u_skel;
1598 auto &FarKab = node->data.FarKab;
1600 size_t nrhs = node->setup->w->col();
1603 beg = omp_get_wtime();
1604 u_skel.resize( 0, 0 );
1605 u_skel.resize( amap.size(), nrhs, 0.0 );
1606 u_skel_time = omp_get_wtime() - beg;
1615 for (
auto it = FarNodes->begin(); it != FarNodes->end(); it ++ )
1617 auto &bmap = (*it)->data.skels;
1618 auto &w_skel = (*it)->data.w_skel;
1619 assert( w_skel.col() == nrhs );
1620 assert( w_skel.row() == bmap.size() );
1621 assert( w_skel.size() == nrhs * bmap.size() );
1623 if ( FarKab.size() )
1628 assert( FarKab.row() == amap.size() );
1629 assert( u_skel.row() * offset <= FarKab.size() );
1637 u_skel.row(), u_skel.col(), w_skel.row(),
1638 1.0, FarKab.data() + u_skel.row() * offset, FarKab.row(),
1639 w_skel.data(), w_skel.row(),
1640 1.0, u_skel.data(), u_skel.row()
1650 assert( FarKab.col() >= W.
row() + offset );
1651 Kab.
Set( FarKab.row(), W.
row(), 0, offset, &FarKab_v );
1652 gemm::xgemm<GEMM_NB>( (T)1.0, Kab, W, (T)1.0, U );
1656 offset += w_skel.row();
1660 printf(
"Far Kab not cached treelist_id %lu, l %lu\n\n",
1661 node->treelist_id, node->l ); fflush( stdout );
1664 auto Kab = K( amap, bmap );
1665 xgemm(
"N",
"N", u_skel.row(), u_skel.col(), w_skel.row(),
1666 1.0, Kab.data(), Kab.row(),
1667 w_skel.data(), w_skel.row(),
1668 1.0, u_skel.data(), u_skel.row() );
1686 template<
bool NNPRUNE,
typename NODE,
typename T>
1696 name = string(
"s2s" );
1700 ss << arg->treelist_id;
1705 double flops = 0.0, mops = 0.0;
1706 auto &w = *arg->setup->w;
1707 size_t m = arg->data.skels.size();
1710 std::set<NODE*> *FarNodes;
1711 if ( NNPRUNE ) FarNodes = &arg->NNFarNodes;
1712 else FarNodes = &arg->FarNodes;
1714 for (
auto it = FarNodes->begin(); it != FarNodes->end(); it ++ )
1716 size_t k = (*it)->data.skels.size();
1717 flops += 2.0 * m * n * k;
1719 mops += 2.0 * ( m * n + n * k + k * n );
1723 event.Set( label + name, flops, mops );
1725 cost = flops / 1E+9;
1730 void DependencyAnalysis()
1732 for (
auto it : arg->NNFarNodes ) it->DependencyAnalysis( R,
this );
1733 arg->DependencyAnalysis( RW,
this );
1737 void Execute(
Worker* user_worker ) { SkeletonsToSkeletons( arg ); };
1746 template<
typename NODE>
1747 void SkeletonsToNodes( NODE *node )
1750 using T =
typename NODE::T;
1753 auto &K = *node->setup->K;
1754 auto &w = *node->setup->w;
1757 auto &gids = node->gids;
1758 auto &data = node->data;
1759 auto &proj = data.proj;
1760 auto &skels = data.skels;
1761 auto &u_skel = data.u_skel;
1762 auto *lchild = node->lchild;
1763 auto *rchild = node->rchild;
1765 size_t nrhs = w.col();
1776 if ( U.
col() == nrhs )
1782 "Transpose",
"Non-transpose",
1783 U.
row(), U.
col(), u_skel.row(),
1784 1.0, proj.data(), proj.row(),
1785 u_skel.data(), u_skel.row(),
1794 auto &u_leaf = node->data.u_leaf[ 0 ];
1797 u_leaf.resize( 0, 0 );
1798 u_leaf.resize( gids.size(), nrhs, 0.0 );
1807 u_leaf.row(), u_leaf.col(), u_skel.row(),
1808 1.0, proj.data(), proj.row(),
1809 u_skel.data(), u_skel.row(),
1810 1.0, u_leaf.data(), u_leaf.row()
1817 if ( !node->parent || !node->data.isskel )
return;
1819 auto &u_lskel = lchild->data.u_skel;
1820 auto &u_rskel = rchild->data.u_skel;
1821 auto &lskel = lchild->data.skels;
1822 auto &rskel = rchild->data.skels;
1825 if ( node->treelist_id > 6 )
1830 "Transpose",
"No transpose",
1831 u_lskel.row(), u_lskel.col(), proj.row(),
1832 1.0, proj.data(), proj.row(),
1833 u_skel.data(), u_skel.row(),
1834 1.0, u_lskel.data(), u_lskel.row()
1838 "Transpose",
"No transpose",
1839 u_rskel.row(), u_rskel.col(), proj.row(),
1840 1.0, proj.data() + proj.row() * lskel.size(), proj.row(),
1841 u_skel.data(), u_skel.row(),
1842 1.0, u_rskel.data(), u_rskel.row()
1850 View<T> U(
false, u_skel ), UL(
false, u_lskel ),
1851 UR(
false, u_rskel );
1854 PR, lskel.size(), TOP );
1856 gemm::xgemm<GEMM_NB>( (T)1.0, PL, U, (T)1.0, UL );
1858 gemm::xgemm<GEMM_NB>( (T)1.0, PR, U, (T)1.0, UR );
1866 template<
bool NNPRUNE,
typename NODE,
typename T>
1876 name = string(
"s2n" );
1877 label = to_string( arg->treelist_id );
1880 double flops = 0.0, mops = 0.0;
1881 auto &gids = arg->gids;
1882 auto &data = arg->data;
1883 auto &proj = data.proj;
1884 auto &skels = data.skels;
1885 auto &w = *arg->setup->w;
1889 size_t m = proj.col();
1891 size_t k = proj.row();
1892 flops += 2.0 * m * n * k;
1893 mops += 2.0 * ( m * n + n * k + m * k );
1897 if ( !arg->parent || !arg->data.isskel )
1903 size_t m = proj.col();
1905 size_t k = proj.row();
1906 flops += 2.0 * m * n * k;
1907 mops += 2.0 * ( m * n + n * k + m * k );
1912 event.Set( label + name, flops, mops );
1914 cost = flops / 1E+9;
1919 void Prefetch(
Worker* user_worker )
1921 auto &proj = arg->data.proj;
1922 __builtin_prefetch( proj.data() );
1923 auto &u_skel = arg->data.u_skel;
1924 __builtin_prefetch( u_skel.data() );
1934 auto &u_lskel = arg->lchild->data.u_skel;
1935 __builtin_prefetch( u_lskel.data() );
1936 auto &u_rskel = arg->rchild->data.u_skel;
1937 __builtin_prefetch( u_rskel.data() );
1939 #ifdef HMLP_USE_CUDA 1941 if ( user_worker ) device = user_worker->
GetDevice();
1944 int stream_id = arg->treelist_id % 8;
1945 proj.CacheD( device );
1946 proj.PrefetchH2D( device, stream_id );
1947 u_skel.CacheD( device );
1948 u_skel.PrefetchH2D( device, stream_id );
1954 auto &u_lskel = arg->lchild->data.u_skel;
1955 u_lskel.CacheD( device );
1956 u_lskel.PrefetchH2D( device, stream_id );
1957 auto &u_rskel = arg->rchild->data.u_skel;
1958 u_rskel.CacheD( device );
1959 u_rskel.PrefetchH2D( device, stream_id );
1965 void DependencyAnalysis() { arg->DependOnParent(
this ); };
1967 void Execute(
Worker* user_worker )
1969 #ifdef HMLP_USE_CUDA 1971 if ( user_worker ) device = user_worker->
GetDevice();
1972 if ( device ) gpu::SkeletonsToNodes<NNPRUNE, NODE, T>( device, arg );
1973 else SkeletonsToNodes<NNPRUNE, NODE, T>( arg );
1975 SkeletonsToNodes( arg );
1983 template<
int SUBTASKID,
bool NNPRUNE,
typename NODE,
typename T>
1984 void LeavesToLeaves( NODE *node,
size_t itbeg,
size_t itend )
1986 assert( node->isleaf );
1988 double beg, u_leaf_time, before_writeback_time, after_writeback_time;
1991 auto &K = *node->setup->K;
1992 auto &w = *node->setup->w;
1994 auto &gids = node->gids;
1995 auto &data = node->data;
1996 auto &amap = node->gids;
1997 auto &NearKab = data.NearKab;
1999 size_t nrhs = w.col();
2001 set<NODE*> *NearNodes;
2002 if ( NNPRUNE ) NearNodes = &node->NNNearNodes;
2003 else NearNodes = &node->NearNodes;
2008 auto &u_leaf = data.u_leaf[ SUBTASKID ];
2009 u_leaf.resize( 0, 0 );
2012 if ( itbeg == itend )
2018 u_leaf.resize( gids.size(), nrhs, 0.0 );
2021 if ( NearKab.size() )
2026 for (
auto it = NearNodes->begin(); it != NearNodes->end(); it ++ )
2028 if ( itbeg <= itptr && itptr < itend )
2031 auto wb = (*it)->data.w_leaf;
2039 u_leaf.row(), u_leaf.col(), wb.row(),
2040 1.0, NearKab.data() + offset * NearKab.row(), NearKab.row(),
2041 wb.data(), wb.row(),
2042 1.0, u_leaf.data(), u_leaf.row()
2047 View<T> W = (*it)->data.w_view;
2051 u_leaf.row(), u_leaf.col(), W.
row(),
2052 1.0, NearKab.data() + offset * NearKab.row(), NearKab.row(),
2054 1.0, u_leaf.data(), u_leaf.row()
2058 offset += (*it)->gids.size();
2065 for (
auto it = NearNodes->begin(); it != NearNodes->end(); it ++ )
2067 if ( itbeg <= itptr && itptr < itend )
2069 auto &bmap = (*it)->gids;
2070 auto wb = (*it)->data.w_leaf;
2073 auto Kab = K( amap, bmap );
2078 xgemm(
"N",
"N", u_leaf.row(), u_leaf.col(), wb.row(),
2079 1.0, Kab.data(), Kab.row(),
2080 wb.data(), wb.row(),
2081 1.0, u_leaf.data(), u_leaf.row());
2085 View<T> W = (*it)->data.w_view;
2086 xgemm(
"N",
"N", u_leaf.row(), u_leaf.col(), W.
row(),
2087 1.0, Kab.data(), Kab.row(),
2089 1.0, u_leaf.data(), u_leaf.row() );
2095 before_writeback_time = omp_get_wtime() - beg;
2100 template<
int SUBTASKID,
bool NNPRUNE,
typename NODE,
typename T>
2114 name = string(
"l2l" );
2115 label = to_string( arg->treelist_id );
2119 double flops = 0.0, mops = 0.0;
2120 auto &gids = arg->gids;
2121 auto &data = arg->data;
2122 auto &proj = data.proj;
2123 auto &skels = data.skels;
2124 auto &w = *arg->setup->w;
2125 auto &K = *arg->setup->K;
2126 auto &NearKab = data.NearKab;
2128 assert( arg->isleaf );
2130 size_t m = gids.size();
2133 set<NODE*> *NearNodes;
2134 if ( NNPRUNE ) NearNodes = &arg->NNNearNodes;
2135 else NearNodes = &arg->NearNodes;
2139 size_t itrange = ( NearNodes->size() + 3 ) / 4;
2140 if ( itrange < 1 ) itrange = 1;
2141 itbeg = ( SUBTASKID - 1 ) * itrange;
2142 itend = ( SUBTASKID + 0 ) * itrange;
2143 if ( itbeg > NearNodes->size() ) itbeg = NearNodes->size();
2144 if ( itend > NearNodes->size() ) itend = NearNodes->size();
2145 if ( SUBTASKID == 4 ) itend = NearNodes->size();
2147 for (
auto it = NearNodes->begin(); it != NearNodes->end(); it ++ )
2149 if ( itbeg <= itptr && itptr < itend )
2151 size_t k = (*it)->gids.size();
2152 flops += 2.0 * m * n * k;
2154 mops += 2.0 * ( m * n + n * k + m * k );
2160 event.Set( label + name, flops, mops );
2163 cost = flops / 1E+9;
2166 void Prefetch(
Worker* user_worker )
2168 auto &u_leaf = arg->data.u_leaf[ SUBTASKID ];
2169 __builtin_prefetch( u_leaf.data() );
2180 assert( arg->isleaf );
2189 void Execute(
Worker* user_worker )
2191 LeavesToLeaves<SUBTASKID, NNPRUNE, NODE, T>( arg, itbeg, itend );
2199 template<
typename NODE>
2200 void PrintSet( set<NODE*> &
set )
2202 for (
auto it =
set.begin(); it !=
set.end(); it ++ )
2204 printf(
"%lu, ", (*it)->treelist_id );
2216 template<
typename NODE>
2217 multimap<size_t, size_t> NearNodeBallots( NODE *node )
2220 assert( node->isleaf );
2222 auto &setup = *(node->setup);
2223 auto &NN = *(setup.NN);
2224 auto &gids = node->gids;
2227 map<size_t, size_t> ballot;
2229 size_t HasMissingNeighbors = 0;
2233 for (
size_t j = 0; j < gids.size(); j ++ )
2235 for (
size_t i = 0; i < NN.row(); i ++ )
2237 auto value = NN( i, gids[ j ] ).first;
2238 size_t neighbor_gid = NN( i, gids[ j ] ).second;
2240 if ( neighbor_gid >= 0 && neighbor_gid < NN.col() )
2242 size_t neighbor_morton = setup.morton[ neighbor_gid ];
2243 size_t weighted_ballot = 1.0 / ( value + 1E-3 );
2247 if ( i < NN.row() / 2 )
2249 if ( ballot.find( neighbor_morton ) != ballot.end() )
2252 ballot[ neighbor_morton ] += weighted_ballot;
2257 ballot[ neighbor_morton ] = weighted_ballot;
2263 HasMissingNeighbors ++;
2268 if ( HasMissingNeighbors )
2270 printf(
"Missing %lu neighbor pairs\n", HasMissingNeighbors );
2275 return flip_map( ballot );
2282 template<
typename NODE,
typename T>
2283 void NearSamples( NODE *node )
2285 auto &setup = *(node->setup);
2286 auto &NN = *(setup.NN);
2290 auto &gids = node->gids;
2292 double budget = setup.Budget();
2293 size_t n_nodes = ( 1 << node->l );
2296 node->NearNodes.insert( node );
2297 node->NNNearNodes.insert( node );
2298 node->NNNearNodeMortonIDs.insert( node->morton );
2301 multimap<size_t, size_t> sorted_ballot = NearNodeBallots( node );
2304 for (
auto it = sorted_ballot.rbegin(); it != sorted_ballot.rend(); it ++ )
2307 if ( node->NNNearNodes.size() >= n_nodes * budget )
break;
2309 auto *target = (*node->morton2node)[ (*it).second ];
2310 node->NNNearNodeMortonIDs.insert( (*it).second );
2311 node->NNNearNodes.insert( target );
2319 template<
typename NODE,
typename T>
2329 name = string(
"near" );
2332 double flops = 0.0, mops = 0.0;
2335 event.Set( label + name, flops, mops );
2342 void DependencyAnalysis() { this->TryEnqueue(); };
2344 void Execute(
Worker* user_worker )
2346 NearSamples<NODE, T>( arg );
2352 template<
typename TREE>
2353 void SymmetrizeNearInteractions( TREE & tree )
2355 int n_nodes = 1 << tree.depth;
2356 auto level_beg = tree.treelist.begin() + n_nodes - 1;
2358 for (
int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2360 auto *node = *(level_beg + node_ind);
2361 auto & NearMortonIDs = node->NNNearNodeMortonIDs;
2362 for (
auto & it : NearMortonIDs )
2364 auto *target = tree.morton2node[ it ];
2365 target->NNNearNodes.insert( node );
2366 target->NNNearNodeMortonIDs.insert( it );
2373 template<
bool NNPRUNE,
typename NODE>
2383 name = string(
"c-n" );
2384 label = to_string( arg->treelist_id );
2391 double flops = 0.0, mops = 0.0;
2394 auto *NearNodes = &node->NearNodes;
2395 if ( NNPRUNE ) NearNodes = &node->NNNearNodes;
2396 auto &K = *node->setup->K;
2398 size_t m = node->gids.size();
2400 for (
auto it = NearNodes->begin(); it != NearNodes->end(); it ++ )
2402 n += (*it)->gids.size();
2405 event.Set( label + name, flops, mops );
2408 void DependencyAnalysis() { arg->DependOnNoOne(
this ); };
2410 void Execute(
Worker* user_worker )
2415 auto *NearNodes = &node->NearNodes;
2416 if ( NNPRUNE ) NearNodes = &node->NNNearNodes;
2417 auto &K = *node->setup->K;
2418 auto &data = node->data;
2419 auto &amap = node->gids;
2420 vector<size_t> bmap;
2421 for (
auto it = NearNodes->begin(); it != NearNodes->end(); it ++ )
2423 bmap.insert( bmap.end(), (*it)->gids.begin(), (*it)->gids.end() );
2425 data.NearKab = K( amap, bmap );
2428 data.Nearbmap.resize( bmap.size(), 1 );
2429 for (
size_t i = 0; i < bmap.size(); i ++ )
2430 data.Nearbmap[ i ] = bmap[ i ];
2432 #ifdef HMLP_USE_CUDA
2433 auto *device = hmlp_get_device( 0 );
2435 node->data.Nearbmap.PrefetchH2D( device, 8 );
2437 size_t preserve_size = 3000000000;
2441 if ( data.NearKab.col() * MAX_NRHS < 1200000000 &&
2442 data.NearKab.size() * 8 + preserve_size < device->get_memory_left() )
2445 data.NearKab.PrefetchH2D( device, 8 );
2449 printf(
"Kab %lu %lu not cache\n", data.NearKab.row(), data.NearKab.col() );
2468 template<
typename NODE>
2469 void FindFarNodes( NODE *node, NODE *target )
2472 assert( target->isleaf );
2475 set<NODE*> *NearNodes;
2476 auto &data = node->data;
2477 auto *lchild = node->lchild;
2478 auto *rchild = node->rchild;
2487 NearNodes = &target->NearNodes;
2490 if ( !data.isskel || node->ContainAny( *NearNodes ) )
2492 if ( !node->isleaf )
2495 FindFarNodes( lchild, target );
2496 FindFarNodes( rchild, target );
2502 target->FarNodes.insert( node );
2512 NearNodes = &target->NNNearNodes;
2515 if ( !data.isskel || node->ContainAny( *NearNodes ) )
2517 if ( !node->isleaf )
2520 FindFarNodes( lchild, target );
2521 FindFarNodes( rchild, target );
2526 if ( node->setup->IsSymmetric() && ( node->morton < target->morton ) )
2535 target->NNFarNodes.insert( node );
2552 template<
typename TREE>
2553 void MergeFarNodes( TREE &tree )
2555 for (
int l = tree.depth; l >= 0; l -- )
2557 size_t n_nodes = ( 1 << l );
2558 auto level_beg = tree.treelist.begin() + n_nodes - 1;
2560 for (
int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2562 auto *node = *(level_beg + node_ind);
2565 if ( !node->data.isskel )
continue;
2569 FindFarNodes( tree.treelist[ 0 ] , node );
2574 auto *lchild = node->lchild;
2575 auto *rchild = node->rchild;
2578 auto &pFarNodes = node->FarNodes;
2579 auto &lFarNodes = lchild->FarNodes;
2580 auto &rFarNodes = rchild->FarNodes;
2582 for (
auto it = lFarNodes.begin(); it != lFarNodes.end(); ++ it )
2584 if ( rFarNodes.count( *it ) ) pFarNodes.insert( *it );
2587 for (
auto it = pFarNodes.begin(); it != pFarNodes.end(); it ++ )
2589 lFarNodes.erase( *it ); rFarNodes.erase( *it );
2594 auto &pNNFarNodes = node->NNFarNodes;
2595 auto &lNNFarNodes = lchild->NNFarNodes;
2596 auto &rNNFarNodes = rchild->NNFarNodes;
2605 for (
auto it = lNNFarNodes.begin(); it != lNNFarNodes.end(); ++ it )
2607 if ( rNNFarNodes.count( *it ) ) pNNFarNodes.insert( *it );
2610 for (
auto it = pNNFarNodes.begin(); it != pNNFarNodes.end(); it ++ )
2612 lNNFarNodes.erase( *it );
2613 rNNFarNodes.erase( *it );
2623 if ( tree.setup.IsSymmetric() )
2626 for (
int l = tree.depth; l >= 0; l -- )
2628 std::size_t n_nodes = 1 << l;
2629 auto level_beg = tree.treelist.begin() + n_nodes - 1;
2631 for (
int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2633 auto *node = *(level_beg + node_ind);
2634 auto &pFarNodes = node->NNFarNodes;
2635 for (
auto it = pFarNodes.begin(); it != pFarNodes.end(); it ++ )
2637 (*it)->NNFarNodes.insert( node );
2643 #ifdef DEBUG_SPDASKIT 2644 for (
int l = tree.depth; l >= 0; l -- )
2646 std::size_t n_nodes = 1 << l;
2647 auto level_beg = tree.treelist.begin() + n_nodes - 1;
2649 for (
int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2651 auto *node = *(level_beg + node_ind);
2652 auto &pFarNodes = node->NNFarNodes;
2653 for (
auto it = pFarNodes.begin(); it != pFarNodes.end(); it ++ )
2655 if ( !( (*it)->NNFarNodes.count( node ) ) )
2657 printf(
"Unsymmetric FarNodes %lu, %lu\n", node->treelist_id, (*it)->treelist_id );
2659 PrintSet( node->NNNearNodes );
2660 PrintSet( (*it)->NNNearNodes );
2662 PrintSet( node->NNFarNodes );
2663 PrintSet( (*it)->NNFarNodes );
2664 printf(
"======\n" );
2668 if ( pFarNodes.size() )
2670 printf(
"l %2lu FarNodes(%lu) ", node->l, node->treelist_id );
2671 PrintSet( pFarNodes );
2686 template<
bool NNPRUNE,
bool CACHE = true,
typename TREE>
2687 void CacheFarNodes( TREE &tree )
2690 #pragma omp parallel for schedule( dynamic ) 2691 for (
size_t i = 0; i < tree.treelist.size(); i ++ )
2693 auto *node = tree.treelist[ i ];
2696 node->data.w_leaf.reserve( node->gids.size(), MAX_NRHS );
2697 node->data.u_leaf[ 0 ].reserve( MAX_NRHS, node->gids.size() );
2705 #pragma omp parallel for schedule( dynamic ) 2706 for (
size_t i = 0; i < tree.treelist.size(); i ++ )
2708 auto *node = tree.treelist[ i ];
2709 auto *FarNodes = &node->FarNodes;
2710 if ( NNPRUNE ) FarNodes = &node->NNFarNodes;
2711 auto &K = *node->setup->K;
2712 auto &data = node->data;
2713 auto &amap = data.skels;
2714 std::vector<size_t> bmap;
2715 for (
auto it = FarNodes->begin(); it != FarNodes->end(); it ++ )
2717 bmap.insert( bmap.end(), (*it)->data.skels.begin(),
2718 (*it)->data.skels.end() );
2720 data.FarKab = K( amap, bmap );
2729 template<
bool NNPRUNE,
typename TREE>
2730 double DrawInteraction( TREE &tree )
2732 double exact_ratio = 0.0;
2737 pFile = fopen (
"interaction.m",
"w" );
2739 fprintf( pFile,
"figure('Position',[100,100,800,800]);" );
2740 fprintf( pFile,
"hold on;" );
2741 fprintf( pFile,
"axis square;" );
2742 fprintf( pFile,
"axis ij;" );
2744 for (
int l = tree.depth; l >= 0; l -- )
2746 std::size_t n_nodes = 1 << l;
2747 auto level_beg = tree.treelist.begin() + n_nodes - 1;
2749 for (
int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2751 auto *node = *(level_beg + node_ind);
2755 auto &pNearNodes = node->NNNearNodes;
2756 auto &pFarNodes = node->NNFarNodes;
2757 for (
auto it = pFarNodes.begin(); it != pFarNodes.end(); it ++ )
2759 double gb = (double)std::min( node->l, (*it)->l ) / tree.depth;
2761 fprintf( pFile,
"rectangle('position',[%lu %lu %lu %lu],'facecolor',[1.0,%lf,%lf]);\n",
2762 node->offset, (*it)->offset,
2763 node->gids.size(), (*it)->gids.size(),
2766 for (
auto it = pNearNodes.begin(); it != pNearNodes.end(); it ++ )
2768 fprintf( pFile,
"rectangle('position',[%lu %lu %lu %lu],'facecolor',[0.2,0.4,1.0]);\n",
2769 node->offset, (*it)->offset,
2770 node->gids.size(), (*it)->gids.size() );
2773 exact_ratio += node->gids.size() * (*it)->gids.size();
2781 fprintf( pFile,
"hold off;" );
2784 return exact_ratio / ( tree.n * tree.n );
2796 template<
bool SYMBOLIC,
bool NNPRUNE,
typename NODE,
typename T>
2801 vector<size_t> &nnandi,
2805 auto &K = *node->setup->K;
2806 auto &w = *node->setup->w;
2807 auto &gids = node->gids;
2808 auto &data = node->data;
2809 auto *lchild = node->lchild;
2810 auto *rchild = node->rchild;
2812 size_t nrhs = w.col();
2814 auto amap = std::vector<size_t>( 1 );
2819 assert( potentials.size() == amap.size() * nrhs );
2822 if ( !data.isskel || node->ContainAny( nnandi ) )
2829 data.lock.Acquire();
2831 if ( NNPRUNE ) node->NNNearIDs.insert( gid );
2832 else node->NearIDs.insert( gid );
2834 data.lock.Release();
2838 #ifdef DEBUG_SPDASKIT 2839 printf(
"level %lu direct evaluation\n", node->l );
2842 auto Kab = K( amap, gids );
2845 std::vector<size_t> bmap( nrhs );
2846 for (
size_t j = 0; j < bmap.size(); j ++ )
2850 auto wb = w( gids, bmap );
2855 Kab.row(), wb.col(), wb.row(),
2856 1.0, Kab.data(), Kab.row(),
2857 wb.data(), wb.row(),
2858 1.0, potentials.data(), potentials.
row()
2864 Evaluate<SYMBOLIC, NNPRUNE>( lchild, gid, nnandi, potentials );
2865 Evaluate<SYMBOLIC, NNPRUNE>( rchild, gid, nnandi, potentials );
2873 data.lock.Acquire();
2876 if ( NNPRUNE ) node->FarIDs.insert( gid );
2877 else node->NNFarIDs.insert( gid );
2879 data.lock.Release();
2883 #ifdef DEBUG_SPDASKIT 2884 printf(
"level %lu is prunable\n", node->l );
2886 auto Kab = K( amap, node->data.skels );
2887 auto &w_skel = node->data.w_skel;
2891 Kab.row(), w_skel.col(), w_skel.row(),
2892 1.0, Kab.data(), Kab.row(),
2893 w_skel.data(), w_skel.row(),
2894 1.0, potentials.data(), potentials.
row()
2908 template<
bool SYMBOLIC,
bool NNPRUNE,
typename TREE,
typename T>
2916 vector<size_t> nnandi;
2917 auto &w = *tree.setup.w;
2920 potentials.resize( 1, w.col(), 0.0 );
2924 auto &NN = *tree.setup.NN;
2925 nnandi.reserve( NN.row() + 1 );
2926 nnandi.push_back( gid );
2927 for (
size_t i = 0; i < NN.row(); i ++ )
2929 nnandi.push_back( NN( i, gid ).second );
2931 #ifdef DEBUG_SPDASKIT 2932 printf(
"nnandi.size() %lu\n", nnandi.size() );
2937 nnandi.reserve( 1 );
2938 nnandi.push_back( gid );
2941 Evaluate<SYMBOLIC, NNPRUNE>( tree.treelist[ 0 ], gid, nnandi, potentials );
2950 bool USE_RUNTIME =
true,
2951 bool USE_OMP_TASK =
false,
2952 bool NNPRUNE =
true,
2962 const bool AUTO_DEPENDENCY =
true;
2965 using NODE =
typename TREE::NODE;
2968 double beg, time_ratio, evaluation_time = 0.0;
2969 double allocate_time, computeall_time;
2970 double forward_permute_time, backward_permute_time;
2973 tree.DependencyCleanUp();
2976 size_t n = weights.
row();
2977 size_t nrhs = weights.
col();
2979 beg = omp_get_wtime();
2981 tree.setup.w = &weights;
2982 tree.setup.u = &potentials;
2983 allocate_time = omp_get_wtime() - beg;
2986 if ( REPORT_EVALUATE_STATUS )
2988 printf(
"Forward permute ...\n" ); fflush( stdout );
2990 beg = omp_get_wtime();
2991 int n_nodes = ( 1 << tree.depth );
2992 auto level_beg = tree.treelist.begin() + n_nodes - 1;
2993 #pragma omp parallel for 2994 for (
int node_ind = 0; node_ind < n_nodes; node_ind ++ )
2996 auto *node = *(level_beg + node_ind);
2999 auto &gids = node->gids;
3000 auto &w_leaf = node->data.w_leaf;
3002 if ( w_leaf.row() != gids.size() || w_leaf.col() != weights.
col() )
3004 w_leaf.resize( gids.size(), weights.
col() );
3007 for (
size_t j = 0; j < w_leaf.col(); j ++ )
3009 for (
size_t i = 0; i < w_leaf.row(); i ++ )
3011 w_leaf( i, j ) = weights( gids[ i ], j );
3015 forward_permute_time = omp_get_wtime() - beg;
3020 if ( REPORT_EVALUATE_STATUS )
3022 printf(
"N2S, S2S, S2N, L2L (HMLP Runtime) ...\n" ); fflush( stdout );
3024 if ( tree.setup.IsSymmetric() )
3026 beg = omp_get_wtime();
3027 #ifdef HMLP_USE_CUDA 3028 potentials.AllocateD( hmlp_get_device( 0 ) );
3030 LEAFTOLEAFVER2TASK leaftoleafver2task;
3041 LEAFTOLEAFTASK1 leaftoleaftask1;
3042 LEAFTOLEAFTASK2 leaftoleaftask2;
3043 LEAFTOLEAFTASK3 leaftoleaftask3;
3044 LEAFTOLEAFTASK4 leaftoleaftask4;
3046 NODETOSKELTASK nodetoskeltask;
3047 SKELTOSKELTASK skeltoskeltask;
3048 SKELTONODETASK skeltonodetask;
3098 #ifdef HMLP_USE_CUDA 3099 tree.TraverseLeafs( leaftoleafver2task );
3101 tree.TraverseLeafs( leaftoleaftask1 );
3102 tree.TraverseLeafs( leaftoleaftask2 );
3103 tree.TraverseLeafs( leaftoleaftask3 );
3104 tree.TraverseLeafs( leaftoleaftask4 );
3106 tree.TraverseUp( nodetoskeltask );
3107 tree.TraverseUnOrdered( skeltoskeltask );
3108 tree.TraverseDown( skeltonodetask );
3109 tree.ExecuteAllTasks();
3113 double d2h_beg_t = omp_get_wtime();
3114 #ifdef HMLP_USE_CUDA 3116 for (
int stream_id = 0; stream_id < 10; stream_id ++ )
3117 device->wait( stream_id );
3119 potentials.FetchD2H( device );
3121 double d2h_t = omp_get_wtime() - d2h_beg_t;
3122 printf(
"d2h_t %lfs\n", d2h_t );
3125 double aggregate_beg_t = omp_get_wtime();
3127 #pragma omp parallel for 3128 for (
int node_ind = 0; node_ind < n_nodes; node_ind ++ )
3130 auto *node = *(level_beg + node_ind);
3131 auto &u_leaf = node->data.u_leaf[ 0 ];
3133 for (
size_t p = 1; p < 20; p ++ )
3135 for (
size_t i = 0; i < node->data.u_leaf[ p ].size(); i ++ )
3136 u_leaf[ i ] += node->data.u_leaf[ p ][ i ];
3139 double aggregate_t = omp_get_wtime() - aggregate_beg_t;
3140 printf(
"aggregate_t %lfs\n", d2h_t );
3142 #ifdef HMLP_USE_CUDA 3145 computeall_time = omp_get_wtime() - beg;
3150 printf(
"Non symmetric ComputeAll is not yet implemented\n" );
3157 if ( REPORT_EVALUATE_STATUS )
3159 printf(
"Backward permute ...\n" ); fflush( stdout );
3161 beg = omp_get_wtime();
3162 #pragma omp parallel for 3163 for (
int node_ind = 0; node_ind < n_nodes; node_ind ++ )
3165 auto *node = *(level_beg + node_ind);
3166 auto &amap = node->gids;
3167 auto &u_leaf = node->data.u_leaf[ 0 ];
3177 for (
size_t j = 0; j < potentials.
col(); j ++ )
3178 for (
size_t i = 0; i < amap.size(); i ++ )
3179 potentials( amap[ i ], j ) += u_leaf( i, j );
3183 backward_permute_time = omp_get_wtime() - beg;
3185 evaluation_time += allocate_time;
3186 evaluation_time += forward_permute_time;
3187 evaluation_time += computeall_time;
3188 evaluation_time += backward_permute_time;
3189 time_ratio = 100 / evaluation_time;
3191 if ( REPORT_EVALUATE_STATUS )
3193 printf(
"========================================================\n");
3194 printf(
"GOFMM evaluation phase\n" );
3195 printf(
"========================================================\n");
3196 printf(
"Allocate ------------------------------ %5.2lfs (%5.1lf%%)\n",
3197 allocate_time, allocate_time * time_ratio );
3198 printf(
"Forward permute ----------------------- %5.2lfs (%5.1lf%%)\n",
3199 forward_permute_time, forward_permute_time * time_ratio );
3200 printf(
"N2S, S2S, S2N, L2L -------------------- %5.2lfs (%5.1lf%%)\n",
3201 computeall_time, computeall_time * time_ratio );
3202 printf(
"Backward permute ---------------------- %5.2lfs (%5.1lf%%)\n",
3203 backward_permute_time, backward_permute_time * time_ratio );
3204 printf(
"========================================================\n");
3205 printf(
"Evaluate ------------------------------ %5.2lfs (%5.1lf%%)\n",
3206 evaluation_time, evaluation_time * time_ratio );
3207 printf(
"========================================================\n\n");
3212 tree.DependencyCleanUp();
3221 template<
typename SPLITTER,
typename T,
typename SPDMATRIX>
3235 using NODE =
typename TREE::NODE;
3237 DistanceMetric metric = config.MetricType();
3238 size_t n = config.ProblemSize();
3239 size_t k = config.NeighborSize();
3241 pair<T, size_t> init( numeric_limits<T>::max(), n );
3244 rkdt.setup.FromConfiguration( config, K, splitter, NULL );
3245 return rkdt.AllNearestNeighbor( n_iter, k, n_iter, init, NEIGHBORStask );
3252 template<
typename SPLITTER,
typename RKDTSPLITTER,
typename T,
typename SPDMATRIX>
3257 Data<pair<T, size_t>> &NN,
3259 RKDTSPLITTER rkdtsplitter,
3264 DistanceMetric metric = config.MetricType();
3265 size_t n = config.ProblemSize();
3266 size_t m = config.LeafNodeSize();
3267 size_t k = config.NeighborSize();
3268 size_t s = config.MaximumRank();
3269 T stol = config.Tolerance();
3270 T budget = config.Budget();
3273 const bool NNPRUNE =
true;
3274 const bool CACHE =
true;
3281 using NODE =
typename TREE::NODE;
3285 double beg, omptask45_time, omptask_time, ref_time;
3286 double time_ratio, compress_time = 0.0, other_time = 0.0;
3287 double ann_time, tree_time, skel_time, mergefarnodes_time, cachefarnodes_time;
3288 double nneval_time, nonneval_time, fmm_evaluation_time, symbolic_evaluation_time;
3291 beg = omp_get_wtime();
3292 if ( NN.size() != n * k )
3294 NN = gofmm::FindNeighbors( K, rkdtsplitter, config );
3296 ann_time = omp_get_wtime() - beg;
3300 auto *tree_ptr =
new TREE();
3301 auto &tree = *tree_ptr;
3302 tree.setup.FromConfiguration( config, K, splitter, &NN );
3305 if ( REPORT_COMPRESS_STATUS )
3307 printf(
"TreePartitioning ...\n" ); fflush( stdout );
3309 beg = omp_get_wtime();
3310 tree.TreePartition();
3311 tree_time = omp_get_wtime() - beg;
3316 assert( omp_get_max_threads() == 68 );
3319 hmlp_set_num_workers( 17 );
3325 if ( REPORT_COMPRESS_STATUS )
3327 printf(
"omp_get_max_threads() %d\n", omp_get_max_threads() );
3336 tree.DependencyCleanUp();
3337 printf(
"Dependency clean up\n" ); fflush( stdout );
3338 tree.TraverseLeafs( NEARSAMPLEStask );
3339 tree.ExecuteAllTasks();
3341 printf(
"Finish NearSamplesTask\n" ); fflush( stdout );
3342 SymmetrizeNearInteractions( tree );
3343 printf(
"Finish SymmetrizeNearInteractions\n" ); fflush( stdout );
3348 if ( REPORT_COMPRESS_STATUS )
3350 printf(
"Skeletonization (HMLP Runtime) ...\n" ); fflush( stdout );
3352 beg = omp_get_wtime();
3356 tree.DependencyCleanUp();
3357 tree.TraverseUp( GETMTXtask, SKELtask );
3358 tree.TraverseUnOrdered( PROJtask );
3362 tree.template TraverseLeafs( KIJtask );
3364 other_time += omp_get_wtime() - beg;
3366 skel_time = omp_get_wtime() - beg;
3372 beg = omp_get_wtime();
3373 if ( REPORT_COMPRESS_STATUS )
3375 printf(
"MergeFarNodes ...\n" ); fflush( stdout );
3377 gofmm::MergeFarNodes( tree );
3378 mergefarnodes_time = omp_get_wtime() - beg;
3381 beg = omp_get_wtime();
3382 if ( REPORT_COMPRESS_STATUS )
3384 printf(
"CacheFarNodes ...\n" ); fflush( stdout );
3386 gofmm::CacheFarNodes<NNPRUNE, CACHE>( tree );
3387 cachefarnodes_time = omp_get_wtime() - beg;
3390 auto exact_ratio = hmlp::gofmm::DrawInteraction<true>( tree );
3392 compress_time += ann_time;
3393 compress_time += tree_time;
3394 compress_time += skel_time;
3395 compress_time += mergefarnodes_time;
3396 compress_time += cachefarnodes_time;
3397 time_ratio = 100.0 / compress_time;
3398 if ( REPORT_COMPRESS_STATUS )
3400 printf(
"========================================================\n");
3401 printf(
"GOFMM compression phase\n" );
3402 printf(
"========================================================\n");
3403 printf(
"NeighborSearch ------------------------ %5.2lfs (%5.1lf%%)\n", ann_time, ann_time * time_ratio );
3404 printf(
"TreePartitioning ---------------------- %5.2lfs (%5.1lf%%)\n", tree_time, tree_time * time_ratio );
3405 printf(
"Skeletonization ----------------------- %5.2lfs (%5.1lf%%)\n", skel_time, skel_time * time_ratio );
3406 printf(
"MergeFarNodes ------------------------- %5.2lfs (%5.1lf%%)\n", mergefarnodes_time, mergefarnodes_time * time_ratio );
3407 printf(
"CacheFarNodes ------------------------- %5.2lfs (%5.1lf%%)\n", cachefarnodes_time, cachefarnodes_time * time_ratio );
3408 printf(
"========================================================\n");
3409 printf(
"Compress (%4.2lf not compressed) -------- %5.2lfs (%5.1lf%%)\n",
3410 exact_ratio, compress_time, compress_time * time_ratio );
3411 printf(
"========================================================\n\n");
3415 tree_ptr->DependencyCleanUp();
3465 template<
typename T,
typename SPDMATRIX>
3469 *Compress( SPDMATRIX &K, T stol, T budget,
size_t m,
size_t k,
size_t s )
3475 SPLITTER splitter( K );
3477 splitter.metric = ANGLE_DISTANCE;
3479 RKDTSPLITTER rkdtsplitter( K );
3480 rkdtsplitter.Kptr = &K;
3481 rkdtsplitter.metric = ANGLE_DISTANCE;
3488 return Compress<SPLITTER, RKDTSPLITTER>
3490 splitter, rkdtsplitter,
3504 template<
typename T,
typename SPDMATRIX>
3506 gofmm::Setup<SPDMATRIX, centersplit<SPDMATRIX, 2, T>, T>,
3508 *Compress( SPDMATRIX &K, T stol, T budget )
3514 SPLITTER splitter( K );
3516 splitter.metric = ANGLE_DISTANCE;
3518 RKDTSPLITTER rkdtsplitter( K );
3519 rkdtsplitter.Kptr = &K;
3520 rkdtsplitter.metric = ANGLE_DISTANCE;
3552 return Compress<SPLITTER, RKDTSPLITTER>
3554 splitter, rkdtsplitter, config );
3561 template<
typename T>
3567 return Compress<T, SPDMatrix<T>>( K, stol, budget );
3576 template<
typename NODE,
typename T>
3577 void ComputeError( NODE *node,
Data<T> potentials )
3579 auto &K = *node->setup->K;
3580 auto &w = node->setup->w;
3582 auto &amap = node->gids;
3583 std::vector<size_t> bmap = std::vector<size_t>( K.
col() );
3585 for (
size_t j = 0; j < bmap.size(); j ++ ) bmap[ j ] = j;
3587 auto Kab = K( amap, bmap );
3589 auto nrm2 = hmlp_norm( potentials.
row(), potentials.
col(),
3590 potentials.data(), potentials.
row() );
3595 Kab.row(), w.row(), w.col(),
3596 -1.0, Kab.data(), Kab.row(),
3598 1.0, potentials.data(), potentials.
row()
3601 auto err = hmlp_norm( potentials.
row(), potentials.
col(),
3602 potentials.data(), potentials.
row() );
3604 printf(
"node relative error %E, nrm2 %E\n", err / nrm2, nrm2 );
3619 template<
typename TREE,
typename T>
3620 T ComputeError( TREE &tree,
size_t gid,
Data<T> potentials )
3622 auto &K = *tree.setup.K;
3623 auto &w = *tree.setup.w;
3625 auto amap = std::vector<size_t>( 1, gid );
3626 auto bmap = std::vector<size_t>( K.
col() );
3627 for (
size_t j = 0; j < bmap.size(); j ++ ) bmap[ j ] = j;
3629 auto Kab = K( amap, bmap );
3630 auto exact = potentials;
3635 Kab.row(), w.col(), w.row(),
3636 1.0, Kab.data(), Kab.row(),
3638 0.0, exact.data(), exact.row()
3642 auto nrm2 = hmlp_norm( exact.row(), exact.col(),
3643 exact.data(), exact.row() );
3648 Kab.row(), w.col(), w.row(),
3649 -1.0, Kab.data(), Kab.row(),
3651 1.0, potentials.data(), potentials.
row()
3654 auto err = hmlp_norm( potentials.
row(), potentials.
col(),
3655 potentials.data(), potentials.
row() );
3662 template<
typename TREE>
3663 void SelfTesting( TREE &tree,
size_t ntest,
size_t nrhs )
3666 using T =
typename TREE::T;
3670 if ( ntest > n ) ntest = n;
3672 vector<size_t> all_rhs( nrhs );
3673 for (
size_t rhs = 0; rhs < nrhs; rhs ++ ) all_rhs[ rhs ] = rhs;
3678 Data<T> w( n, nrhs ); w.rand();
3679 auto u = Evaluate<true, false, true, true>( tree, w );
3683 T nonnerr_avg = 0.0;
3685 printf(
"========================================================\n");
3686 printf(
"Accuracy report\n" );
3687 printf(
"========================================================\n");
3688 for (
size_t i = 0; i < ntest; i ++ )
3690 size_t tar = i * n / ntest;
3693 Evaluate<false, true>( tree, tar, potentials );
3694 auto nnerr = ComputeError( tree, tar, potentials );
3696 Evaluate<false, false>( tree, tar, potentials );
3697 auto nonnerr = ComputeError( tree, tar, potentials );
3700 for (
size_t p = 0; p < potentials.
col(); p ++ )
3702 potentials[ p ] = u( tar, p );
3704 auto fmmerr = ComputeError( tree, tar, potentials );
3709 printf(
"gid %6lu, ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n",
3710 tar, nnerr, nonnerr, fmmerr );
3713 nonnerr_avg += nonnerr;
3714 fmmerr_avg += fmmerr;
3716 printf(
"========================================================\n");
3717 printf(
" ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n",
3718 nnerr_avg / ntest , nonnerr_avg / ntest, fmmerr_avg / ntest );
3719 printf(
"========================================================\n");
3721 if ( !tree.setup.SecureAccuracy() )
3725 gofmm::Factorize( tree, lambda );
3727 gofmm::ComputeError( tree, lambda, w, u );
3734 template<
typename SPDMATRIX>
3737 using T =
typename SPDMATRIX::T;
3739 const int N_CHILDREN = 2;
3744 SPLITTER splitter( K );
3746 splitter.metric = cmd.
metric;
3748 RKDTSPLITTER rkdtsplitter( K );
3749 rkdtsplitter.Kptr = &K;
3750 rkdtsplitter.metric = cmd.
metric;
3753 cmd.
n, cmd.m, cmd.k, cmd.s, cmd.
stol, cmd.budget );
3758 auto *tree_ptr = gofmm::Compress( K, NN, splitter, rkdtsplitter, config );
3759 auto &tree = *tree_ptr;
3761 gofmm::SelfTesting( tree, 100, cmd.nrhs );
3778 template<
typename T,
typename SPDMATRIX>
3785 tree_ptr = Compress( K, stol, budget );
3790 if ( tree_ptr )
delete tree_ptr;
3802 y = gofmm::Evaluate( *tree_ptr, x );
3815 gofmm::Setup<SPDMATRIX, centersplit<SPDMATRIX, 2, T>, T>,
3860 dTree_t *Compress( dSPDMatrix_t *K,
double stol,
double budget );
3861 sTree_t *Compress( sSPDMatrix_t *K,
float stol,
float budget );
3863 dTree_t *Compress( dSPDMatrix_t *K,
double stol,
double budget,
3864 size_t m,
size_t k,
size_t s );
3865 sTree_t *Compress( sSPDMatrix_t *K,
float stol,
float budget,
3866 size_t m,
size_t k,
size_t s );
deque< Statistic > s2s_kij_t
Definition: gofmm.hpp:508
Configuration contains all user-defined parameters.
Definition: gofmm.hpp:212
void Set(NODE *user_arg)
Definition: gofmm.hpp:1271
Data< T > w_leaf
Definition: gofmm.hpp:382
deque< Statistic > s2n_kij_t
Definition: gofmm.hpp:513
Definition: gofmm_gpu.hpp:497
Event skeletonize
Definition: gofmm.hpp:396
NodeData()
Definition: gofmm.hpp:350
void FromConfiguration(Configuration< T > &config, SPDMATRIX &K, SPLITTER &splitter, Data< pair< T, size_t >> *NN)
Definition: gofmm.hpp:309
void Set(NODE *user_arg)
Definition: gofmm.hpp:1476
void GetEventRecord()
Definition: gofmm.hpp:2389
Definition: gofmm.hpp:3779
vector< size_t > candidate_rows
Definition: gofmm.hpp:371
Definition: gofmm.hpp:1265
Definition: igofmm.hpp:70
Data< T > w_skel
Definition: gofmm.hpp:378
void Set(NODE *user_arg)
Definition: gofmm.hpp:747
This task creates an hierarchical tree view for w<RIDS> and u<RIDS>.
Definition: gofmm.hpp:414
void Set(NODE *user_arg)
Definition: gofmm.hpp:2111
vector< size_t > skels
Definition: gofmm.hpp:359
size_t n
Definition: tree.hpp:971
Definition: gofmm.hpp:1867
size_t row()
Definition: View.hpp:345
map< size_t, T > snids
Definition: gofmm.hpp:368
void GetEventRecord()
Definition: gofmm.hpp:1282
CommandLineHelper(int argc, char *argv[])
Definition: gofmm.hpp:94
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
void DependencyAnalysis()
Definition: gofmm.hpp:2178
Data and setup that are shared with all nodes.
Definition: tree.hpp:876
void Set(NODE *user_arg)
Definition: gofmm.hpp:1174
Definition: gofmm.hpp:1168
size_t col()
Definition: View.hpp:348
deque< Statistic > updateweight
Definition: gofmm.hpp:505
View< T > w_view
Definition: gofmm.hpp:386
T * data()
Definition: View.hpp:354
class Device * GetDevice()
Definition: thread.cpp:554
Data< T > KIJ
Definition: gofmm.hpp:375
void DependencyAnalysis()
Definition: gofmm.hpp:429
void Set(NODE *user_arg)
Definition: gofmm.hpp:1693
This class does not need to inherit hmlp::Data<T>, but it should support two interfaces for data fetc...
Definition: SPDMatrix.hpp:21
This the splitter used in the randomized tree.
Definition: gofmm.hpp:658
Data< size_t > Nearbmap
Definition: gofmm.hpp:390
void Set(NODE *user_arg)
Definition: gofmm.hpp:2380
void Set(NODE *user_arg)
Definition: gofmm.hpp:2326
size_t col() const noexcept
Definition: Data.hpp:281
void GetEventRecord()
Definition: gofmm.hpp:2172
Wrapper for omp or pthread mutex.
Definition: tci.hpp:50
Provide statistics summary for the execution section.
Definition: gofmm.hpp:493
void HeapSelect(size_t n, size_t k, std::pair< T, size_t > *Query, std::pair< T, size_t > *NN)
Definition: util.hpp:520
void MergeNeighbors(size_t k, pair< T, size_t > *A, pair< T, size_t > *B, vector< pair< T, size_t >> &aux)
Definition: tree.hpp:212
Task wrapper for CacheNearNodes().
Definition: gofmm.hpp:2374
size_t row() const noexcept
Definition: Data.hpp:278
These are data that shared by the whole local tree.
Definition: gofmm.hpp:301
Data< T > proj
Definition: gofmm.hpp:365
void Set(bool TRANS, Data< T > &buff)
Definition: View.hpp:60
void Set(NODE *user_arg)
Definition: gofmm.hpp:1873
Definition: gofmm.hpp:2101
Definition: gofmm.hpp:738
size_t ld()
Definition: View.hpp:351
There is no dependency between each task. However there are raw (read after write) dependencies: ...
Definition: gofmm.hpp:1687
The correponding task of Interpolate().
Definition: gofmm.hpp:945
DistanceMetric metric
Definition: gofmm.hpp:193
This is a helper class that parses the arguments from command lines.
Definition: gofmm.hpp:89
This class describes devices or accelerators that require a master thread to control. A device can accept tasks from multiple workers. All received tasks are expected to be executed independently in a time-sharing fashion. Whether these tasks are executed in parallel, sequential or with some built-in context switching scheme does not matter.
Definition: device.hpp:125
void Execute(Worker *user_worker)
Definition: gofmm.hpp:431
size_t n
Definition: gofmm.hpp:185
vector< int > jpvt
Definition: gofmm.hpp:362
void xtrsm(const char *side, const char *uplo, const char *transA, const char *diag, int m, int n, double alpha, double *A, int lda, double *B, int ldb)
DTRSM wrapper.
Definition: blas_lapack.cpp:315
Definition: gofmm.hpp:1470
double stol
Definition: gofmm.hpp:190
size_t d
Definition: gofmm.hpp:196
This the main splitter used to build the Spd-Askit tree. First compute the approximate center using s...
Definition: gofmm.hpp:580
Wrapper for omp or pthread mutex.
Definition: runtime.hpp:113
This class contains all GOFMM related data carried by a tree node.
Definition: gofmm.hpp:345
Definition: runtime.hpp:174
size_t col()
Definition: VirtualMatrix.hpp:85
Definition: gofmm.hpp:2320
Definition: thread.hpp:166