28 #include <tree_mpi.hpp> 29 #include <igofmm_mpi.hpp> 208 template<
typename SPDMATRIX,
typename SPLITTER,
typename T>
216 SPDMATRIX &K, SPLITTER &splitter,
217 DistData<STAR, CBLK, pair<T, size_t>>* NN_cblk )
219 this->CopyFrom( config );
221 this->splitter = splitter;
222 this->NN_cblk = NN_cblk;
243 bool do_ulv_factorization =
true;
258 template<
typename NODE>
265 void Set( NODE *user_arg )
268 name = string(
"TreeView" );
269 label = to_string( arg->treelist_id );
281 auto &w = *(node->setup->w);
282 auto &u = *(node->setup->u);
285 auto &U = node->data.u_view;
286 auto &W = node->data.w_view;
293 if ( !node->isleaf && !node->child )
295 assert( node->lchild && node->rchild );
296 auto &UL = node->lchild->data.u_view;
297 auto &UR = node->rchild->data.u_view;
298 auto &WL = node->lchild->data.w_view;
299 auto &WR = node->rchild->data.w_view;
305 UR, node->lchild->n, TOP );
307 WR, node->lchild->n, TOP );
323 vector<vector<size_t>> DistMedianSplit( vector<T> &values, mpi::Comm comm )
326 int num_points_owned = values.size();
328 mpi::Allreduce( &num_points_owned, &n, 1, MPI_SUM, comm );
329 T median = combinatorics::Select( n / 2, values, comm );
331 vector<vector<size_t>> split( 2 );
332 vector<size_t> middle;
334 if ( n == 0 )
return split;
336 for (
size_t i = 0; i < values.size(); i ++ )
338 auto v = values[ i ];
339 if ( std::fabs( v - median ) < 1E-6 ) middle.push_back( i );
340 else if ( v < median ) split[ 0 ].push_back( i );
341 else split[ 1 ].push_back( i );
347 int num_mid_owned = middle.size();
348 int num_lhs_owned = split[ 0 ].size();
349 int num_rhs_owned = split[ 1 ].size();
352 mpi::Allreduce( &num_mid_owned, &nmid, 1, MPI_SUM, comm );
353 mpi::Allreduce( &num_lhs_owned, &nlhs, 1, MPI_SUM, comm );
354 mpi::Allreduce( &num_rhs_owned, &nrhs, 1, MPI_SUM, comm );
359 int nlhs_required, nrhs_required;
363 nlhs_required = ( n - 1 ) / 2 + 1 - nlhs;
364 nrhs_required = nmid - nlhs_required;
368 nrhs_required = ( n - 1 ) / 2 + 1 - nrhs;
369 nlhs_required = nmid - nrhs_required;
372 assert( nlhs_required >= 0 && nrhs_required >= 0 );
375 double lhs_ratio = ( (double)nlhs_required ) / nmid;
376 int nlhs_required_owned = num_mid_owned * lhs_ratio;
377 int nrhs_required_owned = num_mid_owned - nlhs_required_owned;
384 assert( nlhs_required_owned >= 0 && nrhs_required_owned >= 0 );
386 for (
size_t i = 0; i < middle.size(); i ++ )
388 if ( i < nlhs_required_owned )
389 split[ 0 ].push_back( middle[ i ] );
391 split[ 1 ].push_back( middle[ i ] );
407 template<
typename SPDMATRIX,
int N_SPLIT,
typename T>
416 inline vector<vector<size_t> > operator() ( vector<size_t>& gids )
const 422 inline vector<vector<size_t> > operator() ( vector<size_t>& gids, mpi::Comm comm )
const 425 assert( N_SPLIT == 2 );
426 assert( this->Kptr );
429 int size; mpi::Comm_size( comm, &size );
430 int rank; mpi::Comm_rank( comm, &rank );
431 auto &K = *(this->Kptr);
434 vector<T> temp( gids.size(), 0.0 );
437 auto column_samples = combinatorics::SampleWithoutReplacement(
438 this->n_centroid_samples, gids );
441 mpi::Bcast( column_samples.data(), column_samples.size(), 0, comm );
442 K.BcastIndices( column_samples, 0, comm );
445 auto DIC = K.Distances( this->metric, gids, column_samples );
448 for (
auto & it : temp ) it = 0;
451 for (
size_t j = 0; j < DIC.col(); j ++ )
452 for (
size_t i = 0; i < DIC.row(); i ++ )
453 temp[ i ] += DIC( i, j );
456 auto idf2c = distance( temp.begin(), max_element( temp.begin(), temp.end() ) );
460 local_max_pair.val = temp[ idf2c ];
461 local_max_pair.key = rank;
464 mpi::Allreduce( &local_max_pair, &max_pair, 1, MPI_MAXLOC, comm );
467 int gidf2c = gids[ idf2c ];
468 mpi::Bcast( &gidf2c, 1, MPI_INT, max_pair.key, comm );
477 vector<size_t> P( 1, gidf2c );
478 K.BcastIndices( P, max_pair.key, comm );
481 auto DIP = K.Distances( this->metric, gids, P );
484 auto idf2f = distance( DIP.begin(), max_element( DIP.begin(), DIP.end() ) );
487 local_max_pair.val = DIP[ idf2f ];
488 local_max_pair.key = rank;
491 mpi::Allreduce( &local_max_pair, &max_pair, 1, MPI_MAXLOC, comm );
494 int gidf2f = gids[ idf2f ];
495 mpi::Bcast( &gidf2f, 1, MPI_INT, max_pair.key, comm );
503 vector<size_t> Q( 1, gidf2f );
504 K.BcastIndices( Q, max_pair.key, comm );
507 auto DIQ = K.Distances( this->metric, gids, P );
510 for (
size_t i = 0; i < temp.size(); i ++ )
511 temp[ i ] = DIP[ i ] - DIQ[ i ];
514 auto split = DistMedianSplit( temp, comm );
518 vector<size_t> sent_gids;
519 int partner = ( rank + size / 2 ) % size;
520 if ( rank < size / 2 )
522 for (
auto it : split[ 1 ] )
523 sent_gids.push_back( gids[ it ] );
524 K.SendIndices( sent_gids, partner, comm );
525 K.RecvIndices( partner, comm, &status );
529 for (
auto it : split[ 0 ] )
530 sent_gids.push_back( gids[ it ] );
531 K.RecvIndices( partner, comm, &status );
532 K.SendIndices( sent_gids, partner, comm );
545 template<
typename SPDMATRIX,
int N_SPLIT,
typename T>
554 inline vector<vector<size_t> > operator() ( vector<size_t>& gids )
const 560 inline vector<vector<size_t> > operator() ( vector<size_t>& gids, mpi::Comm comm )
const 563 assert( N_SPLIT == 2 );
564 assert( this->Kptr );
567 int size, rank, global_rank, global_size;
568 mpi::Comm_size( comm, &size );
569 mpi::Comm_rank( comm, &rank );
570 mpi::Comm_rank( MPI_COMM_WORLD, &global_rank );
571 mpi::Comm_size( MPI_COMM_WORLD, &global_size );
572 SPDMATRIX &K = *(this->Kptr);
575 if ( size == global_size )
577 for (
size_t i = 0; i < gids.size(); i ++ )
578 assert( gids[ i ] == i * size + rank );
586 int num_points_owned = gids.size();
587 vector<T> temp( gids.size(), 0.0 );
590 mpi::Allreduce( &num_points_owned, &n, 1, MPI_INT, MPI_SUM, comm );
596 size_t gidf2c, gidf2f;
599 gidf2c = gids[ std::rand() % gids.size() ];
600 gidf2f = gids[ std::rand() % gids.size() ];
605 local_max_pair.val = gids.size();
606 local_max_pair.key = rank;
609 mpi::Allreduce( &local_max_pair, &max_pair, 1, MPI_MAXLOC, comm );
612 mpi::Bcast( &gidf2c, 1, max_pair.key, comm );
613 vector<size_t> P( 1, gidf2c );
614 K.BcastIndices( P, max_pair.key, comm );
617 if ( rank == max_pair.key ) local_max_pair.val = 0;
620 mpi::Allreduce( &local_max_pair, &max_pair, 1, MPI_MAXLOC, comm );
623 mpi::Bcast( &gidf2f, 1, max_pair.key, comm );
624 vector<size_t> Q( 1, gidf2f );
625 K.BcastIndices( Q, max_pair.key, comm );
628 auto DIP = K.Distances( this->metric, gids, P );
629 auto DIQ = K.Distances( this->metric, gids, Q );
632 for (
size_t i = 0; i < temp.size(); i ++ )
633 temp[ i ] = DIP[ i ] - DIQ[ i ];
636 auto split = DistMedianSplit( temp, comm );
640 vector<size_t> sent_gids;
641 int partner = ( rank + size / 2 ) % size;
642 if ( rank < size / 2 )
644 for (
auto it : split[ 1 ] )
645 sent_gids.push_back( gids[ it ] );
646 K.SendIndices( sent_gids, partner, comm );
647 K.RecvIndices( partner, comm, &status );
651 for (
auto it : split[ 0 ] )
652 sent_gids.push_back( gids[ it ] );
653 K.RecvIndices( partner, comm, &status );
654 K.SendIndices( sent_gids, partner, comm );
688 template<
typename NODE>
689 void DistUpdateWeights( NODE *node )
692 using T =
typename NODE::T;
695 auto comm = node->GetComm();
696 int size = node->GetCommSize();
697 int rank = node->GetCommRank();
700 if ( !node->parent || !node->data.isskel )
return;
705 gofmm::UpdateWeights( node );
710 auto &w = *node->setup->w;
711 size_t nrhs = w.col();
714 auto &data = node->data;
715 auto &proj = data.proj;
716 auto &w_skel = data.w_skel;
721 size_t s = proj.row();
722 size_t sl = node->child->data.skels.size();
723 size_t sr = proj.col() - sl;
725 w_skel.resize( s, nrhs );
727 View<T> P(
false, proj ), PL, PR;
728 View<T> W(
false, w_skel ), WL(
false, node->child->data.w_skel );
730 P.Partition1x2( PL, PR, sl, LEFT );
732 gemm::xgemm<GEMM_NB>( (T)1.0, PL, WL, (T)0.0, W );
735 mpi::ExchangeVector( w_skel, size / 2, 0, w_skel_sib, size / 2, 0, comm, &status );
737 #pragma omp parallel for 738 for (
size_t i = 0; i < w_skel.size(); i ++ )
739 w_skel[ i ] += w_skel_sib[ i ];
743 if ( rank == size / 2 )
745 size_t s = proj.row();
746 size_t sr = node->child->data.skels.size();
747 size_t sl = proj.col() - sr;
749 w_skel.resize( s, nrhs );
751 View<T> P(
false, proj ), PL, PR;
752 View<T> W(
false, w_skel ), WR(
false, node->child->data.w_skel );
754 P.Partition1x2( PL, PR, sl, LEFT );
756 gemm::xgemm<GEMM_NB>( (T)1.0, PR, WR, (T)0.0, W );
760 mpi::ExchangeVector( w_skel, 0, 0, w_skel_sib, 0, 0, comm, &status );
772 template<
typename NODE,
typename T>
779 void Set( NODE *user_arg )
782 name = string(
"DistN2S" );
783 label = to_string( arg->treelist_id );
786 double flops = 0.0, mops = 0.0;
787 auto &gids = arg->gids;
788 auto &skels = arg->data.skels;
789 auto &w = *arg->setup->w;
795 auto m = skels.size();
797 auto k = gids.size();
798 flops = 2.0 * m * n * k;
799 mops = 2.0 * ( m * n + m * k + k * n );
803 auto &lskels = arg->lchild->data.skels;
804 auto &rskels = arg->rchild->data.skels;
805 auto m = skels.size();
807 auto k = lskels.size() + rskels.size();
808 flops = 2.0 * m * n * k;
809 mops = 2.0 * ( m * n + m * k + k * n );
814 if ( arg->GetCommRank() == 0 )
816 auto &lskels = arg->child->data.skels;
817 auto m = skels.size();
819 auto k = lskels.size();
820 flops = 2.0 * m * n * k;
821 mops = 2.0 * ( m * n + m * k + k * n );
823 if ( arg->GetCommRank() == arg->GetCommSize() / 2 )
825 auto &rskels = arg->child->data.skels;
826 auto m = skels.size();
828 auto k = rskels.size();
829 flops = 2.0 * m * n * k;
830 mops = 2.0 * ( m * n + m * k + k * n );
835 event.Set( label + name, flops, mops );
842 void DependencyAnalysis() { arg->DependOnChildren(
this ); };
844 void Execute(
Worker* user_worker ) { DistUpdateWeights( arg ); };
940 template<
typename NODE,
typename LETNODE,
typename T>
947 vector<LETNODE*> Sources;
953 int *num_arrived_subtasks;
955 void Set( NODE *user_arg, vector<LETNODE*> user_src,
int user_p,
Lock *user_lock,
956 int *user_num_arrived_subtasks )
962 num_arrived_subtasks = user_num_arrived_subtasks;
963 name = string(
"S2S" );
964 label = to_string( arg->treelist_id );
967 double flops = 0.0, mops = 0.0;
968 size_t nrhs = arg->setup->w->col();
969 size_t m = arg->data.skels.size();
970 for (
auto src : Sources )
972 size_t k = src->data.skels.size();
973 flops += 2 * m * k * nrhs;
974 mops += 2 * ( m * k + ( m + k ) * nrhs );
975 flops += 2 * m * nrhs;
976 flops += m * k * ( 2 * 18 + 100 );
979 event.Set( label + name, flops, mops );
983 if ( arg->treelist_id == 0 ) priority =
true;
986 void DependencyAnalysis()
988 if ( p == hmlp_get_mpi_rank() )
990 for (
auto src : Sources ) src->DependencyAnalysis( R,
this );
999 if ( !node->parent || !node->data.isskel )
return;
1000 size_t nrhs = node->setup->w->col();
1001 auto &K = *node->setup->K;
1002 auto &I = node->data.skels;
1005 Data<T> u( I.size(), nrhs, 0.0 );
1007 for (
auto src : Sources )
1009 auto &J = src->data.skels;
1010 auto &w = src->data.w_skel;
1011 bool is_cached =
true;
1013 auto &KIJ = node->DistFar[ p ][ src->morton ];
1014 if ( KIJ.row() != I.size() || KIJ.col() != J.size() )
1021 assert( w.col() == nrhs );
1022 assert( w.row() == J.size() );
1030 gemm::xgemm( (T)1.0, KIJ, w, (T)1.0, u );
1036 KIJ.shrink_to_fit();
1042 auto &u_skel = node->data.u_skel;
1043 for (
int i = 0; i < u.size(); i ++ )
1044 u_skel[ i ] += u[ i ];
1047 #pragma omp atomic update 1048 *num_arrived_subtasks += 1;
1052 template<
typename NODE,
typename LETNODE,
typename T>
1059 vector<S2STask2<NODE, LETNODE, T>*> subtasks;
1063 int num_arrived_subtasks = 0;
1065 const size_t batch_size = 2;
1070 name = string(
"S2SR" );
1071 label = to_string( arg->treelist_id );
1076 size_t nrhs = arg->setup->w->col();
1077 auto &I = arg->data.skels;
1078 arg->data.u_skel.resize( 0, 0 );
1079 arg->data.u_skel.resize( I.size(), nrhs, 0 );
1083 for (
int p = 0; p < hmlp_get_mpi_size(); p ++ )
1085 vector<LETNODE*> Sources;
1086 for (
auto &it : arg->DistFar[ p ] )
1088 Sources.push_back( (*arg->morton2node)[ it.first ] );
1089 if ( Sources.size() == batch_size )
1092 subtasks.back()->Submit();
1093 subtasks.back()->Set( user_arg, Sources, p, &lock, &num_arrived_subtasks );
1094 subtasks.back()->DependencyAnalysis();
1098 if ( Sources.size() )
1101 subtasks.back()->Submit();
1102 subtasks.back()->Set( user_arg, Sources, p, &lock, &num_arrived_subtasks );
1103 subtasks.back()->DependencyAnalysis();
1108 double flops = 0, mops = 0;
1110 event.Set( label + name, flops, mops );
1115 void DependencyAnalysis()
1117 for (
auto task : subtasks ) Scheduler::DependencyAdd( task,
this );
1118 arg->DependencyAnalysis( RW,
this );
1125 assert( num_arrived_subtasks == subtasks.size() );
1148 template<
bool NNPRUNE,
typename NODE,
typename T>
1149 void DistSkeletonsToNodes( NODE *node )
1152 auto comm = node->GetComm();
1153 auto size = node->GetCommSize();
1154 auto rank = node->GetCommRank();
1158 auto &K = *node->setup->K;
1159 auto &w = *node->setup->w;
1162 size_t nrhs = w.col();
1166 if ( !node->parent || !node->data.isskel )
return;
1171 gofmm::SkeletonsToNodes( node );
1175 auto &data = node->data;
1176 auto &proj = data.proj;
1177 auto &u_skel = data.u_skel;
1181 size_t sl = node->child->data.skels.size();
1182 size_t sr = proj.col() - sl;
1184 mpi::SendVector( u_skel, size / 2, 0, comm );
1186 View<T> P(
true, proj ), PL, PR;
1187 View<T> U(
false, u_skel ), UL(
false, node->child->data.u_skel );
1192 gemm::xgemm<GEMM_NB>( (T)1.0, PL, U, (T)1.0, UL );
1196 if ( rank == size / 2 )
1198 size_t s = proj.row();
1199 size_t sr = node->child->data.skels.size();
1200 size_t sl = proj.col() - sr;
1202 mpi::RecvVector( u_skel, 0, 0, comm, &status );
1203 u_skel.resize( s, nrhs );
1205 View<T> P(
true, proj ), PL, PR;
1206 View<T> U(
false, u_skel ), UR(
false, node->child->data.u_skel );
1211 gemm::xgemm<GEMM_NB>( (T)1.0, PR, U, (T)1.0, UR );
1220 template<
bool NNPRUNE,
typename NODE,
typename T>
1230 name = string(
"PS2N" );
1231 label = to_string( arg->l );
1233 double flops = 0.0, mops = 0.0;
1234 auto &gids = arg->gids;
1235 auto &skels = arg->data.skels;
1236 auto &w = *arg->setup->w;
1242 auto m = skels.size();
1244 auto k = gids.size();
1245 flops = 2.0 * m * n * k;
1246 mops = 2.0 * ( m * n + m * k + k * n );
1250 auto &lskels = arg->lchild->data.skels;
1251 auto &rskels = arg->rchild->data.skels;
1252 auto m = skels.size();
1254 auto k = lskels.size() + rskels.size();
1255 flops = 2.0 * m * n * k;
1256 mops = 2.0 * ( m * n + m * k + k * n );
1261 if ( arg->GetCommRank() == 0 )
1263 auto &lskels = arg->child->data.skels;
1264 auto m = skels.size();
1266 auto k = lskels.size();
1267 flops = 2.0 * m * n * k;
1268 mops = 2.0 * ( m * n + m * k + k * n );
1270 if ( arg->GetCommRank() == arg->GetCommSize() / 2 )
1272 auto &rskels = arg->child->data.skels;
1273 auto m = skels.size();
1275 auto k = rskels.size();
1276 flops = 2.0 * m * n * k;
1277 mops = 2.0 * ( m * n + m * k + k * n );
1282 event.Set( label + name, flops, mops );
1284 cost = flops / 1E+9;
1289 void DependencyAnalysis() { arg->DependOnParent(
this ); };
1291 void Execute(
Worker* user_worker ) { DistSkeletonsToNodes<NNPRUNE, NODE, T>( arg ); };
1297 template<
typename NODE,
typename T>
1312 int *num_arrived_subtasks;
1314 void Set( NODE *user_arg, vector<NODE*> user_src,
int user_p,
Lock *user_lock,
1315 int* user_num_arrived_subtasks )
1321 num_arrived_subtasks = user_num_arrived_subtasks;
1322 name = string(
"L2L" );
1323 label = to_string( arg->treelist_id );
1326 double flops = 0.0, mops = 0.0;
1327 size_t nrhs = arg->setup->w->col();
1328 size_t m = arg->gids.size();
1329 for (
auto src : Sources )
1331 size_t k = src->gids.size();
1332 flops += 2 * m * k * nrhs;
1333 mops += 2 * ( m * k + ( m + k ) * nrhs );
1334 flops += 2 * m * nrhs;
1335 flops += m * k * ( 2 * 18 + 100 );
1338 event.Set( label + name, flops, mops );
1340 cost = flops / 1E+9;
1348 if ( p != hmlp_get_mpi_rank() )
1356 size_t nrhs = node->setup->w->col();
1357 auto &K = *node->setup->K;
1358 auto &I = node->gids;
1360 double beg = omp_get_wtime();
1362 Data<T> u( I.size(), nrhs, 0.0 );
1365 for (
auto src : Sources )
1368 View<T> &W = src->data.w_view;
1369 Data<T> &w = src->data.w_leaf;
1371 bool is_cached =
true;
1372 auto &J = src->gids;
1373 auto &KIJ = node->DistNear[ p ][ src->morton ];
1374 if ( KIJ.row() != I.size() || KIJ.col() != J.size() )
1380 if ( W.
col() == nrhs && W.
row() == J.size() )
1385 "N",
"N", u.row(), u.col(), W.
row(),
1386 1.0, KIJ.data(), KIJ.row(),
1388 1.0, u.data(), u.row()
1396 "N",
"N", u.row(), u.col(), w.
row(),
1397 1.0, KIJ.data(), KIJ.row(),
1399 1.0, u.data(), u.row()
1407 KIJ.shrink_to_fit();
1411 double lock_beg = omp_get_wtime();
1416 for (
int j = 0; j < u.col(); j ++ )
1417 for (
int i = 0; i < u.row(); i ++ )
1418 U( i, j ) += u( i, j );
1421 double lock_time = omp_get_wtime() - lock_beg;
1423 double gemm_time = omp_get_wtime() - beg;
1424 double GFLOPS = 2.0 * u.row() * u.col() * k / ( 1E+9 * gemm_time );
1427 #pragma omp atomic update 1428 *num_arrived_subtasks += 1;
1435 template<
typename NODE,
typename T>
1442 vector<L2LTask2<NODE, T>*> subtasks;
1446 int num_arrived_subtasks = 0;
1448 const size_t batch_size = 2;
1453 name = string(
"L2LR" );
1454 label = to_string( arg->treelist_id );
1456 for (
int p = 0; p < hmlp_get_mpi_size(); p ++ )
1458 vector<NODE*> Sources;
1459 for (
auto &it : arg->DistNear[ p ] )
1461 Sources.push_back( (*arg->morton2node)[ it.first ] );
1462 if ( Sources.size() == batch_size )
1465 subtasks.back()->Submit();
1466 subtasks.back()->Set( user_arg, Sources, p, &lock, &num_arrived_subtasks );
1467 subtasks.back()->DependencyAnalysis();
1471 if ( Sources.size() )
1474 subtasks.back()->Submit();
1475 subtasks.back()->Set( user_arg, Sources, p, &lock, &num_arrived_subtasks );
1476 subtasks.back()->DependencyAnalysis();
1485 double flops = 0, mops = 0;
1487 event.Set( label + name, flops, mops );
1492 void DependencyAnalysis()
1494 for (
auto task : subtasks ) Scheduler::DependencyAdd( task,
this );
1495 arg->DependencyAnalysis( RW,
this );
1499 void Execute(
Worker* user_worker )
1501 assert( num_arrived_subtasks == subtasks.size() );
1537 template<
typename TREE>
1538 void FindNearInteractions( TREE &tree )
1540 mpi::PrintProgress(
"[BEG] Finish FindNearInteractions ...", tree.GetComm() );
1542 using NODE =
typename TREE::NODE;
1543 auto &setup = tree.setup;
1544 auto &NN = *setup.NN;
1545 double budget = setup.Budget();
1546 size_t n_leafs = ( 1 << tree.depth );
1555 auto level_beg = tree.treelist.begin() + n_leafs - 1;
1558 #pragma omp parallel for 1559 for (
size_t node_ind = 0; node_ind < n_leafs; node_ind ++ )
1561 auto *node = *(level_beg + node_ind);
1562 auto &data = node->data;
1563 size_t n_nodes = ( 1 << node->l );
1566 node->NNNearNodes.insert( node );
1567 node->NNNearNodeMortonIDs.insert( node->morton );
1570 multimap<size_t, size_t> sorted_ballot = gofmm::NearNodeBallots( node );
1573 for (
auto it = sorted_ballot.rbegin();
1574 it != sorted_ballot.rend(); it ++ )
1577 if ( node->NNNearNodes.size() >= n_nodes * budget )
break;
1585 #pragma omp critical 1587 if ( !(*node->morton2node).count( (*it).second ) )
1590 (*node->morton2node)[ (*it).second ] =
new NODE( (*it).second );
1593 auto *target = (*node->morton2node)[ (*it).second ];
1594 node->NNNearNodeMortonIDs.insert( (*it).second );
1595 node->NNNearNodes.insert( target );
1599 mpi::PrintProgress(
"[END] Finish FindNearInteractions ...", tree.GetComm() );
1605 template<
typename NODE>
1606 void FindFarNodes( MortonHelper::Recursor r, NODE *target )
1609 if ( r.second > target->l )
return;
1611 size_t node_morton = MortonHelper::MortonID( r );
1614 auto & NearMortonIDs = target->NNNearNodeMortonIDs;
1617 if ( MortonHelper::ContainAny( node_morton, NearMortonIDs ) )
1619 FindFarNodes( MortonHelper::RecurLeft( r ), target );
1620 FindFarNodes( MortonHelper::RecurRight( r ), target );
1624 if ( node_morton >= target->morton )
1625 target->NNFarNodeMortonIDs.insert( node_morton );
1634 template<
typename TREE>
1635 void SymmetrizeNearInteractions( TREE & tree )
1637 mpi::PrintProgress(
"[BEG] SymmetrizeNearInteractions ...", tree.GetComm() );
1640 using NODE =
typename TREE::NODE;
1642 int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
1643 int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
1645 vector<vector<pair<size_t, size_t>>> sendlist( comm_size );
1646 vector<vector<pair<size_t, size_t>>> recvlist( comm_size );
1655 int n_nodes = 1 << tree.depth;
1656 auto level_beg = tree.treelist.begin() + n_nodes - 1;
1658 #pragma omp parallel 1661 vector<vector<pair<size_t, size_t>>> list( comm_size );
1664 for (
int node_ind = 0; node_ind < n_nodes; node_ind ++ )
1666 auto *node = *(level_beg + node_ind);
1668 for (
auto it : node->NNNearNodeMortonIDs )
1670 int dest = tree.Morton2Rank( it );
1671 if ( dest >= comm_size ) printf(
"%8lu dest %d\n", it, dest );
1672 list[ dest ].push_back( make_pair( it, node->morton ) );
1676 #pragma omp critical 1678 for (
int p = 0; p < comm_size; p ++ )
1680 sendlist[ p ].insert( sendlist[ p ].end(),
1681 list[ p ].begin(), list[ p ].end() );
1688 mpi::AlltoallVector( sendlist, recvlist, tree.GetComm() );
1692 for (
int p = 0; p < comm_size; p ++ )
1694 for (
auto & query : recvlist[ p ] )
1697 #pragma omp critical 1699 auto* node = tree.morton2node[ query.first ];
1700 if ( !tree.morton2node.count( query.second ) )
1702 tree.morton2node[ query.second ] =
new NODE( query.second );
1704 node->data.lock.Acquire();
1706 node->NNNearNodes.insert( tree.morton2node[ query.second ] );
1707 node->NNNearNodeMortonIDs.insert( query.second );
1709 node->data.lock.Release();
1713 mpi::Barrier( tree.GetComm() );
1714 mpi::PrintProgress(
"[END] SymmetrizeNearInteractions ...", tree.GetComm() );
1718 template<
typename TREE>
1719 void SymmetrizeFarInteractions( TREE & tree )
1721 mpi::PrintProgress(
"[BEG] SymmetrizeFarInteractions ...", tree.GetComm() );
1724 using NODE =
typename TREE::NODE;
1729 vector<vector<pair<size_t, size_t>>> sendlist( tree.GetCommSize() );
1730 vector<vector<pair<size_t, size_t>>> recvlist( tree.GetCommSize() );
1733 #pragma omp parallel 1736 vector<vector<pair<size_t, size_t>>> list( tree.GetCommSize() );
1739 for (
size_t i = 1; i < tree.treelist.size(); i ++ )
1741 auto *node = tree.treelist[ i ];
1742 for (
auto it = node->NNFarNodeMortonIDs.begin();
1743 it != node->NNFarNodeMortonIDs.end(); it ++ )
1746 #pragma omp critical 1748 if ( !tree.morton2node.count( *it ) )
1750 tree.morton2node[ *it ] =
new NODE( *it );
1752 node->NNFarNodes.insert( tree.morton2node[ *it ] );
1754 int dest = tree.Morton2Rank( *it );
1755 if ( dest >= tree.GetCommSize() ) printf(
"%8lu dest %d\n", *it, dest );
1756 list[ dest ].push_back( make_pair( *it, node->morton ) );
1760 #pragma omp critical 1762 for (
int p = 0; p < tree.GetCommSize(); p ++ )
1764 sendlist[ p ].insert( sendlist[ p ].end(),
1765 list[ p ].begin(), list[ p ].end() );
1772 #pragma omp parallel 1775 vector<vector<pair<size_t, size_t>>> list( tree.GetCommSize() );
1778 for (
size_t i = 0; i < tree.mpitreelists.size(); i ++ )
1780 auto *node = tree.mpitreelists[ i ];
1781 for (
auto it = node->NNFarNodeMortonIDs.begin();
1782 it != node->NNFarNodeMortonIDs.end(); it ++ )
1785 #pragma omp critical 1787 if ( !tree.morton2node.count( *it ) )
1789 tree.morton2node[ *it ] =
new NODE( *it );
1791 node->NNFarNodes.insert( tree.morton2node[ *it ] );
1793 int dest = tree.Morton2Rank( *it );
1794 if ( dest >= tree.GetCommSize() ) printf(
"%8lu dest %d\n", *it, dest ); fflush( stdout );
1795 list[ dest ].push_back( make_pair( *it, node->morton ) );
1799 #pragma omp critical 1801 for (
int p = 0; p < tree.GetCommSize(); p ++ )
1803 sendlist[ p ].insert( sendlist[ p ].end(),
1804 list[ p ].begin(), list[ p ].end() );
1810 mpi::AlltoallVector( sendlist, recvlist, tree.GetComm() );
1813 for (
int p = 0; p < tree.GetCommSize(); p ++ )
1816 for (
auto & query : recvlist[ p ] )
1819 #pragma omp critical 1821 if ( !tree.morton2node.count( query.second ) )
1823 tree.morton2node[ query.second ] =
new NODE( query.second );
1827 auto* node = tree.morton2node[ query.first ];
1828 node->data.lock.Acquire();
1830 node->NNFarNodes.insert( tree.morton2node[ query.second ] );
1831 node->NNFarNodeMortonIDs.insert( query.second );
1833 node->data.lock.Release();
1834 assert( tree.Morton2Rank( node->morton ) == tree.GetCommRank() );
1839 mpi::Barrier( tree.GetComm() );
1840 mpi::PrintProgress(
"[END] SymmetrizeFarInteractions ...", tree.GetComm() );
1860 template<
typename TREE>
1861 void BuildInteractionListPerRank( TREE &tree,
bool is_near )
1864 using T =
typename TREE::T;
1866 int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
1867 int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
1870 vector<set<size_t>> lists( comm_size );
1875 int n_nodes = 1 << tree.depth;
1876 auto level_beg = tree.treelist.begin() + n_nodes - 1;
1878 #pragma omp parallel 1881 vector<set<size_t>> list( comm_size );
1884 for (
int node_ind = 0; node_ind < n_nodes; node_ind ++ )
1886 auto *node = *(level_beg + node_ind);
1887 auto & NearMortonIDs = node->NNNearNodeMortonIDs;
1888 node->DistNear.resize( comm_size );
1889 for (
auto it : NearMortonIDs )
1891 int dest = tree.Morton2Rank( it );
1892 if ( dest >= comm_size ) printf(
"%8lu dest %d\n", it, dest );
1893 if ( dest != comm_rank ) list[ dest ].insert( node->morton );
1894 node->DistNear[ dest ][ it ] =
Data<T>();
1898 #pragma omp critical 1900 for (
int p = 0; p < comm_size; p ++ )
1901 lists[ p ].insert( list[ p ].begin(), list[ p ].end() );
1907 vector<vector<size_t>> recvlist( comm_size );
1908 if ( !tree.NearSentToRank.size() ) tree.NearSentToRank.resize( comm_size );
1909 if ( !tree.NearRecvFromRank.size() ) tree.NearRecvFromRank.resize( comm_size );
1910 #pragma omp parallel for 1911 for (
int p = 0; p < comm_size; p ++ )
1913 tree.NearSentToRank[ p ].insert( tree.NearSentToRank[ p ].end(),
1914 lists[ p ].begin(), lists[ p ].end() );
1918 mpi::AlltoallVector( tree.NearSentToRank, recvlist, tree.GetComm() );
1921 #pragma omp parallel for 1922 for (
int p = 0; p < comm_size; p ++ )
1923 for (
int i = 0; i < recvlist[ p ].size(); i ++ )
1924 tree.NearRecvFromRank[ p ][ recvlist[ p ][ i ] ] = i;
1928 #pragma omp parallel 1931 vector<set<size_t>> list( comm_size );
1935 for (
size_t i = 1; i < tree.treelist.size(); i ++ )
1937 auto *node = tree.treelist[ i ];
1938 node->DistFar.resize( comm_size );
1939 for (
auto it = node->NNFarNodeMortonIDs.begin();
1940 it != node->NNFarNodeMortonIDs.end(); it ++ )
1942 int dest = tree.Morton2Rank( *it );
1943 if ( dest >= comm_size ) printf(
"%8lu dest %d\n", *it, dest );
1944 if ( dest != comm_rank )
1946 list[ dest ].insert( node->morton );
1949 node->DistFar[ dest ][ *it ] =
Data<T>();
1955 for (
size_t i = 0; i < tree.mpitreelists.size(); i ++ )
1957 auto *node = tree.mpitreelists[ i ];
1958 node->DistFar.resize( comm_size );
1960 if ( tree.Morton2Rank( node->morton ) == comm_rank )
1962 for (
auto it = node->NNFarNodeMortonIDs.begin();
1963 it != node->NNFarNodeMortonIDs.end(); it ++ )
1965 int dest = tree.Morton2Rank( *it );
1966 if ( dest >= comm_size ) printf(
"%8lu dest %d\n", *it, dest );
1967 if ( dest != comm_rank )
1969 list[ dest ].insert( node->morton );
1972 node->DistFar[ dest ][ *it ] =
Data<T>();
1977 #pragma omp critical 1979 for (
int p = 0; p < comm_size; p ++ )
1980 lists[ p ].insert( list[ p ].begin(), list[ p ].end() );
1986 vector<vector<size_t>> recvlist( comm_size );
1987 if ( !tree.FarSentToRank.size() ) tree.FarSentToRank.resize( comm_size );
1988 if ( !tree.FarRecvFromRank.size() ) tree.FarRecvFromRank.resize( comm_size );
1989 #pragma omp parallel for 1990 for (
int p = 0; p < comm_size; p ++ )
1992 tree.FarSentToRank[ p ].insert( tree.FarSentToRank[ p ].end(),
1993 lists[ p ].begin(), lists[ p ].end() );
1998 mpi::AlltoallVector( tree.FarSentToRank, recvlist, tree.GetComm() );
2001 #pragma omp parallel for 2002 for (
int p = 0; p < comm_size; p ++ )
2003 for (
int i = 0; i < recvlist[ p ].size(); i ++ )
2004 tree.FarRecvFromRank[ p ][ recvlist[ p ][ i ] ] = i;
2007 mpi::Barrier( tree.GetComm() );
2011 template<
typename TREE>
2012 pair<double, double> NonCompressedRatio( TREE &tree )
2015 int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
2016 int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
2019 double ratio_n = 0.0;
2020 double ratio_f = 0.0;
2024 for (
auto &tar : tree.treelist )
2028 for (
auto nearID : tar->NNNearNodeMortonIDs )
2030 auto *src = tree.morton2node[ nearID ];
2032 double m = tar->gids.size();
2033 double n = src->gids.size();
2035 ratio_n += ( m / N ) * ( n / N );
2039 for (
auto farID : tar->NNFarNodeMortonIDs )
2041 auto *src = tree.morton2node[ farID ];
2043 double m = tar->data.skels.size();
2044 double n = src->data.skels.size();
2046 ratio_f += ( m / N ) * ( n / N );
2051 for (
auto &tar : tree.mpitreelists )
2053 if ( !tar->child || tar->GetCommRank() )
continue;
2054 for (
auto farID : tar->NNFarNodeMortonIDs )
2056 auto *src = tree.morton2node[ farID ];
2058 double m = tar->data.skels.size();
2059 double n = src->data.skels.size();
2061 ratio_f += ( m / N ) * ( n / N );
2066 pair<double, double> ret( 0, 0 );
2067 mpi::Allreduce( &ratio_n, &(ret.first), 1, MPI_SUM, tree.GetComm() );
2068 mpi::Allreduce( &ratio_f, &(ret.second), 1, MPI_SUM, tree.GetComm() );
2075 template<
typename T,
typename TREE>
2076 void PackNear( TREE &tree,
string option,
int p,
2077 vector<size_t> &sendsizes,
2078 vector<size_t> &sendskels,
2079 vector<T> &sendbuffs )
2081 vector<size_t> offsets( 1, 0 );
2083 for (
auto it : tree.NearSentToRank[ p ] )
2085 auto *node = tree.morton2node[ it ];
2086 auto &gids = node->gids;
2087 if ( !option.compare(
string(
"leafgids" ) ) )
2089 sendsizes.push_back( gids.size() );
2090 sendskels.insert( sendskels.end(), gids.begin(), gids.end() );
2094 auto &w_view = node->data.w_view;
2095 sendsizes.push_back( gids.size() * w_view.col() );
2096 offsets.push_back( sendsizes.back() + offsets.back() );
2100 if ( offsets.size() ) sendbuffs.resize( offsets.back() );
2102 if ( !option.compare(
string(
"leafweights" ) ) )
2104 #pragma omp parallel for 2105 for (
size_t i = 0; i < tree.NearSentToRank[ p ].size(); i ++ )
2107 auto *node = tree.morton2node[ tree.NearSentToRank[ p ][ i ] ];
2108 auto &gids = node->gids;
2109 auto &w_view = node->data.w_view;
2110 auto w_leaf = w_view.toData();
2111 size_t offset = offsets[ i ];
2112 for (
size_t j = 0; j < w_leaf.size(); j ++ )
2113 sendbuffs[ offset + j ] = w_leaf[ j ];
2119 template<
typename T,
typename TREE>
2120 void UnpackLeaf( TREE &tree,
string option,
int p,
2121 const vector<size_t> &recvsizes,
2122 const vector<size_t> &recvskels,
2123 const vector<T> &recvbuffs )
2125 vector<size_t> offsets( 1, 0 );
2126 for (
auto it : recvsizes ) offsets.push_back( offsets.back() + it );
2128 for (
auto it : tree.NearRecvFromRank[ p ] )
2130 auto *node = tree.morton2node[ it.first ];
2131 if ( !option.compare(
string(
"leafgids" ) ) )
2133 auto &gids = node->gids;
2134 size_t i = it.second;
2135 gids.reserve( recvsizes[ i ] );
2136 for ( uint64_t j = offsets[ i + 0 ];
2137 j < offsets[ i + 1 ];
2140 gids.push_back( recvskels[ j ] );
2146 size_t nrhs = tree.setup.w->col();
2147 auto &w_leaf = node->data.w_leaf;
2148 size_t i = it.second;
2149 w_leaf.resize( recvsizes[ i ] / nrhs, nrhs );
2152 for ( uint64_t j = offsets[ i + 0 ], jj = 0;
2153 j < offsets[ i + 1 ];
2156 w_leaf[ jj ] = recvbuffs[ j ];
2163 template<
typename T,
typename TREE>
2164 void PackFar( TREE &tree,
string option,
int p,
2165 vector<size_t> &sendsizes,
2166 vector<size_t> &sendskels,
2167 vector<T> &sendbuffs )
2169 for (
auto it : tree.FarSentToRank[ p ] )
2171 auto *node = tree.morton2node[ it ];
2172 auto &skels = node->data.skels;
2173 if ( !option.compare(
string(
"skelgids" ) ) )
2175 sendsizes.push_back( skels.size() );
2176 sendskels.insert( sendskels.end(), skels.begin(), skels.end() );
2180 auto &w_skel = node->data.w_skel;
2181 sendsizes.push_back( w_skel.size() );
2182 sendbuffs.insert( sendbuffs.end(), w_skel.begin(), w_skel.end() );
2199 template<
typename TREE,
typename T>
2200 void PackWeights( TREE &tree,
int p,
2201 vector<T> &sendbuffs, vector<size_t> &sendsizes )
2203 for (
auto it : tree.NearSentToRank[ p ] )
2205 auto *node = tree.morton2node[ it ];
2206 auto w_leaf = node->data.w_view.toData();
2207 sendbuffs.insert( sendbuffs.end(), w_leaf.begin(), w_leaf.end() );
2208 sendsizes.push_back( w_leaf.size() );
2214 template<
typename TREE,
typename T>
2215 void UnpackWeights( TREE &tree,
int p,
2216 const vector<T> recvbuffs,
const vector<size_t> &recvsizes )
2218 vector<size_t> offsets( 1, 0 );
2219 for (
auto it : recvsizes ) offsets.push_back( offsets.back() + it );
2221 for (
auto it : tree.NearRecvFromRank[ p ] )
2224 auto *node = tree.morton2node[ it.first ];
2226 size_t nrhs = tree.setup.w->col();
2227 auto &w_leaf = node->data.w_leaf;
2228 size_t i = it.second;
2229 w_leaf.resize( recvsizes[ i ] / nrhs, nrhs );
2230 for ( uint64_t j = offsets[ i + 0 ], jj = 0;
2231 j < offsets[ i + 1 ];
2234 w_leaf[ jj ] = recvbuffs[ j ];
2242 template<
typename TREE>
2243 void PackSkeletons( TREE &tree,
int p,
2244 vector<size_t> &sendbuffs, vector<size_t> &sendsizes )
2246 for (
auto it : tree.FarSentToRank[ p ] )
2249 auto *node = tree.morton2node[ it ];
2250 auto &skels = node->data.skels;
2251 sendbuffs.insert( sendbuffs.end(), skels.begin(), skels.end() );
2252 sendsizes.push_back( skels.size() );
2258 template<
typename TREE>
2259 void UnpackSkeletons( TREE &tree,
int p,
2260 const vector<size_t> recvbuffs,
const vector<size_t> &recvsizes )
2262 vector<size_t> offsets( 1, 0 );
2263 for (
auto it : recvsizes ) offsets.push_back( offsets.back() + it );
2265 for (
auto it : tree.FarRecvFromRank[ p ] )
2268 auto *node = tree.morton2node[ it.first ];
2269 auto &skels = node->data.skels;
2270 size_t i = it.second;
2272 skels.reserve( recvsizes[ i ] );
2273 for ( uint64_t j = offsets[ i + 0 ];
2274 j < offsets[ i + 1 ];
2277 skels.push_back( recvbuffs[ j ] );
2285 template<
typename TREE,
typename T>
2286 void PackSkeletonWeights( TREE &tree,
int p,
2287 vector<T> &sendbuffs, vector<size_t> &sendsizes )
2289 for (
auto it : tree.FarSentToRank[ p ] )
2291 auto *node = tree.morton2node[ it ];
2292 auto &w_skel = node->data.w_skel;
2293 sendbuffs.insert( sendbuffs.end(), w_skel.begin(), w_skel.end() );
2294 sendsizes.push_back( w_skel.size() );
2300 template<
typename TREE,
typename T>
2301 void UnpackSkeletonWeights( TREE &tree,
int p,
2302 const vector<T> recvbuffs,
const vector<size_t> &recvsizes )
2304 vector<size_t> offsets( 1, 0 );
2305 for (
auto it : recvsizes ) offsets.push_back( offsets.back() + it );
2307 for (
auto it : tree.FarRecvFromRank[ p ] )
2310 auto *node = tree.morton2node[ it.first ];
2312 size_t nrhs = tree.setup.w->col();
2313 auto &w_skel = node->data.w_skel;
2314 size_t i = it.second;
2315 w_skel.resize( recvsizes[ i ] / nrhs, nrhs );
2316 for ( uint64_t j = offsets[ i + 0 ], jj = 0;
2317 j < offsets[ i + 1 ];
2320 w_skel[ jj ] = recvbuffs[ j ];
2330 template<
typename T,
typename TREE>
2331 void UnpackFar( TREE &tree,
string option,
int p,
2332 const vector<size_t> &recvsizes,
2333 const vector<size_t> &recvskels,
2334 const vector<T> &recvbuffs )
2336 vector<size_t> offsets( 1, 0 );
2337 for (
auto it : recvsizes ) offsets.push_back( offsets.back() + it );
2339 for (
auto it : tree.FarRecvFromRank[ p ] )
2342 auto *node = tree.morton2node[ it.first ];
2343 if ( !option.compare(
string(
"skelgids" ) ) )
2345 auto &skels = node->data.skels;
2346 size_t i = it.second;
2348 skels.reserve( recvsizes[ i ] );
2349 for ( uint64_t j = offsets[ i + 0 ];
2350 j < offsets[ i + 1 ];
2353 skels.push_back( recvskels[ j ] );
2359 size_t nrhs = tree.setup.w->col();
2360 auto &w_skel = node->data.w_skel;
2361 size_t i = it.second;
2362 w_skel.resize( recvsizes[ i ] / nrhs, nrhs );
2366 for ( uint64_t j = offsets[ i + 0 ], jj = 0;
2367 j < offsets[ i + 1 ];
2370 w_skel[ jj ] = recvbuffs[ j ];
2379 template<
typename T,
typename TREE>
2385 :
SendTask<T, TREE>( tree, src, tar, key )
2389 this->DependencyAnalysis();
2392 void DependencyAnalysis()
2394 TREE &tree = *(this->arg);
2395 tree.DependOnNearInteractions( this->tar,
this );
2401 PackWeights( *this->arg, this->tar,
2402 this->send_buffs, this->send_sizes );
2421 template<
typename T,
typename TREE>
2427 :
RecvTask<T, TREE>( tree, src, tar, key )
2431 this->DependencyAnalysis();
2436 UnpackWeights( *this->arg, this->src,
2437 this->recv_buffs, this->recv_sizes );
2444 template<
typename T,
typename TREE>
2450 :
SendTask<T, TREE>( tree, src, tar, key )
2454 this->DependencyAnalysis();
2457 void DependencyAnalysis()
2459 TREE &tree = *(this->arg);
2460 tree.DependOnFarInteractions( this->tar,
this );
2466 PackSkeletonWeights( *this->arg, this->tar,
2467 this->send_buffs, this->send_sizes );
2474 template<
typename T,
typename TREE>
2480 :
RecvTask<T, TREE>( tree, src, tar, key )
2484 this->DependencyAnalysis();
2489 UnpackSkeletonWeights( *this->arg, this->src,
2490 this->recv_buffs, this->recv_sizes );
2512 template<
typename TREE>
2513 void ExchangeLET( TREE &tree,
string option )
2516 using T =
typename TREE::T;
2518 int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
2519 int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
2522 vector<vector<size_t>> sendsizes( comm_size );
2523 vector<vector<size_t>> recvsizes( comm_size );
2524 vector<vector<size_t>> sendskels( comm_size );
2525 vector<vector<size_t>> recvskels( comm_size );
2526 vector<vector<T>> sendbuffs( comm_size );
2527 vector<vector<T>> recvbuffs( comm_size );
2530 #pragma omp parallel for 2531 for (
int p = 0; p < comm_size; p ++ )
2533 if ( !option.compare( 0, 4,
"leaf" ) )
2535 PackNear( tree, option, p, sendsizes[ p ], sendskels[ p ], sendbuffs[ p ] );
2537 else if ( !option.compare( 0, 4,
"skel" ) )
2539 PackFar( tree, option, p, sendsizes[ p ], sendskels[ p ], sendbuffs[ p ] );
2543 printf(
"ExchangeLET: option <%s> not available.\n", option.data() );
2549 mpi::AlltoallVector( sendsizes, recvsizes, tree.GetComm() );
2550 if ( !option.compare(
string(
"skelgids" ) ) ||
2551 !option.compare(
string(
"leafgids" ) ) )
2553 auto &K = *tree.setup.K;
2554 mpi::AlltoallVector( sendskels, recvskels, tree.GetComm() );
2555 K.RequestIndices( recvskels );
2559 double beg = omp_get_wtime();
2560 mpi::AlltoallVector( sendbuffs, recvbuffs, tree.GetComm() );
2561 double a2av_time = omp_get_wtime() - beg;
2562 if ( comm_rank == 0 ) printf(
"a2av_time %lfs\n", a2av_time );
2567 #pragma omp parallel for 2568 for (
int p = 0; p < comm_size; p ++ )
2570 if ( !option.compare( 0, 4,
"leaf" ) )
2572 UnpackLeaf( tree, option, p, recvsizes[ p ], recvskels[ p ], recvbuffs[ p ] );
2574 else if ( !option.compare( 0, 4,
"skel" ) )
2576 UnpackFar( tree, option, p, recvsizes[ p ], recvskels[ p ], recvbuffs[ p ] );
2580 printf(
"ExchangeLET: option <%s> not available.\n", option.data() );
2590 template<
typename T,
typename TREE>
2591 void AsyncExchangeLET( TREE &tree,
string option )
2594 int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
2595 int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
2598 for (
int p = 0; p < comm_size; p ++ )
2600 if ( !option.compare( 0, 4,
"leaf" ) )
2608 else if ( !option.compare( 0, 4,
"skel" ) )
2618 printf(
"AsyncExchangeLET: option <%s> not available.\n", option.data() );
2624 for (
int p = 0; p < comm_size; p ++ )
2626 if ( !option.compare( 0, 4,
"leaf" ) )
2634 else if ( !option.compare( 0, 4,
"skel" ) )
2644 printf(
"AsyncExchangeLET: option <%s> not available.\n", option.data() );
2654 template<
typename T,
typename TREE>
2655 void ExchangeNeighbors( TREE &tree )
2657 mpi::PrintProgress(
"[BEG] ExchangeNeighbors ...", tree.GetComm() );
2659 int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
2660 int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
2663 vector<vector<size_t>> send_buff( comm_size );
2664 vector<vector<size_t>> recv_buff( comm_size );
2667 unordered_set<size_t> requested_gids;
2668 auto &NN = *tree.setup.NN;
2671 for (
auto & it : NN )
2673 if ( it.second >= 0 && it.second < tree.n )
2674 requested_gids.insert( it.second );
2678 for (
auto it : tree.treelist[ 0 ]->gids ) requested_gids.erase( it );
2681 for (
auto it :requested_gids )
2683 int p = it % comm_size;
2684 if ( p != comm_rank ) send_buff[ p ].push_back( it );
2688 auto &K = *tree.setup.K;
2689 K.RequestIndices( send_buff );
2691 mpi::PrintProgress(
"[END] ExchangeNeighbors ...", tree.GetComm() );
2704 template<
bool SYMMETRIC,
typename NODE,
typename T>
2705 void MergeFarNodes( NODE *node )
2740 assert( !node->FarNodeMortonIDs.size() );
2741 assert( !node->FarNodes.size() );
2742 node->FarNodeMortonIDs.insert( node->sibling->morton );
2743 node->FarNodes.insert( node->sibling );
2748 FindFarNodes( MortonHelper::Root(), node );
2753 auto *lchild = node->lchild;
2754 auto *rchild = node->rchild;
2757 auto &pNNFarNodes = node->NNFarNodeMortonIDs;
2758 auto &lNNFarNodes = lchild->NNFarNodeMortonIDs;
2759 auto &rNNFarNodes = rchild->NNFarNodeMortonIDs;
2762 for (
auto it = lNNFarNodes.begin();
2763 it != lNNFarNodes.end(); it ++ )
2765 if ( rNNFarNodes.count( *it ) )
2767 pNNFarNodes.insert( *it );
2771 for (
auto it = pNNFarNodes.begin();
2772 it != pNNFarNodes.end(); it ++ )
2774 lNNFarNodes.erase( *it );
2775 rNNFarNodes.erase( *it );
2783 template<
bool SYMMETRIC,
typename NODE,
typename T>
2793 name = string(
"merge" );
2794 label = to_string( arg->treelist_id );
2804 arg->DependencyAnalysis( RW,
this );
2807 arg->lchild->DependencyAnalysis( RW,
this );
2808 arg->rchild->DependencyAnalysis( RW,
this );
2813 void Execute(
Worker* user_worker )
2815 MergeFarNodes<SYMMETRIC, NODE, T>( arg );
2831 template<
bool SYMMETRIC,
typename NODE,
typename T>
2832 void DistMergeFarNodes( NODE *node )
2836 mpi::Comm comm = node->GetComm();
2837 int comm_size = node->GetCommSize();
2838 int comm_rank = node->GetCommRank();
2845 if ( !node->parent )
return;
2848 if ( node->GetCommSize() < 2 )
2850 MergeFarNodes<SYMMETRIC, NODE, T>( node );
2855 auto *child = node->child;
2857 if ( comm_rank == 0 )
2859 auto &pNNFarNodes = node->NNFarNodeMortonIDs;
2860 auto &lNNFarNodes = child->NNFarNodeMortonIDs;
2861 vector<size_t> recvFarNodes;
2864 mpi::RecvVector( recvFarNodes, comm_size / 2, 0, comm, &status );
2867 for (
auto it : recvFarNodes )
2869 if ( lNNFarNodes.count( it ) )
2871 pNNFarNodes.insert( it );
2876 recvFarNodes.clear();
2877 recvFarNodes.reserve( pNNFarNodes.size() );
2880 for (
auto it : pNNFarNodes )
2882 lNNFarNodes.erase( it );
2883 recvFarNodes.push_back( it );
2887 mpi::SendVector( recvFarNodes, comm_size / 2, 0, comm );
2891 if ( comm_rank == comm_size / 2 )
2893 auto &rNNFarNodes = child->NNFarNodeMortonIDs;
2894 vector<size_t> sendFarNodes( rNNFarNodes.begin(), rNNFarNodes.end() );
2897 mpi::SendVector( sendFarNodes, 0, 0, comm );
2899 mpi::RecvVector( sendFarNodes, 0, 0, comm, &status );
2901 for (
auto it : sendFarNodes ) rNNFarNodes.erase( it );
2909 template<
bool SYMMETRIC,
typename NODE,
typename T>
2919 name = string(
"dist-merge" );
2920 label = to_string( arg->treelist_id );
2930 arg->DependencyAnalysis( RW,
this );
2933 if ( arg->GetCommSize() > 1 )
2935 arg->child->DependencyAnalysis( RW,
this );
2939 arg->lchild->DependencyAnalysis( RW,
this );
2940 arg->rchild->DependencyAnalysis( RW,
this );
2946 void Execute(
Worker* user_worker )
2948 DistMergeFarNodes<SYMMETRIC, NODE, T>( arg );
2961 template<
bool NNPRUNE,
typename NODE>
2971 name = string(
"FKIJ" );
2972 label = to_string( arg->treelist_id );
2974 double flops = 0, mops = 0;
2979 void DependencyAnalysis()
2981 arg->DependencyAnalysis( RW,
this );
2985 void Execute(
Worker* user_worker )
2988 auto &K = *node->setup->K;
2990 for (
int p = 0; p < node->DistFar.size(); p ++ )
2992 for (
auto &it : node->DistFar[ p ] )
2994 auto *src = (*node->morton2node)[ it.first ];
2995 auto &I = node->data.skels;
2996 auto &J = src->data.skels;
2997 it.second = K( I, J );
3013 template<
bool NNPRUNE,
typename NODE>
3023 name = string(
"NKIJ" );
3024 label = to_string( arg->treelist_id );
3029 void DependencyAnalysis()
3031 arg->DependencyAnalysis( RW,
this );
3035 void Execute(
Worker* user_worker )
3038 auto &K = *node->setup->K;
3040 for (
int p = 0; p < node->DistNear.size(); p ++ )
3042 for (
auto &it : node->DistNear[ p ] )
3044 auto *src = (*node->morton2node)[ it.first ];
3045 auto &I = node->gids;
3046 auto &J = src->gids;
3047 it.second = K( I, J );
3066 template<
typename NODE,
typename T>
3067 void DistRowSamples( NODE *node,
size_t nsamples )
3070 mpi::Comm comm = node->GetComm();
3071 int size = node->GetCommSize();
3072 int rank = node->GetCommRank();
3075 auto &K = *node->setup->K;
3078 vector<size_t> &I = node->data.candidate_rows;
3087 I.reserve( nsamples );
3089 auto &snids = node->data.snids;
3090 multimap<T, size_t> ordered_snids = gofmm::flip_map( snids );
3092 for (
auto it = ordered_snids.begin();
3093 it != ordered_snids.end(); it++ )
3096 I.push_back( (*it).second );
3097 if ( I.size() >= nsamples )
break;
3102 vector<size_t> candidates( nsamples );
3104 size_t n_required = nsamples - I.size();
3107 mpi::Bcast( &n_required, 1, 0, comm );
3109 while ( n_required )
3113 for (
size_t i = 0; i < nsamples; i ++ )
3115 auto important_sample = K.ImportantSample( 0 );
3116 candidates[ i ] = important_sample.second;
3121 mpi::Bcast( candidates.data(), candidates.size(), 0, comm );
3124 vector<size_t> vconsensus( nsamples, 0 );
3125 vector<size_t> validation = node->setup->ContainAny( candidates, node->morton );
3128 mpi::Reduce( validation.data(), vconsensus.data(), nsamples, MPI_SUM, 0, comm );
3132 for (
size_t i = 0; i < nsamples; i ++ )
3135 if ( I.size() >= nsamples )
3137 I.resize( nsamples );
3141 if ( !vconsensus[ i ] )
3143 if ( find( I.begin(), I.end(), candidates[ i ] ) == I.end() )
3144 I.push_back( candidates[ i ] );
3149 n_required = nsamples - I.size();
3153 mpi::Bcast( &n_required, 1, 0, comm );
3164 template<
bool NNPRUNE,
typename NODE>
3165 void DistSkeletonKIJ( NODE *node )
3168 using T =
typename NODE::T;
3170 if ( !node->parent )
return;
3172 auto &K = *(node->setup->K);
3174 auto &data = node->data;
3175 auto &candidate_rows = data.candidate_rows;
3176 auto &candidate_cols = data.candidate_cols;
3177 auto &KIJ = data.KIJ;
3180 auto comm = node->GetComm();
3181 auto size = node->GetCommSize();
3182 auto rank = node->GetCommRank();
3188 gofmm::SkeletonKIJ<NNPRUNE>( node );
3199 NODE *child = node->child;
3200 size_t nsamples = 0;
3203 int child_isskel = child->data.isskel;
3204 mpi::Bcast( &child_isskel, 1, 0, child->GetComm() );
3205 child->data.isskel = child_isskel;
3211 candidate_cols = child->data.skels;
3212 vector<size_t> rskel;
3214 mpi::RecvVector( rskel, size / 2, 10, comm, &status );
3216 K.RecvIndices( size / 2, comm, &status );
3218 candidate_cols.insert( candidate_cols.end(), rskel.begin(), rskel.end() );
3220 nsamples = 2 * candidate_cols.size();
3222 if ( nsamples < 2 * node->setup->LeafNodeSize() )
3223 nsamples = 2 * node->setup->LeafNodeSize();
3226 auto &lsnids = node->child->data.snids;
3227 vector<T> recv_rsdist;
3228 vector<size_t> recv_rsnids;
3231 mpi::RecvVector( recv_rsdist, size / 2, 20, comm, &status );
3232 mpi::RecvVector( recv_rsnids, size / 2, 30, comm, &status );
3234 K.RecvIndices( size / 2, comm, &status );
3238 auto &snids = node->data.snids;
3241 for (
size_t i = 0; i < recv_rsdist.size(); i ++ )
3243 pair<size_t, T> query( recv_rsnids[ i ], recv_rsdist[ i ] );
3244 auto ret = snids.insert( query );
3247 if ( ret.first->second > recv_rsdist[ i ] )
3248 ret.first->second = recv_rsdist[ i ];
3253 for (
auto gid : node->gids ) snids.erase( gid );
3256 if ( rank == size / 2 )
3259 mpi::SendVector( child->data.skels, 0, 10, comm );
3261 K.SendIndices( child->data.skels, 0, comm );
3264 auto &rsnids = node->child->data.snids;
3265 vector<T> send_rsdist;
3266 vector<size_t> send_rsnids;
3269 send_rsdist.reserve( rsnids.size() );
3270 send_rsnids.reserve( rsnids.size() );
3272 for (
auto it = rsnids.begin(); it != rsnids.end(); it ++ )
3275 send_rsnids.push_back( (*it).first );
3276 send_rsdist.push_back( (*it).second );
3280 mpi::SendVector( send_rsdist, 0, 20, comm );
3281 mpi::SendVector( send_rsnids, 0, 30, comm );
3284 K.SendIndices( send_rsnids, 0, comm );
3288 mpi::Bcast( &nsamples, 1, 0, comm );
3290 DistRowSamples<NODE, T>( node, nsamples );
3294 assert( !candidate_rows.size() );
3295 assert( !candidate_cols.size() );
3302 KIJ = K( candidate_rows, candidate_cols );
3310 template<
bool NNPRUNE,
typename NODE,
typename T>
3320 name = string(
"par-gskm" );
3321 label = to_string( arg->treelist_id );
3328 void DependencyAnalysis() { arg->DependOnChildren(
this ); };
3330 void Execute(
Worker* user_worker ) { DistSkeletonKIJ<NNPRUNE>( arg ); };
3348 template<
typename NODE,
typename T>
3349 void DistSkeletonize( NODE *node )
3352 if ( !node->parent )
return;
3355 auto &K = *(node->setup->K);
3356 auto &NN = *(node->setup->NN);
3357 auto maxs = node->setup->MaximumRank();
3358 auto stol = node->setup->Tolerance();
3359 bool secure_accuracy = node->setup->SecureAccuracy();
3360 bool use_adaptive_ranks = node->setup->UseAdaptiveRanks();
3363 auto &data = node->data;
3364 auto &skels = data.skels;
3365 auto &proj = data.proj;
3366 auto &jpvt = data.jpvt;
3367 auto &KIJ = data.KIJ;
3368 auto &candidate_cols = data.candidate_cols;
3372 size_t m = KIJ.row();
3373 size_t n = KIJ.col();
3376 if ( secure_accuracy )
3383 T scaled_stol = std::sqrt( (T)n / q ) * std::sqrt( (T)m / (N - q) ) * stol;
3386 scaled_stol *= std::sqrt( (T)q / N );
3390 use_adaptive_ranks, secure_accuracy,
3391 KIJ.row(), KIJ.col(), maxs, scaled_stol,
3392 KIJ, skels, proj, jpvt
3397 KIJ.shrink_to_fit();
3400 if ( secure_accuracy )
3403 data.isskel = (skels.size() != 0);
3407 assert( skels.size() );
3408 assert( proj.size() );
3409 assert( jpvt.size() );
3414 for (
size_t i = 0; i < skels.size(); i ++ )
3416 skels[ i ] = candidate_cols[ skels[ i ] ];
3425 template<
typename NODE,
typename T>
3435 name = string(
"SK" );
3436 label = to_string( arg->treelist_id );
3445 double flops = 0.0, mops = 0.0;
3447 auto &K = *arg->setup->K;
3448 size_t n = arg->data.proj.col();
3450 size_t k = arg->data.proj.row();
3453 flops += ( 4.0 / 3.0 ) * n * n * ( 3 * m - n );
3454 mops += ( 2.0 / 3.0 ) * n * n * ( 3 * m - n );
3457 flops += k * ( k - 1 ) * ( n + 1 );
3458 mops += 2.0 * ( k * k + k * n );
3460 event.Set( label + name, flops, mops );
3461 arg->data.skeletonize = event;
3464 void DependencyAnalysis()
3466 arg->DependencyAnalysis( RW,
this );
3470 void Execute(
Worker* user_worker )
3474 DistSkeletonize<NODE, T>( arg );
3487 template<
typename NODE,
typename T>
3497 name = string(
"PSK" );
3498 label = to_string( arg->treelist_id );
3508 double flops = 0.0, mops = 0.0;
3510 auto &K = *arg->setup->K;
3511 size_t n = arg->data.proj.col();
3513 size_t k = arg->data.proj.row();
3515 if ( arg->GetCommRank() == 0 )
3518 flops += ( 4.0 / 3.0 ) * n * n * ( 3 * m - n );
3519 mops += ( 2.0 / 3.0 ) * n * n * ( 3 * m - n );
3522 flops += k * ( k - 1 ) * ( n + 1 );
3523 mops += 2.0 * ( k * k + k * n );
3526 event.Set( label + name, flops, mops );
3527 arg->data.skeletonize = event;
3530 void DependencyAnalysis()
3532 arg->DependencyAnalysis( RW,
this );
3538 mpi::Comm comm = arg->GetComm();
3540 double beg = omp_get_wtime();
3541 if ( arg->GetCommRank() == 0 )
3543 DistSkeletonize<NODE, T>( arg );
3545 double skel_t = omp_get_wtime() - beg;
3548 int isskel = arg->data.isskel;
3549 mpi::Bcast( &isskel, 1, 0, comm );
3550 arg->data.isskel = isskel;
3553 auto &skels = arg->data.skels;
3554 size_t nskels = skels.size();
3555 mpi::Bcast( &nskels, 1, 0, comm );
3556 if ( skels.size() != nskels ) skels.resize( nskels );
3557 mpi::Bcast( skels.data(), skels.size(), 0, comm );
3569 template<
typename NODE>
3576 void Set( NODE *user_arg )
3579 name = string(
"PROJ" );
3580 label = to_string( arg->treelist_id );
3585 void DependencyAnalysis() { arg->DependOnNoOne(
this ); };
3590 auto comm = arg->GetComm();
3592 if ( arg->GetCommRank() == 0 ) gofmm::Interpolate( arg );
3594 auto &proj = arg->data.proj;
3595 size_t nrow = proj.row();
3596 size_t ncol = proj.col();
3597 mpi::Bcast( &nrow, 1, 0, comm );
3598 mpi::Bcast( &ncol, 1, 0, comm );
3599 if ( proj.row() != nrow || proj.col() != ncol ) proj.resize( nrow, ncol );
3600 mpi::Bcast( proj.data(), proj.size(), 0, comm );
3637 template<
bool NNPRUNE = true,
typename TREE,
typename T>
3643 int size; mpi::Comm_size( tree.GetComm(), &size );
3644 int rank; mpi::Comm_rank( tree.GetComm(), &rank );
3646 using NODE =
typename TREE::NODE;
3647 using MPINODE =
typename TREE::MPINODE;
3650 double beg, time_ratio, evaluation_time = 0.0;
3651 double direct_evaluation_time = 0.0, computeall_time, telescope_time, let_exchange_time, async_time;
3652 double overhead_time;
3653 double forward_permute_time, backward_permute_time;
3656 tree.DependencyCleanUp();
3659 size_t n = weights.
row();
3660 size_t nrhs = weights.col();
3663 auto &gids_owned = tree.treelist[ 0 ]->gids;
3665 potentials.setvalue( 0.0 );
3668 tree.setup.w = &weights;
3669 tree.setup.u = &potentials;
3693 mpi::Barrier( tree.GetComm() );
3737 potentials.setvalue( 0.0 );
3738 mpi::Barrier( tree.GetComm() );
3741 beg = omp_get_wtime();
3742 tree.DependencyCleanUp();
3743 tree.DistTraverseDown( mpiVIEWtask );
3744 tree.LocaTraverseDown( seqVIEWtask );
3745 tree.ExecuteAllTasks();
3747 AsyncExchangeLET<T>( tree, string(
"leafweights" ) );
3749 tree.LocaTraverseUp( seqN2Stask );
3750 tree.DistTraverseUp( mpiN2Stask );
3752 AsyncExchangeLET<T>( tree, string(
"skelweights" ) );
3754 tree.LocaTraverseLeafs( seqL2LReducetask2 );
3756 tree.LocaTraverseUnOrdered( seqS2SReducetask2 );
3757 tree.DistTraverseUnOrdered( mpiS2SReducetask2 );
3759 tree.DistTraverseDown( mpiS2Ntask );
3760 tree.LocaTraverseDown( seqS2Ntask );
3761 overhead_time = omp_get_wtime() - beg;
3762 tree.ExecuteAllTasks();
3763 async_time = omp_get_wtime() - beg;
3768 evaluation_time += direct_evaluation_time;
3769 evaluation_time += telescope_time;
3770 evaluation_time += let_exchange_time;
3771 evaluation_time += computeall_time;
3772 time_ratio = 100 / evaluation_time;
3774 if ( rank == 0 && REPORT_EVALUATE_STATUS )
3776 printf(
"========================================================\n");
3777 printf(
"GOFMM evaluation phase\n" );
3778 printf(
"========================================================\n");
3783 printf(
"Upward telescope ---------------------- %5.2lfs (%5.1lf%%)\n",
3784 telescope_time, telescope_time * time_ratio );
3785 printf(
"LET exchange -------------------------- %5.2lfs (%5.1lf%%)\n",
3786 let_exchange_time, let_exchange_time * time_ratio );
3787 printf(
"L2L ----------------------------------- %5.2lfs (%5.1lf%%)\n",
3788 direct_evaluation_time, direct_evaluation_time * time_ratio );
3789 printf(
"S2S, S2N ------------------------------ %5.2lfs (%5.1lf%%)\n",
3790 computeall_time, computeall_time * time_ratio );
3793 printf(
"========================================================\n");
3794 printf(
"Evaluate ------------------------------ %5.2lfs (%5.1lf%%)\n",
3795 evaluation_time, evaluation_time * time_ratio );
3796 printf(
"Evaluate (Async) ---------------------- %5.2lfs (%5.2lfs)\n",
3797 async_time, overhead_time );
3798 printf(
"========================================================\n\n");
3803 catch (
const exception & e )
3805 cout << e.what() << endl;
3813 template<
bool NNPRUNE = true,
typename TREE,
typename T>
3816 size_t n = w_rblk.
row();
3817 size_t nrhs = w_rblk.col();
3822 auto u_rids = Evaluate<NNPRUNE>( tree, w_rids );
3823 mpi::Barrier( tree.GetComm() );
3833 template<
typename SPLITTER,
typename T,
typename SPDMATRIX>
3839 mpi::Comm CommGOFMM,
3848 using NODE =
typename TREE::NODE;
3850 DistanceMetric metric = config.MetricType();
3851 size_t n = config.ProblemSize();
3852 size_t k = config.NeighborSize();
3854 pair<T, size_t> init( numeric_limits<T>::max(), n );
3856 TREE rkdt( CommGOFMM );
3857 rkdt.setup.FromConfiguration( config, K, splitter, NULL );
3858 return rkdt.AllNearestNeighbor( n_iter, n, k, init, NEIGHBORStask );
3872 template<
typename SPLITTER,
typename RKDTSPLITTER,
typename T,
typename SPDMATRIX>
3879 RKDTSPLITTER rkdtsplitter,
3887 int size; mpi::Comm_size( CommGOFMM, &size );
3888 int rank; mpi::Comm_rank( CommGOFMM, &rank );
3891 DistanceMetric metric = config.MetricType();
3892 size_t n = config.ProblemSize();
3893 size_t m = config.LeafNodeSize();
3894 size_t k = config.NeighborSize();
3895 size_t s = config.MaximumRank();
3898 const bool SYMMETRIC =
true;
3899 const bool NNPRUNE =
true;
3900 const bool CACHE =
true;
3907 using NODE =
typename TREE::NODE;
3908 using MPINODE =
typename TREE::MPINODE;
3911 double beg, omptask45_time, omptask_time, ref_time;
3912 double time_ratio, compress_time = 0.0, other_time = 0.0;
3913 double ann_time, tree_time, skel_time, mpi_skel_time, mergefarnodes_time, cachefarnodes_time;
3914 double local_skel_time, dist_skel_time, let_time;
3915 double nneval_time, nonneval_time, fmm_evaluation_time, symbolic_evaluation_time;
3916 double exchange_neighbor_time, symmetrize_time;
3919 beg = omp_get_wtime();
3920 if ( k && NN_cblk.row() * NN_cblk.col() != k * n )
3922 NN_cblk = mpigofmm::FindNeighbors( K, rkdtsplitter,
3923 config, CommGOFMM );
3925 ann_time = omp_get_wtime() - beg;
3928 auto *tree_ptr =
new TREE( CommGOFMM );
3929 auto &tree = *tree_ptr;
3932 tree.setup.FromConfiguration( config, K, splitter, &NN_cblk );
3935 beg = omp_get_wtime();
3936 tree.TreePartition();
3937 tree_time = omp_get_wtime() - beg;
3940 vector<size_t> perm = tree.GetPermutation();
3943 ofstream perm_file(
"perm.txt" );
3944 for (
auto &
id : perm ) perm_file <<
id <<
" ";
3952 tree.setup.NN = &NN;
3953 beg = omp_get_wtime();
3954 ExchangeNeighbors<T>( tree );
3955 exchange_neighbor_time = omp_get_wtime() - beg;
3958 beg = omp_get_wtime();
3960 FindNearInteractions( tree );
3962 mpigofmm::SymmetrizeNearInteractions( tree );
3964 BuildInteractionListPerRank( tree,
true );
3966 ExchangeLET( tree,
string(
"leafgids" ) );
3967 symmetrize_time = omp_get_wtime() - beg;
3971 mpi::PrintProgress(
"[BEG] MergeFarNodes ...", tree.GetComm() );
3972 beg = omp_get_wtime();
3973 tree.DependencyCleanUp();
3976 tree.LocaTraverseUp( seqMERGEtask );
3977 tree.DistTraverseUp( mpiMERGEtask );
3978 tree.ExecuteAllTasks();
3979 mergefarnodes_time += omp_get_wtime() - beg;
3980 mpi::PrintProgress(
"[END] MergeFarNodes ...", tree.GetComm() );
3983 beg = omp_get_wtime();
3984 mpigofmm::SymmetrizeFarInteractions( tree );
3986 BuildInteractionListPerRank( tree,
false );
3987 symmetrize_time += omp_get_wtime() - beg;
3989 mpi::PrintProgress(
"[BEG] Skeletonization ...", tree.GetComm() );
3991 beg = omp_get_wtime();
3992 tree.DependencyCleanUp();
3998 tree.LocaTraverseUp( seqGETMTXtask, seqSKELtask );
4003 tree.LocaTraverseUnOrdered( seqPROJtask );
4010 tree.ExecuteAllTasks();
4011 skel_time = omp_get_wtime() - beg;
4013 beg = omp_get_wtime();
4014 tree.DistTraverseUp( mpiGETMTXtask, mpiSKELtask );
4015 tree.DistTraverseUnOrdered( mpiPROJtask );
4016 tree.ExecuteAllTasks();
4017 mpi_skel_time = omp_get_wtime() - beg;
4018 mpi::PrintProgress(
"[END] Skeletonization ...", tree.GetComm() );
4023 ExchangeLET( tree,
string(
"skelgids" ) );
4025 beg = omp_get_wtime();
4034 cachefarnodes_time = omp_get_wtime() - beg;
4035 tree.ExecuteAllTasks();
4036 cachefarnodes_time = omp_get_wtime() - beg;
4041 auto ratio = NonCompressedRatio( tree );
4043 double exact_ratio = (double) m / n;
4045 if ( rank == 0 && REPORT_COMPRESS_STATUS )
4047 compress_time += ann_time;
4048 compress_time += tree_time;
4049 compress_time += exchange_neighbor_time;
4050 compress_time += symmetrize_time;
4051 compress_time += skel_time;
4052 compress_time += mpi_skel_time;
4053 compress_time += mergefarnodes_time;
4054 compress_time += cachefarnodes_time;
4055 time_ratio = 100.0 / compress_time;
4056 printf(
"========================================================\n");
4057 printf(
"GOFMM compression phase\n" );
4058 printf(
"========================================================\n");
4059 printf(
"NeighborSearch ------------------------ %5.2lfs (%5.1lf%%)\n", ann_time, ann_time * time_ratio );
4060 printf(
"TreePartitioning ---------------------- %5.2lfs (%5.1lf%%)\n", tree_time, tree_time * time_ratio );
4061 printf(
"ExchangeNeighbors --------------------- %5.2lfs (%5.1lf%%)\n", exchange_neighbor_time, exchange_neighbor_time * time_ratio );
4062 printf(
"MergeFarNodes ------------------------- %5.2lfs (%5.1lf%%)\n", mergefarnodes_time, mergefarnodes_time * time_ratio );
4063 printf(
"Symmetrize ---------------------------- %5.2lfs (%5.1lf%%)\n", symmetrize_time, symmetrize_time * time_ratio );
4064 printf(
"Skeletonization (HMLP Runtime ) ----- %5.2lfs (%5.1lf%%)\n", skel_time, skel_time * time_ratio );
4065 printf(
"Skeletonization (MPI ) ----- %5.2lfs (%5.1lf%%)\n", mpi_skel_time, mpi_skel_time * time_ratio );
4066 printf(
"Cache KIJ ----------------------------- %5.2lfs (%5.1lf%%)\n", cachefarnodes_time, cachefarnodes_time * time_ratio );
4067 printf(
"========================================================\n");
4068 printf(
"%5.3lf%% and %5.3lf%% uncompressed--------- %5.2lfs (%5.1lf%%)\n",
4069 100 * ratio.first, 100 * ratio.second, compress_time, compress_time * time_ratio );
4070 printf(
"========================================================\n\n");
4074 tree_ptr->DependencyCleanUp();
4076 mpi::Barrier( tree.GetComm() );
4080 catch (
const exception & e )
4082 cout << e.what() << endl;
4089 template<
typename TREE,
typename T>
4090 pair<T, T> ComputeError( TREE &tree,
size_t gid,
Data<T> potentials )
4092 int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank );
4093 int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size );
4096 pair<T, T> ret( 0, 0 );
4098 auto &K = *tree.setup.K;
4099 auto &w = *tree.setup.w;
4101 auto I = vector<size_t>( 1, gid );
4102 auto &J = tree.treelist[ 0 ]->gids;
4105 K.BcastIndices( I, gid % comm_size, tree.GetComm() );
4109 auto loc_exact = potentials;
4110 auto glb_exact = potentials;
4112 xgemm(
"N",
"N", Kab.
row(), w.col(), w.row(),
4113 1.0, Kab.data(), Kab.
row(),
4115 0.0, loc_exact.data(), loc_exact.row() );
4122 mpi::Allreduce( loc_exact.data(), glb_exact.data(),
4123 loc_exact.size(), MPI_SUM, tree.GetComm() );
4125 for ( uint64_t j = 0; j < w.col(); j ++ )
4127 T exac = glb_exact[ j ];
4128 T pred = potentials[ j ];
4130 ret.first += ( pred - exac ) * ( pred - exac );
4131 ret.second += exac * exac;
4146 template<
typename TREE>
4147 void SelfTesting( TREE &tree,
size_t ntest,
size_t nrhs )
4150 using T =
typename TREE::T;
4152 int rank; mpi::Comm_rank( tree.GetComm(), &rank );
4153 int size; mpi::Comm_size( tree.GetComm(), &size );
4157 if ( ntest > n ) ntest = n;
4159 vector<size_t> all_rhs( nrhs );
4160 for (
size_t rhs = 0; rhs < nrhs; rhs ++ ) all_rhs[ rhs ] = rhs;
4170 auto u_rids = mpigofmm::Evaluate<true>( tree, w_rids );
4172 assert( !u_rids.HasIllegalValue() );
4178 printf(
"========================================================\n");
4179 printf(
"Accuracy report\n" );
4180 printf(
"========================================================\n");
4183 T nnerr_avg = 0.0, nonnerr_avg = 0.0, fmmerr_avg = 0.0;
4184 T sse_2norm = 0.0, ssv_2norm = 0.0;
4186 for (
size_t i = 0; i < ntest; i ++ )
4188 size_t tar = i * n / ntest;
4189 Data<T> potentials( (
size_t)1, nrhs );
4190 if ( rank == ( tar % size ) ) potentials = u_rblk( vector<size_t>( 1, tar ), all_rhs );
4192 mpi::Bcast( potentials.data(), nrhs, tar % size, tree.GetComm() );
4194 auto sse_ssv = mpigofmm::ComputeError( tree, tar, potentials );
4196 auto fmmerr = sqrt( sse_ssv.first / sse_ssv.second );
4198 fmmerr_avg += fmmerr;
4200 sse_2norm += sse_ssv.first;
4201 ssv_2norm += sse_ssv.second;
4203 if ( i < 10 && rank == 0 )
4205 printf(
"gid %6lu, ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n",
4206 tar, 0.0, 0.0, fmmerr );
4211 printf(
"========================================================\n");
4212 printf(
"Elementwise ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n",
4213 nnerr_avg / ntest , nonnerr_avg / ntest, fmmerr_avg / ntest );
4214 printf(
"F-norm ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n",
4215 0.0, 0.0, sqrt( sse_2norm / ssv_2norm ) );
4216 printf(
"========================================================\n");
4221 mpigofmm::DistFactorize( tree, lambda );
4222 mpigofmm::ComputeError( tree, lambda, w_rids, u_rids );
4227 template<
typename SPDMATRIX>
4230 using T =
typename SPDMATRIX::T;
4231 const int N_CHILDREN = 2;
4236 SPLITTER splitter( K );
4238 splitter.metric = cmd.
metric;
4240 RKDTSPLITTER rkdtsplitter( K );
4241 rkdtsplitter.Kptr = &K;
4242 rkdtsplitter.metric = cmd.
metric;
4245 cmd.
n, cmd.m, cmd.k, cmd.s, cmd.
stol, cmd.budget );
4249 auto *tree_ptr = mpigofmm::Compress( K, NN, splitter, rkdtsplitter, config, CommGOFMM );
4250 auto &tree = *tree_ptr;
4253 mpigofmm::SelfTesting( tree, 100, cmd.nrhs );
This distributed tree inherits the shared memory tree with some additional MPI data structure and fun...
Definition: tree_mpi.hpp:620
void Set(NODE *user_arg, vector< NODE * > user_src, int user_p, Lock *user_lock, int *user_num_arrived_subtasks)
Definition: gofmm_mpi.hpp:1314
Configuration contains all user-defined parameters.
Definition: gofmm.hpp:212
void GetEventRecord()
Definition: gofmm_mpi.hpp:3443
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:2968
void Set(NODE *user_arg, vector< LETNODE * > user_src, int user_p, Lock *user_lock, int *user_num_arrived_subtasks)
Definition: gofmm_mpi.hpp:955
void DependencyAnalysis()
Definition: gofmm_mpi.hpp:2802
Definition: gofmm_mpi.hpp:2445
Definition: mpi_prototypes.h:81
Definition: runtime.hpp:331
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:779
Definition: hmlp_mpi.hpp:152
This the main splitter used to build the Spd-Askit tree. First compute the approximate center using s...
Definition: gofmm_mpi.hpp:408
Definition: DistData.hpp:249
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:3494
Definition: gofmm_mpi.hpp:2910
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:3432
void hmlp_msg_dependency_analysis(int key, int p, ReadWriteType type, Task *task)
Definition: runtime.cpp:1485
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:2790
Definition: runtime.hpp:380
This task creates an hierarchical tree view for w<RIDS> and u<RIDS>.
Definition: gofmm.hpp:414
Definition: gofmm.hpp:1867
size_t row()
Definition: View.hpp:345
void Execute(Worker *user_worker)
Definition: gofmm_mpi.hpp:1122
Definition: gofmm_mpi.hpp:2380
Definition: gofmm_mpi.hpp:1298
This task creates an hierarchical tree view for weights<RIDS> and potentials<RIDS>.
Definition: gofmm_mpi.hpp:259
Definition: gofmm_mpi.hpp:2784
void Execute(Worker *user_worker)
Definition: gofmm_mpi.hpp:1353
void DependencyAnalysis()
Definition: gofmm_mpi.hpp:2928
void xgemm(const char *transA, const char *transB, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
DGEMM wrapper.
Definition: blas_lapack.cpp:130
Definition: gofmm_mpi.hpp:3014
Data and setup that are shared with all nodes.
Definition: tree_mpi.hpp:260
Definition: gofmm_mpi.hpp:1053
Definition: gofmm.hpp:1168
void Execute(Worker *user_worker)
Definition: gofmm_mpi.hpp:276
size_t col()
Definition: View.hpp:348
T * data()
Definition: View.hpp:354
PackFarTask(TREE *tree, int src, int tar, int key)
Definition: gofmm_mpi.hpp:2449
UnpackLeafTask(TREE *tree, int src, int tar, int key)
Definition: gofmm_mpi.hpp:2426
void DependencyAnalysis()
Definition: gofmm_mpi.hpp:274
void Execute(Worker *user_worker)
Definition: gofmm_mpi.hpp:3587
void Execute(Worker *user_worker)
Definition: gofmm_mpi.hpp:3536
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:3317
Definition: gofmm_mpi.hpp:1436
This the splitter used in the randomized tree.
Definition: gofmm.hpp:658
Ecah MPI process own ( rids.size() ) rows of A, and rids denote the distribution. i...
Definition: DistData.hpp:1219
void Pack()
Definition: gofmm_mpi.hpp:2399
PackNearTask(TREE *tree, int src, int tar, int key)
Definition: gofmm_mpi.hpp:2384
Wrapper for omp or pthread mutex.
Definition: tci.hpp:50
Definition: gofmm_mpi.hpp:1221
void GetEventRecord()
Definition: gofmm_mpi.hpp:3506
size_t row()
Definition: DistData.hpp:217
size_t row() const noexcept
Definition: Data.hpp:278
Definition: gofmm_mpi.hpp:2422
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:1067
void Pack()
Definition: gofmm_mpi.hpp:2464
Notice that NODE here is MPITree::Node.
Definition: gofmm_mpi.hpp:773
Notice that S2S depends on all Far interactions, which may include local tree nodes or let nodes...
Definition: gofmm_mpi.hpp:941
These are data that shared by the whole local tree. Distributed setup inherits mpitree::Setup.
Definition: gofmm_mpi.hpp:209
Definition: gofmm_mpi.hpp:2475
Definition: gofmm.hpp:738
void Acquire()
Definition: tci.cpp:53
void Execute(Worker *user_worker)
Definition: gofmm_mpi.hpp:996
size_t ld()
Definition: View.hpp:351
The correponding task of Interpolate().
Definition: gofmm.hpp:945
DistanceMetric metric
Definition: gofmm.hpp:193
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:3020
This is a helper class that parses the arguments from command lines.
Definition: gofmm.hpp:89
void FromConfiguration(gofmm::Configuration< T > &config, SPDMATRIX &K, SPLITTER &splitter, DistData< STAR, CBLK, pair< T, size_t >> *NN_cblk)
Definition: gofmm_mpi.hpp:215
Definition: DistData.hpp:156
size_t n
Definition: gofmm.hpp:185
Definition: gofmm_mpi.hpp:3488
UnpackFarTask(TREE *tree, int src, int tar, int key)
Definition: gofmm_mpi.hpp:2479
Definition: gofmm.hpp:1470
Definition: gofmm_mpi.hpp:2962
Ecah MPI process own ( n / size ) rows of A in a cyclic fashion (Round Robin). i.e. If there are 3 MPI processes, then.
Definition: DistData.hpp:619
vector< NODE * > Sources
Definition: gofmm_mpi.hpp:1305
double stol
Definition: gofmm.hpp:190
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:1227
Definition: gofmm_mpi.hpp:3570
This the main splitter used to build the Spd-Askit tree. First compute the approximate center using s...
Definition: gofmm.hpp:580
Definition: gofmm_mpi.hpp:3426
This class contains all GOFMM related data carried by a tree node.
Definition: gofmm.hpp:345
Definition: gofmm_mpi.hpp:3311
Definition: runtime.hpp:174
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:1450
void DependencyAnalysis()
Definition: gofmm_mpi.hpp:1345
Definition: gofmm_mpi.hpp:546
void Set(NODE *user_arg)
Definition: gofmm_mpi.hpp:2916
Definition: thread.hpp:166