1 |
|
/** |
2 |
|
* HMLP (High-Performance Machine Learning Primitives) |
3 |
|
* |
4 |
|
* Copyright (C) 2014-2017, The University of Texas at Austin |
5 |
|
* |
6 |
|
* This program is free software: you can redistribute it and/or modify |
7 |
|
* it under the terms of the GNU General Public License as published by |
8 |
|
* the Free Software Foundation, either version 3 of the License, or |
9 |
|
* (at your option) any later version. |
10 |
|
* |
11 |
|
* This program is distributed in the hope that it will be useful, |
12 |
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of |
13 |
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
14 |
|
* GNU General Public License for more details. |
15 |
|
* |
16 |
|
* You should have received a copy of the GNU General Public License |
17 |
|
* along with this program. If not, see the LICENSE file. |
18 |
|
* |
19 |
|
**/ |
20 |
|
|
21 |
|
|
22 |
|
#ifndef GOFMM_MPI_HPP |
23 |
|
#define GOFMM_MPI_HPP |
24 |
|
|
25 |
|
/** Inherit most of the classes from shared-memory GOFMM. */ |
26 |
|
#include <gofmm.hpp> |
27 |
|
/** Use distributed metric trees. */ |
28 |
|
#include <tree_mpi.hpp> |
29 |
|
#include <igofmm_mpi.hpp> |
30 |
|
/** Use distributed matrices inspired by the Elemental notation. */ |
31 |
|
//#include <DistData.hpp> |
32 |
|
/** Use STL and HMLP namespaces. */ |
33 |
|
using namespace std; |
34 |
|
using namespace hmlp; |
35 |
|
|
36 |
|
|
37 |
|
namespace hmlp |
38 |
|
{ |
39 |
|
namespace mpigofmm |
40 |
|
{ |
41 |
|
|
42 |
|
|
43 |
|
///** |
44 |
|
// * @biref This class does not have to inherit DistData, but it have to |
45 |
|
// * inherit DistVirtualMatrix<T> |
46 |
|
// * |
47 |
|
// */ |
48 |
|
//template<typename T> |
49 |
|
//class DistSPDMatrix : public DistData<STAR, CBLK, T> |
50 |
|
//{ |
51 |
|
// public: |
52 |
|
// |
53 |
|
// DistSPDMatrix( size_t m, size_t n, mpi::Comm comm ) : |
54 |
|
// DistData<STAR, CBLK, T>( m, n, comm ) |
55 |
|
// { |
56 |
|
// }; |
57 |
|
// |
58 |
|
// |
59 |
|
// /** ESSENTIAL: this is an abstract function */ |
60 |
|
// virtual T operator()( size_t i, size_t j, mpi::Comm comm ) |
61 |
|
// { |
62 |
|
// T Kij = 0; |
63 |
|
// |
64 |
|
// /** MPI */ |
65 |
|
// int size, rank; |
66 |
|
// hmlp::mpi::Comm_size( comm, &size ); |
67 |
|
// hmlp::mpi::Comm_rank( comm, &rank ); |
68 |
|
// |
69 |
|
// std::vector<std::vector<size_t>> sendrids( size ); |
70 |
|
// std::vector<std::vector<size_t>> recvrids( size ); |
71 |
|
// std::vector<std::vector<size_t>> sendcids( size ); |
72 |
|
// std::vector<std::vector<size_t>> recvcids( size ); |
73 |
|
// |
74 |
|
// /** request Kij from rank ( j % size ) */ |
75 |
|
// sendrids[ i % size ].push_back( i ); |
76 |
|
// sendcids[ j % size ].push_back( j ); |
77 |
|
// |
78 |
|
// /** exchange ids */ |
79 |
|
// mpi::AlltoallVector( sendrids, recvrids, comm ); |
80 |
|
// mpi::AlltoallVector( sendcids, recvcids, comm ); |
81 |
|
// |
82 |
|
// /** allocate buffer for data */ |
83 |
|
// std::vector<std::vector<T>> senddata( size ); |
84 |
|
// std::vector<std::vector<T>> recvdata( size ); |
85 |
|
// |
86 |
|
// /** fetch subrows */ |
87 |
|
// for ( size_t p = 0; p < size; p ++ ) |
88 |
|
// { |
89 |
|
// assert( recvrids[ p ].size() == recvcids[ p ].size() ); |
90 |
|
// for ( size_t j = 0; j < recvcids[ p ].size(); j ++ ) |
91 |
|
// { |
92 |
|
// size_t rid = recvrids[ p ][ j ]; |
93 |
|
// size_t cid = recvcids[ p ][ j ]; |
94 |
|
// senddata[ p ].push_back( (*this)( rid, cid ) ); |
95 |
|
// } |
96 |
|
// } |
97 |
|
// |
98 |
|
// /** exchange data */ |
99 |
|
// mpi::AlltoallVector( senddata, recvdata, comm ); |
100 |
|
// |
101 |
|
// for ( size_t p = 0; p < size; p ++ ) |
102 |
|
// { |
103 |
|
// assert( recvdata[ p ].size() <= 1 ); |
104 |
|
// if ( recvdata[ p ] ) Kij = recvdata[ p ][ 0 ]; |
105 |
|
// } |
106 |
|
// |
107 |
|
// return Kij; |
108 |
|
// }; |
109 |
|
// |
110 |
|
// |
111 |
|
// /** ESSENTIAL: return a submatrix */ |
112 |
|
// virtual hmlp::Data<T> operator() |
113 |
|
// ( std::vector<size_t> &imap, std::vector<size_t> &jmap, hmlp::mpi::Comm comm ) |
114 |
|
// { |
115 |
|
// hmlp::Data<T> KIJ( imap.size(), jmap.size() ); |
116 |
|
// |
117 |
|
// /** MPI */ |
118 |
|
// int size, rank; |
119 |
|
// hmlp::mpi::Comm_size( comm, &size ); |
120 |
|
// hmlp::mpi::Comm_rank( comm, &rank ); |
121 |
|
// |
122 |
|
// |
123 |
|
// |
124 |
|
// std::vector<std::vector<size_t>> jmapcids( size ); |
125 |
|
// |
126 |
|
// std::vector<std::vector<size_t>> sendrids( size ); |
127 |
|
// std::vector<std::vector<size_t>> recvrids( size ); |
128 |
|
// std::vector<std::vector<size_t>> sendcids( size ); |
129 |
|
// std::vector<std::vector<size_t>> recvcids( size ); |
130 |
|
// |
131 |
|
// /** request KIJ from rank ( j % size ) */ |
132 |
|
// for ( size_t j = 0; j < jmap.size(); j ++ ) |
133 |
|
// { |
134 |
|
// size_t cid = jmap[ j ]; |
135 |
|
// sendcids[ cid % size ].push_back( cid ); |
136 |
|
// jmapcids[ cid % size ].push_back( j ); |
137 |
|
// } |
138 |
|
// |
139 |
|
// for ( size_t p = 0; p < size; p ++ ) |
140 |
|
// { |
141 |
|
// if ( sendcids[ p ].size() ) sendrids[ p ] = imap; |
142 |
|
// } |
143 |
|
// |
144 |
|
// /** exchange ids */ |
145 |
|
// mpi::AlltoallVector( sendrids, recvrids, comm ); |
146 |
|
// mpi::AlltoallVector( sendcids, recvcids, comm ); |
147 |
|
// |
148 |
|
// /** allocate buffer for data */ |
149 |
|
// std::vector<hmlp::Data<T>> senddata( size ); |
150 |
|
// std::vector<hmlp::Data<T>> recvdata( size ); |
151 |
|
// |
152 |
|
// /** fetch submatrix */ |
153 |
|
// for ( size_t p = 0; p < size; p ++ ) |
154 |
|
// { |
155 |
|
// if ( recvcids[ p ].size() && recvrids[ p ].size() ) |
156 |
|
// { |
157 |
|
// senddata[ p ] = (*this)( recvrids[ p ], recvcids[ p ] ); |
158 |
|
// } |
159 |
|
// } |
160 |
|
// |
161 |
|
// /** exchange data */ |
162 |
|
// mpi::AlltoallVector( senddata, recvdata, comm ); |
163 |
|
// |
164 |
|
// /** merging data */ |
165 |
|
// for ( size_t p = 0; j < size; p ++ ) |
166 |
|
// { |
167 |
|
// assert( recvdata[ p ].size() == imap.size() * recvcids[ p ].size() ); |
168 |
|
// recvdata[ p ].resize( imap.size(), recvcids[ p ].size() ); |
169 |
|
// for ( size_t j = 0; j < recvcids[ p ]; i ++ ) |
170 |
|
// { |
171 |
|
// for ( size_t i = 0; i < imap.size(); i ++ ) |
172 |
|
// { |
173 |
|
// KIJ( i, jmapcids[ p ][ j ] ) = recvdata[ p ]( i, j ); |
174 |
|
// } |
175 |
|
// } |
176 |
|
// }; |
177 |
|
// |
178 |
|
// return KIJ; |
179 |
|
// }; |
180 |
|
// |
181 |
|
// |
182 |
|
// |
183 |
|
// |
184 |
|
// |
185 |
|
// virtual hmlp::Data<T> operator() |
186 |
|
// ( std::vector<int> &imap, std::vector<int> &jmap, hmlp::mpi::Comm comm ) |
187 |
|
// { |
188 |
|
// printf( "operator() not implemented yet\n" ); |
189 |
|
// exit( 1 ); |
190 |
|
// }; |
191 |
|
// |
192 |
|
// |
193 |
|
// |
194 |
|
// /** overload operator */ |
195 |
|
// |
196 |
|
// |
197 |
|
// private: |
198 |
|
// |
199 |
|
//}; /** end class DistSPDMatrix */ |
200 |
|
// |
201 |
|
// |
202 |
|
|
203 |
|
|
204 |
|
/** |
205 |
|
* @brief These are data that shared by the whole local tree. |
206 |
|
* Distributed setup inherits mpitree::Setup. |
207 |
|
*/ |
208 |
|
template<typename SPDMATRIX, typename SPLITTER, typename T> |
209 |
|
class Setup : public mpitree::Setup<SPLITTER, T>, |
210 |
|
public gofmm::Configuration<T> |
211 |
|
{ |
212 |
|
public: |
213 |
|
|
214 |
|
/** Shallow copy from the config. */ |
215 |
|
void FromConfiguration( gofmm::Configuration<T> &config, |
216 |
|
SPDMATRIX &K, SPLITTER &splitter, |
217 |
|
DistData<STAR, CBLK, pair<T, size_t>>* NN_cblk ) |
218 |
|
{ |
219 |
|
this->CopyFrom( config ); |
220 |
|
this->K = &K; |
221 |
|
this->splitter = splitter; |
222 |
|
this->NN_cblk = NN_cblk; |
223 |
|
}; |
224 |
|
|
225 |
|
/** The SPDMATRIX (accessed with gids: dense, CSC or OOC) */ |
226 |
|
SPDMATRIX *K = NULL; |
227 |
|
|
228 |
|
/** rhs-by-n, all weights and potentials. */ |
229 |
|
Data<T> *w = NULL; |
230 |
|
Data<T> *u = NULL; |
231 |
|
|
232 |
|
/** buffer space, either dimension needs to be n */ |
233 |
|
Data<T> *input = NULL; |
234 |
|
Data<T> *output = NULL; |
235 |
|
|
236 |
|
/** regularization */ |
237 |
|
T lambda = 0.0; |
238 |
|
|
239 |
|
/** whether the matrix is symmetric */ |
240 |
|
//bool issymmetric = true; |
241 |
|
|
242 |
|
/** use ULV or Sherman-Morrison-Woodbury */ |
243 |
|
bool do_ulv_factorization = true; |
244 |
|
|
245 |
|
|
246 |
|
private: |
247 |
|
|
248 |
|
}; /** end class Setup */ |
249 |
|
|
250 |
|
|
251 |
|
|
252 |
|
|
253 |
|
|
254 |
|
/** |
255 |
|
* @brief This task creates an hierarchical tree view for |
256 |
|
* weights<RIDS> and potentials<RIDS>. |
257 |
|
*/ |
258 |
|
template<typename NODE> |
259 |
|
class DistTreeViewTask : public Task |
260 |
|
{ |
261 |
|
public: |
262 |
|
|
263 |
|
NODE *arg = NULL; |
264 |
|
|
265 |
|
void Set( NODE *user_arg ) |
266 |
|
{ |
267 |
|
arg = user_arg; |
268 |
|
name = string( "TreeView" ); |
269 |
|
label = to_string( arg->treelist_id ); |
270 |
|
cost = 1.0; |
271 |
|
}; |
272 |
|
|
273 |
|
/** Preorder dependencies (with a single source node) */ |
274 |
|
void DependencyAnalysis() { arg->DependOnParent( this ); }; |
275 |
|
|
276 |
|
void Execute( Worker* user_worker ) |
277 |
|
{ |
278 |
|
auto *node = arg; |
279 |
|
|
280 |
|
/** w and u can be Data<T> or DistData<RIDS,STAR,T> */ |
281 |
|
auto &w = *(node->setup->w); |
282 |
|
auto &u = *(node->setup->u); |
283 |
|
|
284 |
|
/** get the matrix view of this tree node */ |
285 |
|
auto &U = node->data.u_view; |
286 |
|
auto &W = node->data.w_view; |
287 |
|
|
288 |
|
/** Both w and u are column-majored, thus nontranspose. */ |
289 |
|
U.Set( u ); |
290 |
|
W.Set( w ); |
291 |
|
|
292 |
|
/** Create sub matrix views for local nodes. */ |
293 |
|
if ( !node->isleaf && !node->child ) |
294 |
|
{ |
295 |
|
assert( node->lchild && node->rchild ); |
296 |
|
auto &UL = node->lchild->data.u_view; |
297 |
|
auto &UR = node->rchild->data.u_view; |
298 |
|
auto &WL = node->lchild->data.w_view; |
299 |
|
auto &WR = node->rchild->data.w_view; |
300 |
|
/** |
301 |
|
* U = [ UL; W = [ WL; |
302 |
|
* UR; ] WR; ] |
303 |
|
*/ |
304 |
|
U.Partition2x1( UL, |
305 |
|
UR, node->lchild->n, TOP ); |
306 |
|
W.Partition2x1( WL, |
307 |
|
WR, node->lchild->n, TOP ); |
308 |
|
} |
309 |
|
}; |
310 |
|
|
311 |
|
}; /** end class DistTreeViewTask */ |
312 |
|
|
313 |
|
|
314 |
|
|
315 |
|
|
316 |
|
|
317 |
|
|
318 |
|
|
319 |
|
|
320 |
|
|
321 |
|
/** @brief Split values into two halfs accroding to the median. */ |
322 |
|
template<typename T> |
323 |
|
vector<vector<size_t>> DistMedianSplit( vector<T> &values, mpi::Comm comm ) |
324 |
|
{ |
325 |
|
int n = 0; |
326 |
|
int num_points_owned = values.size(); |
327 |
|
/** n = sum( num_points_owned ) over all MPI processes in comm */ |
328 |
|
mpi::Allreduce( &num_points_owned, &n, 1, MPI_SUM, comm ); |
329 |
|
T median = combinatorics::Select( n / 2, values, comm ); |
330 |
|
|
331 |
|
vector<vector<size_t>> split( 2 ); |
332 |
|
vector<size_t> middle; |
333 |
|
|
334 |
|
if ( n == 0 ) return split; |
335 |
|
|
336 |
|
for ( size_t i = 0; i < values.size(); i ++ ) |
337 |
|
{ |
338 |
|
auto v = values[ i ]; |
339 |
|
if ( std::fabs( v - median ) < 1E-6 ) middle.push_back( i ); |
340 |
|
else if ( v < median ) split[ 0 ].push_back( i ); |
341 |
|
else split[ 1 ].push_back( i ); |
342 |
|
} |
343 |
|
|
344 |
|
int nmid = 0; |
345 |
|
int nlhs = 0; |
346 |
|
int nrhs = 0; |
347 |
|
int num_mid_owned = middle.size(); |
348 |
|
int num_lhs_owned = split[ 0 ].size(); |
349 |
|
int num_rhs_owned = split[ 1 ].size(); |
350 |
|
|
351 |
|
/** nmid = sum( num_mid_owned ) over all MPI processes in comm. */ |
352 |
|
mpi::Allreduce( &num_mid_owned, &nmid, 1, MPI_SUM, comm ); |
353 |
|
mpi::Allreduce( &num_lhs_owned, &nlhs, 1, MPI_SUM, comm ); |
354 |
|
mpi::Allreduce( &num_rhs_owned, &nrhs, 1, MPI_SUM, comm ); |
355 |
|
|
356 |
|
/** Assign points in the middle to left or right. */ |
357 |
|
if ( nmid ) |
358 |
|
{ |
359 |
|
int nlhs_required, nrhs_required; |
360 |
|
|
361 |
|
if ( nlhs > nrhs ) |
362 |
|
{ |
363 |
|
nlhs_required = ( n - 1 ) / 2 + 1 - nlhs; |
364 |
|
nrhs_required = nmid - nlhs_required; |
365 |
|
} |
366 |
|
else |
367 |
|
{ |
368 |
|
nrhs_required = ( n - 1 ) / 2 + 1 - nrhs; |
369 |
|
nlhs_required = nmid - nrhs_required; |
370 |
|
} |
371 |
|
|
372 |
|
assert( nlhs_required >= 0 && nrhs_required >= 0 ); |
373 |
|
|
374 |
|
/** Now decide the portion */ |
375 |
|
double lhs_ratio = ( (double)nlhs_required ) / nmid; |
376 |
|
int nlhs_required_owned = num_mid_owned * lhs_ratio; |
377 |
|
int nrhs_required_owned = num_mid_owned - nlhs_required_owned; |
378 |
|
|
379 |
|
//printf( "rank %d [ %d %d ] [ %d %d ]\n", |
380 |
|
// global_rank, |
381 |
|
// nlhs_required_owned, nlhs_required, |
382 |
|
// nrhs_required_owned, nrhs_required ); fflush( stdout ); |
383 |
|
|
384 |
|
assert( nlhs_required_owned >= 0 && nrhs_required_owned >= 0 ); |
385 |
|
|
386 |
|
for ( size_t i = 0; i < middle.size(); i ++ ) |
387 |
|
{ |
388 |
|
if ( i < nlhs_required_owned ) |
389 |
|
split[ 0 ].push_back( middle[ i ] ); |
390 |
|
else |
391 |
|
split[ 1 ].push_back( middle[ i ] ); |
392 |
|
} |
393 |
|
} |
394 |
|
|
395 |
|
return split; |
396 |
|
}; /** end MedianSplit() */ |
397 |
|
|
398 |
|
|
399 |
|
|
400 |
|
|
401 |
|
/** |
402 |
|
* @brief This the main splitter used to build the Spd-Askit tree. |
403 |
|
* First compute the approximate center using subsamples. |
404 |
|
* Then find the two most far away points to do the |
405 |
|
* projection. |
406 |
|
*/ |
407 |
|
template<typename SPDMATRIX, int N_SPLIT, typename T> |
408 |
|
struct centersplit : public gofmm::centersplit<SPDMATRIX, N_SPLIT, T> |
409 |
|
{ |
410 |
|
|
411 |
|
centersplit() : gofmm::centersplit<SPDMATRIX, N_SPLIT, T>() {}; |
412 |
|
|
413 |
|
centersplit( SPDMATRIX& K ) : gofmm::centersplit<SPDMATRIX, N_SPLIT, T>( K ) {}; |
414 |
|
|
415 |
|
/** Shared-memory operator. */ |
416 |
|
inline vector<vector<size_t> > operator() ( vector<size_t>& gids ) const |
417 |
|
{ |
418 |
|
return gofmm::centersplit<SPDMATRIX, N_SPLIT, T>::operator() ( gids ); |
419 |
|
}; |
420 |
|
|
421 |
|
/** Distributed operator. */ |
422 |
|
inline vector<vector<size_t> > operator() ( vector<size_t>& gids, mpi::Comm comm ) const |
423 |
|
{ |
424 |
|
/** All assertions */ |
425 |
|
assert( N_SPLIT == 2 ); |
426 |
|
assert( this->Kptr ); |
427 |
|
|
428 |
|
/** MPI Support. */ |
429 |
|
int size; mpi::Comm_size( comm, &size ); |
430 |
|
int rank; mpi::Comm_rank( comm, &rank ); |
431 |
|
auto &K = *(this->Kptr); |
432 |
|
|
433 |
|
/** */ |
434 |
|
vector<T> temp( gids.size(), 0.0 ); |
435 |
|
|
436 |
|
/** Collecting column samples of K. */ |
437 |
|
auto column_samples = combinatorics::SampleWithoutReplacement( |
438 |
|
this->n_centroid_samples, gids ); |
439 |
|
|
440 |
|
/** Bcast column_samples from rank 0. */ |
441 |
|
mpi::Bcast( column_samples.data(), column_samples.size(), 0, comm ); |
442 |
|
K.BcastIndices( column_samples, 0, comm ); |
443 |
|
|
444 |
|
/** Compute all pairwise distances. */ |
445 |
|
auto DIC = K.Distances( this->metric, gids, column_samples ); |
446 |
|
|
447 |
|
/** Zero out the temporary buffer. */ |
448 |
|
for ( auto & it : temp ) it = 0; |
449 |
|
|
450 |
|
/** Accumulate distances to the temporary buffer. */ |
451 |
|
for ( size_t j = 0; j < DIC.col(); j ++ ) |
452 |
|
for ( size_t i = 0; i < DIC.row(); i ++ ) |
453 |
|
temp[ i ] += DIC( i, j ); |
454 |
|
|
455 |
|
/** Find the f2c (far most to center) from points owned */ |
456 |
|
auto idf2c = distance( temp.begin(), max_element( temp.begin(), temp.end() ) ); |
457 |
|
|
458 |
|
/** Create a pair for MPI Allreduce */ |
459 |
|
mpi::NumberIntPair<T> local_max_pair, max_pair; |
460 |
|
local_max_pair.val = temp[ idf2c ]; |
461 |
|
local_max_pair.key = rank; |
462 |
|
|
463 |
|
/** max_pair = max( local_max_pairs ) over all MPI processes in comm */ |
464 |
|
mpi::Allreduce( &local_max_pair, &max_pair, 1, MPI_MAXLOC, comm ); |
465 |
|
|
466 |
|
/** Boardcast gidf2c from the MPI process which has the max_pair */ |
467 |
|
int gidf2c = gids[ idf2c ]; |
468 |
|
mpi::Bcast( &gidf2c, 1, MPI_INT, max_pair.key, comm ); |
469 |
|
|
470 |
|
|
471 |
|
//printf( "rank %d val %E key %d; global val %E key %d\n", |
472 |
|
// rank, local_max_pair.val, local_max_pair.key, |
473 |
|
// max_pair.val, max_pair.key ); fflush( stdout ); |
474 |
|
//printf( "rank %d gidf2c %d\n", rank, gidf2c ); fflush( stdout ); |
475 |
|
|
476 |
|
/** Collecting KIP and kpp */ |
477 |
|
vector<size_t> P( 1, gidf2c ); |
478 |
|
K.BcastIndices( P, max_pair.key, comm ); |
479 |
|
|
480 |
|
/** Compute all pairwise distances. */ |
481 |
|
auto DIP = K.Distances( this->metric, gids, P ); |
482 |
|
|
483 |
|
/** Find f2f (far most to far most) from owned points */ |
484 |
|
auto idf2f = distance( DIP.begin(), max_element( DIP.begin(), DIP.end() ) ); |
485 |
|
|
486 |
|
/** Create a pair for MPI Allreduce */ |
487 |
|
local_max_pair.val = DIP[ idf2f ]; |
488 |
|
local_max_pair.key = rank; |
489 |
|
|
490 |
|
/** max_pair = max( local_max_pairs ) over all MPI processes in comm */ |
491 |
|
mpi::Allreduce( &local_max_pair, &max_pair, 1, MPI_MAXLOC, comm ); |
492 |
|
|
493 |
|
/** boardcast gidf2f from the MPI process which has the max_pair */ |
494 |
|
int gidf2f = gids[ idf2f ]; |
495 |
|
mpi::Bcast( &gidf2f, 1, MPI_INT, max_pair.key, comm ); |
496 |
|
|
497 |
|
//printf( "rank %d val %E key %d; global val %E key %d\n", |
498 |
|
// rank, local_max_pair.val, local_max_pair.key, |
499 |
|
// max_pair.val, max_pair.key ); fflush( stdout ); |
500 |
|
//printf( "rank %d gidf2f %d\n", rank, gidf2f ); fflush( stdout ); |
501 |
|
|
502 |
|
/** Collecting KIQ and kqq */ |
503 |
|
vector<size_t> Q( 1, gidf2f ); |
504 |
|
K.BcastIndices( Q, max_pair.key, comm ); |
505 |
|
|
506 |
|
/** Compute all pairwise distances. */ |
507 |
|
auto DIQ = K.Distances( this->metric, gids, P ); |
508 |
|
|
509 |
|
/** We use relative distances (dip - diq) for clustering. */ |
510 |
|
for ( size_t i = 0; i < temp.size(); i ++ ) |
511 |
|
temp[ i ] = DIP[ i ] - DIQ[ i ]; |
512 |
|
|
513 |
|
/** Split gids into two clusters using median split. */ |
514 |
|
auto split = DistMedianSplit( temp, comm ); |
515 |
|
|
516 |
|
/** Perform P2P redistribution. */ |
517 |
|
mpi::Status status; |
518 |
|
vector<size_t> sent_gids; |
519 |
|
int partner = ( rank + size / 2 ) % size; |
520 |
|
if ( rank < size / 2 ) |
521 |
|
{ |
522 |
|
for ( auto it : split[ 1 ] ) |
523 |
|
sent_gids.push_back( gids[ it ] ); |
524 |
|
K.SendIndices( sent_gids, partner, comm ); |
525 |
|
K.RecvIndices( partner, comm, &status ); |
526 |
|
} |
527 |
|
else |
528 |
|
{ |
529 |
|
for ( auto it : split[ 0 ] ) |
530 |
|
sent_gids.push_back( gids[ it ] ); |
531 |
|
K.RecvIndices( partner, comm, &status ); |
532 |
|
K.SendIndices( sent_gids, partner, comm ); |
533 |
|
} |
534 |
|
|
535 |
|
return split; |
536 |
|
}; |
537 |
|
|
538 |
|
|
539 |
|
}; /** end struct centersplit */ |
540 |
|
|
541 |
|
|
542 |
|
|
543 |
|
|
544 |
|
|
545 |
|
template<typename SPDMATRIX, int N_SPLIT, typename T> |
546 |
|
struct randomsplit : public gofmm::randomsplit<SPDMATRIX, N_SPLIT, T> |
547 |
|
{ |
548 |
|
|
549 |
|
randomsplit() : gofmm::randomsplit<SPDMATRIX, N_SPLIT, T>() {}; |
550 |
|
|
551 |
|
randomsplit( SPDMATRIX& K ) : gofmm::randomsplit<SPDMATRIX, N_SPLIT, T>( K ) {}; |
552 |
|
|
553 |
|
/** Shared-memory operator. */ |
554 |
|
inline vector<vector<size_t> > operator() ( vector<size_t>& gids ) const |
555 |
|
{ |
556 |
|
return gofmm::randomsplit<SPDMATRIX, N_SPLIT, T>::operator() ( gids ); |
557 |
|
}; |
558 |
|
|
559 |
|
/** Distributed operator. */ |
560 |
|
inline vector<vector<size_t> > operator() ( vector<size_t>& gids, mpi::Comm comm ) const |
561 |
|
{ |
562 |
|
/** All assertions */ |
563 |
|
assert( N_SPLIT == 2 ); |
564 |
|
assert( this->Kptr ); |
565 |
|
|
566 |
|
/** Declaration */ |
567 |
|
int size, rank, global_rank, global_size; |
568 |
|
mpi::Comm_size( comm, &size ); |
569 |
|
mpi::Comm_rank( comm, &rank ); |
570 |
|
mpi::Comm_rank( MPI_COMM_WORLD, &global_rank ); |
571 |
|
mpi::Comm_size( MPI_COMM_WORLD, &global_size ); |
572 |
|
SPDMATRIX &K = *(this->Kptr); |
573 |
|
//vector<vector<size_t>> split( N_SPLIT ); |
574 |
|
|
575 |
|
if ( size == global_size ) |
576 |
|
{ |
577 |
|
for ( size_t i = 0; i < gids.size(); i ++ ) |
578 |
|
assert( gids[ i ] == i * size + rank ); |
579 |
|
} |
580 |
|
|
581 |
|
|
582 |
|
|
583 |
|
|
584 |
|
/** Reduce to get the total size of gids. */ |
585 |
|
int n = 0; |
586 |
|
int num_points_owned = gids.size(); |
587 |
|
vector<T> temp( gids.size(), 0.0 ); |
588 |
|
|
589 |
|
/** n = sum( num_points_owned ) over all MPI processes in comm */ |
590 |
|
mpi::Allreduce( &num_points_owned, &n, 1, MPI_INT, MPI_SUM, comm ); |
591 |
|
|
592 |
|
/** Early return */ |
593 |
|
//if ( n == 0 ) return split; |
594 |
|
|
595 |
|
/** Randomly select two points p and q */ |
596 |
|
size_t gidf2c, gidf2f; |
597 |
|
if ( gids.size() ) |
598 |
|
{ |
599 |
|
gidf2c = gids[ std::rand() % gids.size() ]; |
600 |
|
gidf2f = gids[ std::rand() % gids.size() ]; |
601 |
|
} |
602 |
|
|
603 |
|
/** Create a pair <gids.size(), rank> for MPI Allreduce */ |
604 |
|
mpi::NumberIntPair<T> local_max_pair, max_pair; |
605 |
|
local_max_pair.val = gids.size(); |
606 |
|
local_max_pair.key = rank; |
607 |
|
|
608 |
|
/** max_pair = max( local_max_pairs ) over all MPI processes in comm */ |
609 |
|
mpi::Allreduce( &local_max_pair, &max_pair, 1, MPI_MAXLOC, comm ); |
610 |
|
|
611 |
|
/** Bcast gidf2c from the rank that has the most gids */ |
612 |
|
mpi::Bcast( &gidf2c, 1, max_pair.key, comm ); |
613 |
|
vector<size_t> P( 1, gidf2c ); |
614 |
|
K.BcastIndices( P, max_pair.key, comm ); |
615 |
|
|
616 |
|
/** Choose the second MPI rank */ |
617 |
|
if ( rank == max_pair.key ) local_max_pair.val = 0; |
618 |
|
|
619 |
|
/** max_pair = max( local_max_pairs ) over all MPI processes in comm */ |
620 |
|
mpi::Allreduce( &local_max_pair, &max_pair, 1, MPI_MAXLOC, comm ); |
621 |
|
|
622 |
|
/** Bcast gidf2c from the rank that has the most gids */ |
623 |
|
mpi::Bcast( &gidf2f, 1, max_pair.key, comm ); |
624 |
|
vector<size_t> Q( 1, gidf2f ); |
625 |
|
K.BcastIndices( Q, max_pair.key, comm ); |
626 |
|
|
627 |
|
|
628 |
|
auto DIP = K.Distances( this->metric, gids, P ); |
629 |
|
auto DIQ = K.Distances( this->metric, gids, Q ); |
630 |
|
|
631 |
|
/** We use relative distances (dip - diq) for clustering. */ |
632 |
|
for ( size_t i = 0; i < temp.size(); i ++ ) |
633 |
|
temp[ i ] = DIP[ i ] - DIQ[ i ]; |
634 |
|
|
635 |
|
/** Split gids into two clusters using median split. */ |
636 |
|
auto split = DistMedianSplit( temp, comm ); |
637 |
|
|
638 |
|
/** Perform P2P redistribution. */ |
639 |
|
mpi::Status status; |
640 |
|
vector<size_t> sent_gids; |
641 |
|
int partner = ( rank + size / 2 ) % size; |
642 |
|
if ( rank < size / 2 ) |
643 |
|
{ |
644 |
|
for ( auto it : split[ 1 ] ) |
645 |
|
sent_gids.push_back( gids[ it ] ); |
646 |
|
K.SendIndices( sent_gids, partner, comm ); |
647 |
|
K.RecvIndices( partner, comm, &status ); |
648 |
|
} |
649 |
|
else |
650 |
|
{ |
651 |
|
for ( auto it : split[ 0 ] ) |
652 |
|
sent_gids.push_back( gids[ it ] ); |
653 |
|
K.RecvIndices( partner, comm, &status ); |
654 |
|
K.SendIndices( sent_gids, partner, comm ); |
655 |
|
} |
656 |
|
|
657 |
|
return split; |
658 |
|
}; |
659 |
|
|
660 |
|
|
661 |
|
}; /** end struct randomsplit */ |
662 |
|
|
663 |
|
|
664 |
|
|
665 |
|
|
666 |
|
|
667 |
|
|
668 |
|
|
669 |
|
|
670 |
|
|
671 |
|
|
672 |
|
|
673 |
|
|
674 |
|
|
675 |
|
|
676 |
|
|
677 |
|
|
678 |
|
|
679 |
|
|
680 |
|
|
681 |
|
|
682 |
|
|
683 |
|
/** |
684 |
|
* @brief Compute skeleton weights. |
685 |
|
* |
686 |
|
* |
687 |
|
*/ |
688 |
|
template<typename NODE> |
689 |
|
void DistUpdateWeights( NODE *node ) |
690 |
|
{ |
691 |
|
/** Derive type T from NODE. */ |
692 |
|
using T = typename NODE::T; |
693 |
|
/** MPI Support. */ |
694 |
|
mpi::Status status; |
695 |
|
auto comm = node->GetComm(); |
696 |
|
int size = node->GetCommSize(); |
697 |
|
int rank = node->GetCommRank(); |
698 |
|
|
699 |
|
/** Early return if this is the root or there is no skeleton. */ |
700 |
|
if ( !node->parent || !node->data.isskel ) return; |
701 |
|
|
702 |
|
if ( size < 2 ) |
703 |
|
{ |
704 |
|
/** This is the root of the local tree. */ |
705 |
|
gofmm::UpdateWeights( node ); |
706 |
|
} |
707 |
|
else |
708 |
|
{ |
709 |
|
/** Gather shared data and create reference. */ |
710 |
|
auto &w = *node->setup->w; |
711 |
|
size_t nrhs = w.col(); |
712 |
|
|
713 |
|
/** gather per node data and create reference */ |
714 |
|
auto &data = node->data; |
715 |
|
auto &proj = data.proj; |
716 |
|
auto &w_skel = data.w_skel; |
717 |
|
|
718 |
|
/** This is the corresponding MPI rank. */ |
719 |
|
if ( rank == 0 ) |
720 |
|
{ |
721 |
|
size_t s = proj.row(); |
722 |
|
size_t sl = node->child->data.skels.size(); |
723 |
|
size_t sr = proj.col() - sl; |
724 |
|
/** w_skel is s-by-nrhs, initial values are not important. */ |
725 |
|
w_skel.resize( s, nrhs ); |
726 |
|
/** Create matrix views. */ |
727 |
|
View<T> P( false, proj ), PL, PR; |
728 |
|
View<T> W( false, w_skel ), WL( false, node->child->data.w_skel ); |
729 |
|
/** P = [ PL, PR ] */ |
730 |
|
P.Partition1x2( PL, PR, sl, LEFT ); |
731 |
|
/** W = PL * WL */ |
732 |
|
gemm::xgemm<GEMM_NB>( (T)1.0, PL, WL, (T)0.0, W ); |
733 |
|
|
734 |
|
Data<T> w_skel_sib; |
735 |
|
mpi::ExchangeVector( w_skel, size / 2, 0, w_skel_sib, size / 2, 0, comm, &status ); |
736 |
|
/** Reduce from my sibling. */ |
737 |
|
#pragma omp parallel for |
738 |
|
for ( size_t i = 0; i < w_skel.size(); i ++ ) |
739 |
|
w_skel[ i ] += w_skel_sib[ i ]; |
740 |
|
} |
741 |
|
|
742 |
|
/** The rank that holds the skeleton weight of the right child. */ |
743 |
|
if ( rank == size / 2 ) |
744 |
|
{ |
745 |
|
size_t s = proj.row(); |
746 |
|
size_t sr = node->child->data.skels.size(); |
747 |
|
size_t sl = proj.col() - sr; |
748 |
|
/** w_skel is s-by-nrhs, initial values are not important. */ |
749 |
|
w_skel.resize( s, nrhs ); |
750 |
|
/** Create a transpose view proj_v */ |
751 |
|
View<T> P( false, proj ), PL, PR; |
752 |
|
View<T> W( false, w_skel ), WR( false, node->child->data.w_skel ); |
753 |
|
/** P = [ PL, PR ] */ |
754 |
|
P.Partition1x2( PL, PR, sl, LEFT ); |
755 |
|
/** W += PR * WR */ |
756 |
|
gemm::xgemm<GEMM_NB>( (T)1.0, PR, WR, (T)0.0, W ); |
757 |
|
|
758 |
|
|
759 |
|
Data<T> w_skel_sib; |
760 |
|
mpi::ExchangeVector( w_skel, 0, 0, w_skel_sib, 0, 0, comm, &status ); |
761 |
|
w_skel.clear(); |
762 |
|
} |
763 |
|
} |
764 |
|
}; /** end DistUpdateWeights() */ |
765 |
|
|
766 |
|
|
767 |
|
|
768 |
|
|
769 |
|
/** |
770 |
|
* @brief Notice that NODE here is MPITree::Node. |
771 |
|
*/ |
772 |
|
template<typename NODE, typename T> |
773 |
|
class DistUpdateWeightsTask : public Task |
774 |
|
{ |
775 |
|
public: |
776 |
|
|
777 |
|
NODE *arg = NULL; |
778 |
|
|
779 |
|
void Set( NODE *user_arg ) |
780 |
|
{ |
781 |
|
arg = user_arg; |
782 |
|
name = string( "DistN2S" ); |
783 |
|
label = to_string( arg->treelist_id ); |
784 |
|
|
785 |
|
/** Compute FLOPS and MOPS */ |
786 |
|
double flops = 0.0, mops = 0.0; |
787 |
|
auto &gids = arg->gids; |
788 |
|
auto &skels = arg->data.skels; |
789 |
|
auto &w = *arg->setup->w; |
790 |
|
|
791 |
|
if ( !arg->child ) |
792 |
|
{ |
793 |
|
if ( arg->isleaf ) |
794 |
|
{ |
795 |
|
auto m = skels.size(); |
796 |
|
auto n = w.col(); |
797 |
|
auto k = gids.size(); |
798 |
|
flops = 2.0 * m * n * k; |
799 |
|
mops = 2.0 * ( m * n + m * k + k * n ); |
800 |
|
} |
801 |
|
else |
802 |
|
{ |
803 |
|
auto &lskels = arg->lchild->data.skels; |
804 |
|
auto &rskels = arg->rchild->data.skels; |
805 |
|
auto m = skels.size(); |
806 |
|
auto n = w.col(); |
807 |
|
auto k = lskels.size() + rskels.size(); |
808 |
|
flops = 2.0 * m * n * k; |
809 |
|
mops = 2.0 * ( m * n + m * k + k * n ); |
810 |
|
} |
811 |
|
} |
812 |
|
else |
813 |
|
{ |
814 |
|
if ( arg->GetCommRank() == 0 ) |
815 |
|
{ |
816 |
|
auto &lskels = arg->child->data.skels; |
817 |
|
auto m = skels.size(); |
818 |
|
auto n = w.col(); |
819 |
|
auto k = lskels.size(); |
820 |
|
flops = 2.0 * m * n * k; |
821 |
|
mops = 2.0 * ( m * n + m * k + k * n ); |
822 |
|
} |
823 |
|
if ( arg->GetCommRank() == arg->GetCommSize() / 2 ) |
824 |
|
{ |
825 |
|
auto &rskels = arg->child->data.skels; |
826 |
|
auto m = skels.size(); |
827 |
|
auto n = w.col(); |
828 |
|
auto k = rskels.size(); |
829 |
|
flops = 2.0 * m * n * k; |
830 |
|
mops = 2.0 * ( m * n + m * k + k * n ); |
831 |
|
} |
832 |
|
} |
833 |
|
|
834 |
|
/** Setup the event */ |
835 |
|
event.Set( label + name, flops, mops ); |
836 |
|
/** Assume computation bound */ |
837 |
|
cost = flops / 1E+9; |
838 |
|
/** "HIGH" priority (critical path) */ |
839 |
|
priority = true; |
840 |
|
}; |
841 |
|
|
842 |
|
void DependencyAnalysis() { arg->DependOnChildren( this ); }; |
843 |
|
|
844 |
|
void Execute( Worker* user_worker ) { DistUpdateWeights( arg ); }; |
845 |
|
|
846 |
|
}; /** end class DistUpdateWeightsTask */ |
847 |
|
|
848 |
|
|
849 |
|
|
850 |
|
|
851 |
|
/** |
852 |
|
* |
853 |
|
*/ |
854 |
|
//template<bool NNPRUNE, typename NODE, typename T> |
855 |
|
//class DistSkeletonsToSkeletonsTask : public Task |
856 |
|
//{ |
857 |
|
// public: |
858 |
|
// |
859 |
|
// NODE *arg = NULL; |
860 |
|
// |
861 |
|
// void Set( NODE *user_arg ) |
862 |
|
// { |
863 |
|
// arg = user_arg; |
864 |
|
// name = string( "DistS2S" ); |
865 |
|
// label = to_string( arg->treelist_id ); |
866 |
|
// /** compute flops and mops */ |
867 |
|
// double flops = 0.0, mops = 0.0; |
868 |
|
// auto &w = *arg->setup->w; |
869 |
|
// size_t m = arg->data.skels.size(); |
870 |
|
// size_t n = w.col(); |
871 |
|
// |
872 |
|
// auto *FarNodes = &arg->FarNodes; |
873 |
|
// if ( NNPRUNE ) FarNodes = &arg->NNFarNodes; |
874 |
|
// |
875 |
|
// for ( auto it = FarNodes->begin(); it != FarNodes->end(); it ++ ) |
876 |
|
// { |
877 |
|
// size_t k = (*it)->data.skels.size(); |
878 |
|
// flops += 2.0 * m * n * k; |
879 |
|
// mops += m * k; // cost of Kab |
880 |
|
// mops += 2.0 * ( m * n + n * k + k * n ); |
881 |
|
// } |
882 |
|
// |
883 |
|
// /** setup the event */ |
884 |
|
// event.Set( label + name, flops, mops ); |
885 |
|
// |
886 |
|
// /** assume computation bound */ |
887 |
|
// cost = flops / 1E+9; |
888 |
|
// |
889 |
|
// /** "LOW" priority */ |
890 |
|
// priority = false; |
891 |
|
// }; |
892 |
|
// |
893 |
|
// |
894 |
|
// |
895 |
|
// void DependencyAnalysis() |
896 |
|
// { |
897 |
|
// for ( auto p : arg->data.FarDependents ) |
898 |
|
// hmlp_msg_dependency_analysis( 306, p, R, this ); |
899 |
|
// |
900 |
|
// auto *FarNodes = &arg->FarNodes; |
901 |
|
// if ( NNPRUNE ) FarNodes = &arg->NNFarNodes; |
902 |
|
// for ( auto it : *FarNodes ) it->DependencyAnalysis( R, this ); |
903 |
|
// |
904 |
|
// arg->DependencyAnalysis( RW, this ); |
905 |
|
// this->TryEnqueue(); |
906 |
|
// }; |
907 |
|
// |
908 |
|
// /** |
909 |
|
// * @brief Notice that S2S depends on all Far interactions, which |
910 |
|
// * may include local tree nodes or let nodes. |
911 |
|
// * For HSS case, the only Far interaction is the sibling. |
912 |
|
// * Skeleton weight of the sibling will always be exchanged |
913 |
|
// * by default in N2S. Thus, currently we do not need |
914 |
|
// * a distributed S2S, because the skeleton weight is already |
915 |
|
// * in place. |
916 |
|
// * |
917 |
|
// */ |
918 |
|
// void Execute( Worker* user_worker ) |
919 |
|
// { |
920 |
|
// auto *node = arg; |
921 |
|
// /** MPI Support. */ |
922 |
|
// auto comm = node->GetComm(); |
923 |
|
// auto size = node->GetCommSize(); |
924 |
|
// auto rank = node->GetCommRank(); |
925 |
|
// |
926 |
|
// if ( size < 2 ) |
927 |
|
// { |
928 |
|
// gofmm::SkeletonsToSkeletons<NNPRUNE, NODE, T>( node ); |
929 |
|
// } |
930 |
|
// else |
931 |
|
// { |
932 |
|
// /** Only 0th rank (owner) will execute this task. */ |
933 |
|
// if ( rank == 0 ) gofmm::SkeletonsToSkeletons<NNPRUNE, NODE, T>( node ); |
934 |
|
// } |
935 |
|
// }; |
936 |
|
// |
937 |
|
//}; /** end class DistSkeletonsToSkeletonsTask */ |
938 |
|
// |
939 |
|
|
940 |
|
template<typename NODE, typename LETNODE, typename T> |
941 |
|
class S2STask2 : public Task |
942 |
|
{ |
943 |
|
public: |
944 |
|
|
945 |
|
NODE *arg = NULL; |
946 |
|
|
947 |
|
vector<LETNODE*> Sources; |
948 |
|
|
949 |
|
int p = 0; |
950 |
|
|
951 |
|
Lock *lock = NULL; |
952 |
|
|
953 |
|
int *num_arrived_subtasks; |
954 |
|
|
955 |
|
void Set( NODE *user_arg, vector<LETNODE*> user_src, int user_p, Lock *user_lock, |
956 |
|
int *user_num_arrived_subtasks ) |
957 |
|
{ |
958 |
|
arg = user_arg; |
959 |
|
Sources = user_src; |
960 |
|
p = user_p; |
961 |
|
lock = user_lock; |
962 |
|
num_arrived_subtasks = user_num_arrived_subtasks; |
963 |
|
name = string( "S2S" ); |
964 |
|
label = to_string( arg->treelist_id ); |
965 |
|
|
966 |
|
/** Compute FLOPS and MOPS */ |
967 |
|
double flops = 0.0, mops = 0.0; |
968 |
|
size_t nrhs = arg->setup->w->col(); |
969 |
|
size_t m = arg->data.skels.size(); |
970 |
|
for ( auto src : Sources ) |
971 |
|
{ |
972 |
|
size_t k = src->data.skels.size(); |
973 |
|
flops += 2 * m * k * nrhs; |
974 |
|
mops += 2 * ( m * k + ( m + k ) * nrhs ); |
975 |
|
flops += 2 * m * nrhs; |
976 |
|
flops += m * k * ( 2 * 18 + 100 ); |
977 |
|
} |
978 |
|
/** Setup the event */ |
979 |
|
event.Set( label + name, flops, mops ); |
980 |
|
/** Assume computation bound */ |
981 |
|
cost = flops / 1E+9; |
982 |
|
/** Assume computation bound */ |
983 |
|
if ( arg->treelist_id == 0 ) priority = true; |
984 |
|
}; |
985 |
|
|
986 |
|
void DependencyAnalysis() |
987 |
|
{ |
988 |
|
if ( p == hmlp_get_mpi_rank() ) |
989 |
|
{ |
990 |
|
for ( auto src : Sources ) src->DependencyAnalysis( R, this ); |
991 |
|
} |
992 |
|
else hmlp_msg_dependency_analysis( 306, p, R, this ); |
993 |
|
this->TryEnqueue(); |
994 |
|
}; |
995 |
|
|
996 |
|
void Execute( Worker* user_worker ) |
997 |
|
{ |
998 |
|
auto *node = arg; |
999 |
|
if ( !node->parent || !node->data.isskel ) return; |
1000 |
|
size_t nrhs = node->setup->w->col(); |
1001 |
|
auto &K = *node->setup->K; |
1002 |
|
auto &I = node->data.skels; |
1003 |
|
|
1004 |
|
/** Temporary buffer */ |
1005 |
|
Data<T> u( I.size(), nrhs, 0.0 ); |
1006 |
|
|
1007 |
|
for ( auto src : Sources ) |
1008 |
|
{ |
1009 |
|
auto &J = src->data.skels; |
1010 |
|
auto &w = src->data.w_skel; |
1011 |
|
bool is_cached = true; |
1012 |
|
|
1013 |
|
auto &KIJ = node->DistFar[ p ][ src->morton ]; |
1014 |
|
if ( KIJ.row() != I.size() || KIJ.col() != J.size() ) |
1015 |
|
{ |
1016 |
|
//printf( "KIJ %lu %lu I %lu J %lu\n", KIJ.row(), KIJ.col(), I.size(), J.size() ); |
1017 |
|
KIJ = K( I, J ); |
1018 |
|
is_cached = false; |
1019 |
|
} |
1020 |
|
|
1021 |
|
assert( w.col() == nrhs ); |
1022 |
|
assert( w.row() == J.size() ); |
1023 |
|
//xgemm |
1024 |
|
//( |
1025 |
|
// "N", "N", u.row(), u.col(), w.row(), |
1026 |
|
// 1.0, KIJ.data(), KIJ.row(), |
1027 |
|
// w.data(), w.row(), |
1028 |
|
// 1.0, u.data(), u.row() |
1029 |
|
//); |
1030 |
|
gemm::xgemm( (T)1.0, KIJ, w, (T)1.0, u ); |
1031 |
|
|
1032 |
|
/** Free KIJ, if !is_cached. */ |
1033 |
|
if ( !is_cached ) |
1034 |
|
{ |
1035 |
|
KIJ.resize( 0, 0 ); |
1036 |
|
KIJ.shrink_to_fit(); |
1037 |
|
} |
1038 |
|
} |
1039 |
|
|
1040 |
|
lock->Acquire(); |
1041 |
|
{ |
1042 |
|
auto &u_skel = node->data.u_skel; |
1043 |
|
for ( int i = 0; i < u.size(); i ++ ) |
1044 |
|
u_skel[ i ] += u[ i ]; |
1045 |
|
} |
1046 |
|
lock->Release(); |
1047 |
|
#pragma omp atomic update |
1048 |
|
*num_arrived_subtasks += 1; |
1049 |
|
}; |
1050 |
|
}; |
1051 |
|
|
1052 |
|
template<typename NODE, typename LETNODE, typename T> |
1053 |
|
class S2SReduceTask2 : public Task |
1054 |
|
{ |
1055 |
|
public: |
1056 |
|
|
1057 |
|
NODE *arg = NULL; |
1058 |
|
|
1059 |
|
vector<S2STask2<NODE, LETNODE, T>*> subtasks; |
1060 |
|
|
1061 |
|
Lock lock; |
1062 |
|
|
1063 |
|
int num_arrived_subtasks = 0; |
1064 |
|
|
1065 |
|
const size_t batch_size = 2; |
1066 |
|
|
1067 |
|
void Set( NODE *user_arg ) |
1068 |
|
{ |
1069 |
|
arg = user_arg; |
1070 |
|
name = string( "S2SR" ); |
1071 |
|
label = to_string( arg->treelist_id ); |
1072 |
|
|
1073 |
|
/** Reset u_skel */ |
1074 |
|
if ( arg ) |
1075 |
|
{ |
1076 |
|
size_t nrhs = arg->setup->w->col(); |
1077 |
|
auto &I = arg->data.skels; |
1078 |
|
arg->data.u_skel.resize( 0, 0 ); |
1079 |
|
arg->data.u_skel.resize( I.size(), nrhs, 0 ); |
1080 |
|
} |
1081 |
|
|
1082 |
|
/** Create subtasks */ |
1083 |
|
for ( int p = 0; p < hmlp_get_mpi_size(); p ++ ) |
1084 |
|
{ |
1085 |
|
vector<LETNODE*> Sources; |
1086 |
|
for ( auto &it : arg->DistFar[ p ] ) |
1087 |
|
{ |
1088 |
|
Sources.push_back( (*arg->morton2node)[ it.first ] ); |
1089 |
|
if ( Sources.size() == batch_size ) |
1090 |
|
{ |
1091 |
|
subtasks.push_back( new S2STask2<NODE, LETNODE, T>() ); |
1092 |
|
subtasks.back()->Submit(); |
1093 |
|
subtasks.back()->Set( user_arg, Sources, p, &lock, &num_arrived_subtasks ); |
1094 |
|
subtasks.back()->DependencyAnalysis(); |
1095 |
|
Sources.clear(); |
1096 |
|
} |
1097 |
|
} |
1098 |
|
if ( Sources.size() ) |
1099 |
|
{ |
1100 |
|
subtasks.push_back( new S2STask2<NODE, LETNODE, T>() ); |
1101 |
|
subtasks.back()->Submit(); |
1102 |
|
subtasks.back()->Set( user_arg, Sources, p, &lock, &num_arrived_subtasks ); |
1103 |
|
subtasks.back()->DependencyAnalysis(); |
1104 |
|
Sources.clear(); |
1105 |
|
} |
1106 |
|
} |
1107 |
|
/** Compute FLOPS and MOPS. */ |
1108 |
|
double flops = 0, mops = 0; |
1109 |
|
/** Setup the event */ |
1110 |
|
event.Set( label + name, flops, mops ); |
1111 |
|
/** Assume computation bound */ |
1112 |
|
priority = true; |
1113 |
|
}; |
1114 |
|
|
1115 |
|
void DependencyAnalysis() |
1116 |
|
{ |
1117 |
|
for ( auto task : subtasks ) Scheduler::DependencyAdd( task, this ); |
1118 |
|
arg->DependencyAnalysis( RW, this ); |
1119 |
|
this->TryEnqueue(); |
1120 |
|
}; |
1121 |
|
|
1122 |
|
void Execute( Worker* user_worker ) |
1123 |
|
{ |
1124 |
|
/** Place holder */ |
1125 |
|
assert( num_arrived_subtasks == subtasks.size() ); |
1126 |
|
}; |
1127 |
|
}; |
1128 |
|
|
1129 |
|
|
1130 |
|
|
1131 |
|
|
1132 |
|
|
1133 |
|
|
1134 |
|
|
1135 |
|
|
1136 |
|
|
1137 |
|
|
1138 |
|
|
1139 |
|
|
1140 |
|
|
1141 |
|
|
1142 |
|
|
1143 |
|
|
1144 |
|
|
1145 |
|
|
1146 |
|
|
1147 |
|
|
1148 |
|
template<bool NNPRUNE, typename NODE, typename T> |
1149 |
|
void DistSkeletonsToNodes( NODE *node ) |
1150 |
|
{ |
1151 |
|
/** MPI Support. */ |
1152 |
|
auto comm = node->GetComm(); |
1153 |
|
auto size = node->GetCommSize(); |
1154 |
|
auto rank = node->GetCommRank(); |
1155 |
|
mpi::Status status; |
1156 |
|
|
1157 |
|
/** gather shared data and create reference */ |
1158 |
|
auto &K = *node->setup->K; |
1159 |
|
auto &w = *node->setup->w; |
1160 |
|
|
1161 |
|
|
1162 |
|
size_t nrhs = w.col(); |
1163 |
|
|
1164 |
|
|
1165 |
|
/** Early return if this is the root or has no skeleton. */ |
1166 |
|
if ( !node->parent || !node->data.isskel ) return; |
1167 |
|
|
1168 |
|
if ( size < 2 ) |
1169 |
|
{ |
1170 |
|
/** Call the shared-memory implementation. */ |
1171 |
|
gofmm::SkeletonsToNodes( node ); |
1172 |
|
} |
1173 |
|
else |
1174 |
|
{ |
1175 |
|
auto &data = node->data; |
1176 |
|
auto &proj = data.proj; |
1177 |
|
auto &u_skel = data.u_skel; |
1178 |
|
|
1179 |
|
if ( rank == 0 ) |
1180 |
|
{ |
1181 |
|
size_t sl = node->child->data.skels.size(); |
1182 |
|
size_t sr = proj.col() - sl; |
1183 |
|
/** Send u_skel to my sibling. */ |
1184 |
|
mpi::SendVector( u_skel, size / 2, 0, comm ); |
1185 |
|
/** Create a transpose matrix view for proj. */ |
1186 |
|
View<T> P( true, proj ), PL, PR; |
1187 |
|
View<T> U( false, u_skel ), UL( false, node->child->data.u_skel ); |
1188 |
|
/** P' = [ PL, PR ]' */ |
1189 |
|
P.Partition2x1( PL, |
1190 |
|
PR, sl, TOP ); |
1191 |
|
/** UL += PL' * U */ |
1192 |
|
gemm::xgemm<GEMM_NB>( (T)1.0, PL, U, (T)1.0, UL ); |
1193 |
|
} |
1194 |
|
|
1195 |
|
/** */ |
1196 |
|
if ( rank == size / 2 ) |
1197 |
|
{ |
1198 |
|
size_t s = proj.row(); |
1199 |
|
size_t sr = node->child->data.skels.size(); |
1200 |
|
size_t sl = proj.col() - sr; |
1201 |
|
/** Receive u_skel from my sibling. */ |
1202 |
|
mpi::RecvVector( u_skel, 0, 0, comm, &status ); |
1203 |
|
u_skel.resize( s, nrhs ); |
1204 |
|
/** create a transpose view proj_v */ |
1205 |
|
View<T> P( true, proj ), PL, PR; |
1206 |
|
View<T> U( false, u_skel ), UR( false, node->child->data.u_skel ); |
1207 |
|
/** P' = [ PL, PR ]' */ |
1208 |
|
P.Partition2x1( PL, |
1209 |
|
PR, sl, TOP ); |
1210 |
|
/** UR += PR' * U */ |
1211 |
|
gemm::xgemm<GEMM_NB>( (T)1.0, PR, U, (T)1.0, UR ); |
1212 |
|
} |
1213 |
|
} |
1214 |
|
}; /** end DistSkeletonsToNodes() */ |
1215 |
|
|
1216 |
|
|
1217 |
|
|
1218 |
|
|
1219 |
|
|
1220 |
|
template<bool NNPRUNE, typename NODE, typename T> |
1221 |
|
class DistSkeletonsToNodesTask : public Task |
1222 |
|
{ |
1223 |
|
public: |
1224 |
|
|
1225 |
|
NODE *arg; |
1226 |
|
|
1227 |
|
void Set( NODE *user_arg ) |
1228 |
|
{ |
1229 |
|
arg = user_arg; |
1230 |
|
name = string( "PS2N" ); |
1231 |
|
label = to_string( arg->l ); |
1232 |
|
|
1233 |
|
double flops = 0.0, mops = 0.0; |
1234 |
|
auto &gids = arg->gids; |
1235 |
|
auto &skels = arg->data.skels; |
1236 |
|
auto &w = *arg->setup->w; |
1237 |
|
|
1238 |
|
if ( !arg->child ) |
1239 |
|
{ |
1240 |
|
if ( arg->isleaf ) |
1241 |
|
{ |
1242 |
|
auto m = skels.size(); |
1243 |
|
auto n = w.col(); |
1244 |
|
auto k = gids.size(); |
1245 |
|
flops = 2.0 * m * n * k; |
1246 |
|
mops = 2.0 * ( m * n + m * k + k * n ); |
1247 |
|
} |
1248 |
|
else |
1249 |
|
{ |
1250 |
|
auto &lskels = arg->lchild->data.skels; |
1251 |
|
auto &rskels = arg->rchild->data.skels; |
1252 |
|
auto m = skels.size(); |
1253 |
|
auto n = w.col(); |
1254 |
|
auto k = lskels.size() + rskels.size(); |
1255 |
|
flops = 2.0 * m * n * k; |
1256 |
|
mops = 2.0 * ( m * n + m * k + k * n ); |
1257 |
|
} |
1258 |
|
} |
1259 |
|
else |
1260 |
|
{ |
1261 |
|
if ( arg->GetCommRank() == 0 ) |
1262 |
|
{ |
1263 |
|
auto &lskels = arg->child->data.skels; |
1264 |
|
auto m = skels.size(); |
1265 |
|
auto n = w.col(); |
1266 |
|
auto k = lskels.size(); |
1267 |
|
flops = 2.0 * m * n * k; |
1268 |
|
mops = 2.0 * ( m * n + m * k + k * n ); |
1269 |
|
} |
1270 |
|
if ( arg->GetCommRank() == arg->GetCommSize() / 2 ) |
1271 |
|
{ |
1272 |
|
auto &rskels = arg->child->data.skels; |
1273 |
|
auto m = skels.size(); |
1274 |
|
auto n = w.col(); |
1275 |
|
auto k = rskels.size(); |
1276 |
|
flops = 2.0 * m * n * k; |
1277 |
|
mops = 2.0 * ( m * n + m * k + k * n ); |
1278 |
|
} |
1279 |
|
} |
1280 |
|
|
1281 |
|
/** Setup the event */ |
1282 |
|
event.Set( label + name, flops, mops ); |
1283 |
|
/** Asuume computation bound */ |
1284 |
|
cost = flops / 1E+9; |
1285 |
|
/** "HIGH" priority (critical path) */ |
1286 |
|
priority = true; |
1287 |
|
}; |
1288 |
|
|
1289 |
|
void DependencyAnalysis() { arg->DependOnParent( this ); }; |
1290 |
|
|
1291 |
|
void Execute( Worker* user_worker ) { DistSkeletonsToNodes<NNPRUNE, NODE, T>( arg ); }; |
1292 |
|
|
1293 |
|
}; /** end class DistSkeletonsToNodesTask */ |
1294 |
|
|
1295 |
|
|
1296 |
|
|
1297 |
|
template<typename NODE, typename T> |
1298 |
|
class L2LTask2 : public Task |
1299 |
|
{ |
1300 |
|
public: |
1301 |
|
|
1302 |
|
NODE *arg = NULL; |
1303 |
|
|
1304 |
|
/** A list of source node pointers. */ |
1305 |
|
vector<NODE*> Sources; |
1306 |
|
|
1307 |
|
int p = 0; |
1308 |
|
|
1309 |
|
/** Write lock */ |
1310 |
|
Lock *lock = NULL; |
1311 |
|
|
1312 |
|
int *num_arrived_subtasks; |
1313 |
|
|
1314 |
|
void Set( NODE *user_arg, vector<NODE*> user_src, int user_p, Lock *user_lock, |
1315 |
|
int* user_num_arrived_subtasks ) |
1316 |
|
{ |
1317 |
|
arg = user_arg; |
1318 |
|
Sources = user_src; |
1319 |
|
p = user_p; |
1320 |
|
lock = user_lock; |
1321 |
|
num_arrived_subtasks = user_num_arrived_subtasks; |
1322 |
|
name = string( "L2L" ); |
1323 |
|
label = to_string( arg->treelist_id ); |
1324 |
|
|
1325 |
|
/** Compute FLOPS and MOPS. */ |
1326 |
|
double flops = 0.0, mops = 0.0; |
1327 |
|
size_t nrhs = arg->setup->w->col(); |
1328 |
|
size_t m = arg->gids.size(); |
1329 |
|
for ( auto src : Sources ) |
1330 |
|
{ |
1331 |
|
size_t k = src->gids.size(); |
1332 |
|
flops += 2 * m * k * nrhs; |
1333 |
|
mops += 2 * ( m * k + ( m + k ) * nrhs ); |
1334 |
|
flops += 2 * m * nrhs; |
1335 |
|
flops += m * k * ( 2 * 18 + 100 ); |
1336 |
|
} |
1337 |
|
/** Setup the event */ |
1338 |
|
event.Set( label + name, flops, mops ); |
1339 |
|
/** Assume computation bound */ |
1340 |
|
cost = flops / 1E+9; |
1341 |
|
/** "LOW" priority */ |
1342 |
|
priority = false; |
1343 |
|
}; |
1344 |
|
|
1345 |
|
void DependencyAnalysis() |
1346 |
|
{ |
1347 |
|
/** If p is a distributed process, then depends on the message. */ |
1348 |
|
if ( p != hmlp_get_mpi_rank() ) |
1349 |
|
hmlp_msg_dependency_analysis( 300, p, R, this ); |
1350 |
|
this->TryEnqueue(); |
1351 |
|
}; |
1352 |
|
|
1353 |
|
void Execute( Worker* user_worker ) |
1354 |
|
{ |
1355 |
|
auto *node = arg; |
1356 |
|
size_t nrhs = node->setup->w->col(); |
1357 |
|
auto &K = *node->setup->K; |
1358 |
|
auto &I = node->gids; |
1359 |
|
|
1360 |
|
double beg = omp_get_wtime(); |
1361 |
|
/** Temporary buffer */ |
1362 |
|
Data<T> u( I.size(), nrhs, 0.0 ); |
1363 |
|
size_t k; |
1364 |
|
|
1365 |
|
for ( auto src : Sources ) |
1366 |
|
{ |
1367 |
|
/** Get W view of this treenode. (available for non-LET nodes) */ |
1368 |
|
View<T> &W = src->data.w_view; |
1369 |
|
Data<T> &w = src->data.w_leaf; |
1370 |
|
|
1371 |
|
bool is_cached = true; |
1372 |
|
auto &J = src->gids; |
1373 |
|
auto &KIJ = node->DistNear[ p ][ src->morton ]; |
1374 |
|
if ( KIJ.row() != I.size() || KIJ.col() != J.size() ) |
1375 |
|
{ |
1376 |
|
KIJ = K( I, J ); |
1377 |
|
is_cached = false; |
1378 |
|
} |
1379 |
|
|
1380 |
|
if ( W.col() == nrhs && W.row() == J.size() ) |
1381 |
|
{ |
1382 |
|
k += W.row(); |
1383 |
|
xgemm |
1384 |
|
( |
1385 |
|
"N", "N", u.row(), u.col(), W.row(), |
1386 |
|
1.0, KIJ.data(), KIJ.row(), |
1387 |
|
W.data(), W.ld(), |
1388 |
|
1.0, u.data(), u.row() |
1389 |
|
); |
1390 |
|
} |
1391 |
|
else |
1392 |
|
{ |
1393 |
|
k += w.row(); |
1394 |
|
xgemm |
1395 |
|
( |
1396 |
|
"N", "N", u.row(), u.col(), w.row(), |
1397 |
|
1.0, KIJ.data(), KIJ.row(), |
1398 |
|
w.data(), w.row(), |
1399 |
|
1.0, u.data(), u.row() |
1400 |
|
); |
1401 |
|
} |
1402 |
|
|
1403 |
|
/** Free KIJ, if !is_cached. */ |
1404 |
|
if ( !is_cached ) |
1405 |
|
{ |
1406 |
|
KIJ.resize( 0, 0 ); |
1407 |
|
KIJ.shrink_to_fit(); |
1408 |
|
} |
1409 |
|
} |
1410 |
|
|
1411 |
|
double lock_beg = omp_get_wtime(); |
1412 |
|
lock->Acquire(); |
1413 |
|
{ |
1414 |
|
/** Get U view of this treenode. */ |
1415 |
|
View<T> &U = node->data.u_view; |
1416 |
|
for ( int j = 0; j < u.col(); j ++ ) |
1417 |
|
for ( int i = 0; i < u.row(); i ++ ) |
1418 |
|
U( i, j ) += u( i, j ); |
1419 |
|
} |
1420 |
|
lock->Release(); |
1421 |
|
double lock_time = omp_get_wtime() - lock_beg; |
1422 |
|
|
1423 |
|
double gemm_time = omp_get_wtime() - beg; |
1424 |
|
double GFLOPS = 2.0 * u.row() * u.col() * k / ( 1E+9 * gemm_time ); |
1425 |
|
//printf( "GEMM %4lu %4lu %4lu %lf GFLOPS, lock(%lf/%lf)\n", |
1426 |
|
// u.row(), u.col(), k, GFLOPS, lock_time, gemm_time ); fflush( stdout ); |
1427 |
|
#pragma omp atomic update |
1428 |
|
*num_arrived_subtasks += 1; |
1429 |
|
}; |
1430 |
|
}; |
1431 |
|
|
1432 |
|
|
1433 |
|
|
1434 |
|
|
1435 |
|
template<typename NODE, typename T> |
1436 |
|
class L2LReduceTask2 : public Task |
1437 |
|
{ |
1438 |
|
public: |
1439 |
|
|
1440 |
|
NODE *arg = NULL; |
1441 |
|
|
1442 |
|
vector<L2LTask2<NODE, T>*> subtasks; |
1443 |
|
|
1444 |
|
Lock lock; |
1445 |
|
|
1446 |
|
int num_arrived_subtasks = 0; |
1447 |
|
|
1448 |
|
const size_t batch_size = 2; |
1449 |
|
|
1450 |
|
void Set( NODE *user_arg ) |
1451 |
|
{ |
1452 |
|
arg = user_arg; |
1453 |
|
name = string( "L2LR" ); |
1454 |
|
label = to_string( arg->treelist_id ); |
1455 |
|
/** Create subtasks */ |
1456 |
|
for ( int p = 0; p < hmlp_get_mpi_size(); p ++ ) |
1457 |
|
{ |
1458 |
|
vector<NODE*> Sources; |
1459 |
|
for ( auto &it : arg->DistNear[ p ] ) |
1460 |
|
{ |
1461 |
|
Sources.push_back( (*arg->morton2node)[ it.first ] ); |
1462 |
|
if ( Sources.size() == batch_size ) |
1463 |
|
{ |
1464 |
|
subtasks.push_back( new L2LTask2<NODE, T>() ); |
1465 |
|
subtasks.back()->Submit(); |
1466 |
|
subtasks.back()->Set( user_arg, Sources, p, &lock, &num_arrived_subtasks ); |
1467 |
|
subtasks.back()->DependencyAnalysis(); |
1468 |
|
Sources.clear(); |
1469 |
|
} |
1470 |
|
} |
1471 |
|
if ( Sources.size() ) |
1472 |
|
{ |
1473 |
|
subtasks.push_back( new L2LTask2<NODE, T>() ); |
1474 |
|
subtasks.back()->Submit(); |
1475 |
|
subtasks.back()->Set( user_arg, Sources, p, &lock, &num_arrived_subtasks ); |
1476 |
|
subtasks.back()->DependencyAnalysis(); |
1477 |
|
Sources.clear(); |
1478 |
|
} |
1479 |
|
} |
1480 |
|
|
1481 |
|
|
1482 |
|
|
1483 |
|
|
1484 |
|
/** Compute FLOPS and MOPS */ |
1485 |
|
double flops = 0, mops = 0; |
1486 |
|
/** Setup the event */ |
1487 |
|
event.Set( label + name, flops, mops ); |
1488 |
|
/** "LOW" priority (critical path) */ |
1489 |
|
priority = false; |
1490 |
|
}; |
1491 |
|
|
1492 |
|
void DependencyAnalysis() |
1493 |
|
{ |
1494 |
|
for ( auto task : subtasks ) Scheduler::DependencyAdd( task, this ); |
1495 |
|
arg->DependencyAnalysis( RW, this ); |
1496 |
|
this->TryEnqueue(); |
1497 |
|
}; |
1498 |
|
|
1499 |
|
void Execute( Worker* user_worker ) |
1500 |
|
{ |
1501 |
|
assert( num_arrived_subtasks == subtasks.size() ); |
1502 |
|
}; |
1503 |
|
}; |
1504 |
|
|
1505 |
|
|
1506 |
|
|
1507 |
|
|
1508 |
|
|
1509 |
|
|
1510 |
|
|
1511 |
|
|
1512 |
|
|
1513 |
|
|
1514 |
|
|
1515 |
|
|
1516 |
|
|
1517 |
|
|
1518 |
|
|
1519 |
|
|
1520 |
|
|
1521 |
|
/** |
1522 |
|
* @brief (FMM specific) Compute Near( leaf nodes ). This is just like |
1523 |
|
* the neighbor list but the granularity is in nodes but not points. |
1524 |
|
* The algorithm is to compute the node morton ids of neighbor points. |
1525 |
|
* Get the pointers of these nodes and insert them into a std::set. |
1526 |
|
* std::set will automatic remove duplication. Here the insertion |
1527 |
|
* will be performed twice each time to get a symmetric one. That is |
1528 |
|
* if alpha has beta in its list, then beta will also have alpha in |
1529 |
|
* its list. |
1530 |
|
* |
1531 |
|
* Only leaf nodes will have the list `` NearNodes''. |
1532 |
|
* |
1533 |
|
* This list will later be used to get the FarNodes using a recursive |
1534 |
|
* node traversal scheme. |
1535 |
|
* |
1536 |
|
*/ |
1537 |
|
template<typename TREE> |
1538 |
|
void FindNearInteractions( TREE &tree ) |
1539 |
|
{ |
1540 |
|
mpi::PrintProgress( "[BEG] Finish FindNearInteractions ...", tree.GetComm() ); |
1541 |
|
/** Derive type NODE from TREE. */ |
1542 |
|
using NODE = typename TREE::NODE; |
1543 |
|
auto &setup = tree.setup; |
1544 |
|
auto &NN = *setup.NN; |
1545 |
|
double budget = setup.Budget(); |
1546 |
|
size_t n_leafs = ( 1 << tree.depth ); |
1547 |
|
/** |
1548 |
|
* The type here is tree::Node but not mpitree::Node. |
1549 |
|
* NearNodes and NNNearNodes also take tree::Node. |
1550 |
|
* This is ok, because they will only contain leaf nodes, |
1551 |
|
* which will never be distributed. |
1552 |
|
* However, FarNodes and NNFarNodes may contain distributed |
1553 |
|
* tree nodes. In this case, we have to do type casting. |
1554 |
|
*/ |
1555 |
|
auto level_beg = tree.treelist.begin() + n_leafs - 1; |
1556 |
|
|
1557 |
|
/** Traverse all leaf nodes. **/ |
1558 |
|
#pragma omp parallel for |
1559 |
|
for ( size_t node_ind = 0; node_ind < n_leafs; node_ind ++ ) |
1560 |
|
{ |
1561 |
|
auto *node = *(level_beg + node_ind); |
1562 |
|
auto &data = node->data; |
1563 |
|
size_t n_nodes = ( 1 << node->l ); |
1564 |
|
|
1565 |
|
/** Add myself to the near interaction list. */ |
1566 |
|
node->NNNearNodes.insert( node ); |
1567 |
|
node->NNNearNodeMortonIDs.insert( node->morton ); |
1568 |
|
|
1569 |
|
/** Compute ballots for all near interactions */ |
1570 |
|
multimap<size_t, size_t> sorted_ballot = gofmm::NearNodeBallots( node ); |
1571 |
|
|
1572 |
|
/** Insert near node cadidates until reaching the budget limit. */ |
1573 |
|
for ( auto it = sorted_ballot.rbegin(); |
1574 |
|
it != sorted_ballot.rend(); it ++ ) |
1575 |
|
{ |
1576 |
|
/** Exit if we have enough near interactions. */ |
1577 |
|
if ( node->NNNearNodes.size() >= n_nodes * budget ) break; |
1578 |
|
|
1579 |
|
/** |
1580 |
|
* Get the node pointer from MortonID. |
1581 |
|
* |
1582 |
|
* Two situations: |
1583 |
|
* 1. the pointer doesn't exist, then creates a lettreenode |
1584 |
|
*/ |
1585 |
|
#pragma omp critical |
1586 |
|
{ |
1587 |
|
if ( !(*node->morton2node).count( (*it).second ) ) |
1588 |
|
{ |
1589 |
|
/** Create a LET node. */ |
1590 |
|
(*node->morton2node)[ (*it).second ] = new NODE( (*it).second ); |
1591 |
|
} |
1592 |
|
/** Insert */ |
1593 |
|
auto *target = (*node->morton2node)[ (*it).second ]; |
1594 |
|
node->NNNearNodeMortonIDs.insert( (*it).second ); |
1595 |
|
node->NNNearNodes.insert( target ); |
1596 |
|
} /** end pragma omp critical */ |
1597 |
|
} |
1598 |
|
} /** end for each leaf owned leaf node in the local tree */ |
1599 |
|
mpi::PrintProgress( "[END] Finish FindNearInteractions ...", tree.GetComm() ); |
1600 |
|
}; /** end FindNearInteractions() */ |
1601 |
|
|
1602 |
|
|
1603 |
|
|
1604 |
|
|
1605 |
|
template<typename NODE> |
1606 |
|
void FindFarNodes( MortonHelper::Recursor r, NODE *target ) |
1607 |
|
{ |
1608 |
|
/** Return while reaching the leaf level (recursion base case). */ |
1609 |
|
if ( r.second > target->l ) return; |
1610 |
|
/** Compute the MortonID of the visiting node. */ |
1611 |
|
size_t node_morton = MortonHelper::MortonID( r ); |
1612 |
|
|
1613 |
|
//bool prunable = true; |
1614 |
|
auto & NearMortonIDs = target->NNNearNodeMortonIDs; |
1615 |
|
|
1616 |
|
/** Recur to children if the current node contains near interactions. */ |
1617 |
|
if ( MortonHelper::ContainAny( node_morton, NearMortonIDs ) ) |
1618 |
|
{ |
1619 |
|
FindFarNodes( MortonHelper::RecurLeft( r ), target ); |
1620 |
|
FindFarNodes( MortonHelper::RecurRight( r ), target ); |
1621 |
|
} |
1622 |
|
else |
1623 |
|
{ |
1624 |
|
if ( node_morton >= target->morton ) |
1625 |
|
target->NNFarNodeMortonIDs.insert( node_morton ); |
1626 |
|
} |
1627 |
|
}; /** end FindFarNodes() */ |
1628 |
|
|
1629 |
|
|
1630 |
|
|
1631 |
|
|
1632 |
|
|
1633 |
|
|
1634 |
|
template<typename TREE> |
1635 |
|
void SymmetrizeNearInteractions( TREE & tree ) |
1636 |
|
{ |
1637 |
|
mpi::PrintProgress( "[BEG] SymmetrizeNearInteractions ...", tree.GetComm() ); |
1638 |
|
|
1639 |
|
/** Derive type NODE from TREE. */ |
1640 |
|
using NODE = typename TREE::NODE; |
1641 |
|
/** MPI Support */ |
1642 |
|
int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size ); |
1643 |
|
int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank ); |
1644 |
|
|
1645 |
|
vector<vector<pair<size_t, size_t>>> sendlist( comm_size ); |
1646 |
|
vector<vector<pair<size_t, size_t>>> recvlist( comm_size ); |
1647 |
|
|
1648 |
|
|
1649 |
|
/** |
1650 |
|
* Traverse local leaf nodes: |
1651 |
|
* |
1652 |
|
* Loop over all near node MortonIDs, create |
1653 |
|
* |
1654 |
|
*/ |
1655 |
|
int n_nodes = 1 << tree.depth; |
1656 |
|
auto level_beg = tree.treelist.begin() + n_nodes - 1; |
1657 |
|
|
1658 |
|
#pragma omp parallel |
1659 |
|
{ |
1660 |
|
/** Create a per thread list. Merge them into sendlist afterward. */ |
1661 |
|
vector<vector<pair<size_t, size_t>>> list( comm_size ); |
1662 |
|
|
1663 |
|
#pragma omp for |
1664 |
|
for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ ) |
1665 |
|
{ |
1666 |
|
auto *node = *(level_beg + node_ind); |
1667 |
|
//auto & NearMortonIDs = node->NNNearNodeMortonIDs; |
1668 |
|
for ( auto it : node->NNNearNodeMortonIDs ) |
1669 |
|
{ |
1670 |
|
int dest = tree.Morton2Rank( it ); |
1671 |
|
if ( dest >= comm_size ) printf( "%8lu dest %d\n", it, dest ); |
1672 |
|
list[ dest ].push_back( make_pair( it, node->morton ) ); |
1673 |
|
} |
1674 |
|
} /** end pramga omp for */ |
1675 |
|
|
1676 |
|
#pragma omp critical |
1677 |
|
{ |
1678 |
|
for ( int p = 0; p < comm_size; p ++ ) |
1679 |
|
{ |
1680 |
|
sendlist[ p ].insert( sendlist[ p ].end(), |
1681 |
|
list[ p ].begin(), list[ p ].end() ); |
1682 |
|
} |
1683 |
|
} /** end pragma omp critical*/ |
1684 |
|
}; /** end pargma omp parallel */ |
1685 |
|
|
1686 |
|
|
1687 |
|
/** Alltoallv */ |
1688 |
|
mpi::AlltoallVector( sendlist, recvlist, tree.GetComm() ); |
1689 |
|
|
1690 |
|
|
1691 |
|
/** Loop over queries. */ |
1692 |
|
for ( int p = 0; p < comm_size; p ++ ) |
1693 |
|
{ |
1694 |
|
for ( auto & query : recvlist[ p ] ) |
1695 |
|
{ |
1696 |
|
/** Check if query node is allocated? */ |
1697 |
|
#pragma omp critical |
1698 |
|
{ |
1699 |
|
auto* node = tree.morton2node[ query.first ]; |
1700 |
|
if ( !tree.morton2node.count( query.second ) ) |
1701 |
|
{ |
1702 |
|
tree.morton2node[ query.second ] = new NODE( query.second ); |
1703 |
|
} |
1704 |
|
node->data.lock.Acquire(); |
1705 |
|
{ |
1706 |
|
node->NNNearNodes.insert( tree.morton2node[ query.second ] ); |
1707 |
|
node->NNNearNodeMortonIDs.insert( query.second ); |
1708 |
|
} |
1709 |
|
node->data.lock.Release(); |
1710 |
|
} |
1711 |
|
}; /** end pargma omp parallel for */ |
1712 |
|
} |
1713 |
|
mpi::Barrier( tree.GetComm() ); |
1714 |
|
mpi::PrintProgress( "[END] SymmetrizeNearInteractions ...", tree.GetComm() ); |
1715 |
|
}; /** end SymmetrizeNearInteractions() */ |
1716 |
|
|
1717 |
|
|
1718 |
|
template<typename TREE> |
1719 |
|
void SymmetrizeFarInteractions( TREE & tree ) |
1720 |
|
{ |
1721 |
|
mpi::PrintProgress( "[BEG] SymmetrizeFarInteractions ...", tree.GetComm() ); |
1722 |
|
|
1723 |
|
/** Derive type NODE from TREE. */ |
1724 |
|
using NODE = typename TREE::NODE; |
1725 |
|
///** MPI Support. */ |
1726 |
|
//int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size ); |
1727 |
|
//int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank ); |
1728 |
|
|
1729 |
|
vector<vector<pair<size_t, size_t>>> sendlist( tree.GetCommSize() ); |
1730 |
|
vector<vector<pair<size_t, size_t>>> recvlist( tree.GetCommSize() ); |
1731 |
|
|
1732 |
|
/** Local traversal */ |
1733 |
|
#pragma omp parallel |
1734 |
|
{ |
1735 |
|
/** Create a per thread list. Merge them into sendlist afterward. */ |
1736 |
|
vector<vector<pair<size_t, size_t>>> list( tree.GetCommSize() ); |
1737 |
|
|
1738 |
|
#pragma omp for |
1739 |
|
for ( size_t i = 1; i < tree.treelist.size(); i ++ ) |
1740 |
|
{ |
1741 |
|
auto *node = tree.treelist[ i ]; |
1742 |
|
for ( auto it = node->NNFarNodeMortonIDs.begin(); |
1743 |
|
it != node->NNFarNodeMortonIDs.end(); it ++ ) |
1744 |
|
{ |
1745 |
|
/** Allocate if not exist */ |
1746 |
|
#pragma omp critical |
1747 |
|
{ |
1748 |
|
if ( !tree.morton2node.count( *it ) ) |
1749 |
|
{ |
1750 |
|
tree.morton2node[ *it ] = new NODE( *it ); |
1751 |
|
} |
1752 |
|
node->NNFarNodes.insert( tree.morton2node[ *it ] ); |
1753 |
|
} |
1754 |
|
int dest = tree.Morton2Rank( *it ); |
1755 |
|
if ( dest >= tree.GetCommSize() ) printf( "%8lu dest %d\n", *it, dest ); |
1756 |
|
list[ dest ].push_back( make_pair( *it, node->morton ) ); |
1757 |
|
} |
1758 |
|
} |
1759 |
|
|
1760 |
|
#pragma omp critical |
1761 |
|
{ |
1762 |
|
for ( int p = 0; p < tree.GetCommSize(); p ++ ) |
1763 |
|
{ |
1764 |
|
sendlist[ p ].insert( sendlist[ p ].end(), |
1765 |
|
list[ p ].begin(), list[ p ].end() ); |
1766 |
|
} |
1767 |
|
} /** end pragma omp critical*/ |
1768 |
|
} |
1769 |
|
|
1770 |
|
|
1771 |
|
/** Distributed traversal */ |
1772 |
|
#pragma omp parallel |
1773 |
|
{ |
1774 |
|
/** Create a per thread list. Merge them into sendlist afterward. */ |
1775 |
|
vector<vector<pair<size_t, size_t>>> list( tree.GetCommSize() ); |
1776 |
|
|
1777 |
|
#pragma omp for |
1778 |
|
for ( size_t i = 0; i < tree.mpitreelists.size(); i ++ ) |
1779 |
|
{ |
1780 |
|
auto *node = tree.mpitreelists[ i ]; |
1781 |
|
for ( auto it = node->NNFarNodeMortonIDs.begin(); |
1782 |
|
it != node->NNFarNodeMortonIDs.end(); it ++ ) |
1783 |
|
{ |
1784 |
|
/** Allocate if not exist */ |
1785 |
|
#pragma omp critical |
1786 |
|
{ |
1787 |
|
if ( !tree.morton2node.count( *it ) ) |
1788 |
|
{ |
1789 |
|
tree.morton2node[ *it ] = new NODE( *it ); |
1790 |
|
} |
1791 |
|
node->NNFarNodes.insert( tree.morton2node[ *it ] ); |
1792 |
|
} |
1793 |
|
int dest = tree.Morton2Rank( *it ); |
1794 |
|
if ( dest >= tree.GetCommSize() ) printf( "%8lu dest %d\n", *it, dest ); fflush( stdout ); |
1795 |
|
list[ dest ].push_back( make_pair( *it, node->morton ) ); |
1796 |
|
} |
1797 |
|
} |
1798 |
|
|
1799 |
|
#pragma omp critical |
1800 |
|
{ |
1801 |
|
for ( int p = 0; p < tree.GetCommSize(); p ++ ) |
1802 |
|
{ |
1803 |
|
sendlist[ p ].insert( sendlist[ p ].end(), |
1804 |
|
list[ p ].begin(), list[ p ].end() ); |
1805 |
|
} |
1806 |
|
} /** end pragma omp critical*/ |
1807 |
|
} |
1808 |
|
|
1809 |
|
/** Alltoallv */ |
1810 |
|
mpi::AlltoallVector( sendlist, recvlist, tree.GetComm() ); |
1811 |
|
|
1812 |
|
/** Loop over queries */ |
1813 |
|
for ( int p = 0; p < tree.GetCommSize(); p ++ ) |
1814 |
|
{ |
1815 |
|
//#pragma omp parallel for |
1816 |
|
for ( auto & query : recvlist[ p ] ) |
1817 |
|
{ |
1818 |
|
/** Check if query node is allocated? */ |
1819 |
|
#pragma omp critical |
1820 |
|
{ |
1821 |
|
if ( !tree.morton2node.count( query.second ) ) |
1822 |
|
{ |
1823 |
|
tree.morton2node[ query.second ] = new NODE( query.second ); |
1824 |
|
//printf( "rank %d, %8lu level %lu creates far LET %8lu (symmetrize)\n", |
1825 |
|
// comm_rank, node->morton, node->l, query.second ); |
1826 |
|
} |
1827 |
|
auto* node = tree.morton2node[ query.first ]; |
1828 |
|
node->data.lock.Acquire(); |
1829 |
|
{ |
1830 |
|
node->NNFarNodes.insert( tree.morton2node[ query.second ] ); |
1831 |
|
node->NNFarNodeMortonIDs.insert( query.second ); |
1832 |
|
} |
1833 |
|
node->data.lock.Release(); |
1834 |
|
assert( tree.Morton2Rank( node->morton ) == tree.GetCommRank() ); |
1835 |
|
} /** end pragma omp critical */ |
1836 |
|
} /** end pargma omp parallel for */ |
1837 |
|
} |
1838 |
|
|
1839 |
|
mpi::Barrier( tree.GetComm() ); |
1840 |
|
mpi::PrintProgress( "[END] SymmetrizeFarInteractions ...", tree.GetComm() ); |
1841 |
|
}; /** end SymmetrizeFarInteractions() */ |
1842 |
|
|
1843 |
|
|
1844 |
|
|
1845 |
|
/** |
1846 |
|
* TODO: need send and recv interaction lists for each rank |
1847 |
|
* |
1848 |
|
* SendNNNear[ rank ][ local morton ] |
1849 |
|
* RecvNNNear[ rank ][ remote morton ] |
1850 |
|
* |
1851 |
|
* for each leaf alpha and beta in Near(alpha) |
1852 |
|
* SendNNNear[ rank(beta) ] += Morton(alpha) |
1853 |
|
* |
1854 |
|
* Alltoallv( SendNNNear, rbuff ); |
1855 |
|
* |
1856 |
|
* for each rank |
1857 |
|
* RecvNNNear[ rank ][ remote morton ] = offset in rbuff |
1858 |
|
* |
1859 |
|
*/ |
1860 |
|
template<typename TREE> |
1861 |
|
void BuildInteractionListPerRank( TREE &tree, bool is_near ) |
1862 |
|
{ |
1863 |
|
/** Derive type T from TREE. */ |
1864 |
|
using T = typename TREE::T; |
1865 |
|
/** MPI Support. */ |
1866 |
|
int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size ); |
1867 |
|
int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank ); |
1868 |
|
|
1869 |
|
/** Interaction set per rank in MortonID. */ |
1870 |
|
vector<set<size_t>> lists( comm_size ); |
1871 |
|
|
1872 |
|
if ( is_near ) |
1873 |
|
{ |
1874 |
|
/** Traverse leaf nodes (near interation lists) */ |
1875 |
|
int n_nodes = 1 << tree.depth; |
1876 |
|
auto level_beg = tree.treelist.begin() + n_nodes - 1; |
1877 |
|
|
1878 |
|
#pragma omp parallel |
1879 |
|
{ |
1880 |
|
/** Create a per thread list. Merge them into sendlist afterward. */ |
1881 |
|
vector<set<size_t>> list( comm_size ); |
1882 |
|
|
1883 |
|
#pragma omp for |
1884 |
|
for ( int node_ind = 0; node_ind < n_nodes; node_ind ++ ) |
1885 |
|
{ |
1886 |
|
auto *node = *(level_beg + node_ind); |
1887 |
|
auto & NearMortonIDs = node->NNNearNodeMortonIDs; |
1888 |
|
node->DistNear.resize( comm_size ); |
1889 |
|
for ( auto it : NearMortonIDs ) |
1890 |
|
{ |
1891 |
|
int dest = tree.Morton2Rank( it ); |
1892 |
|
if ( dest >= comm_size ) printf( "%8lu dest %d\n", it, dest ); |
1893 |
|
if ( dest != comm_rank ) list[ dest ].insert( node->morton ); |
1894 |
|
node->DistNear[ dest ][ it ] = Data<T>(); |
1895 |
|
} |
1896 |
|
} /** end pramga omp for */ |
1897 |
|
|
1898 |
|
#pragma omp critical |
1899 |
|
{ |
1900 |
|
for ( int p = 0; p < comm_size; p ++ ) |
1901 |
|
lists[ p ].insert( list[ p ].begin(), list[ p ].end() ); |
1902 |
|
} /** end pragma omp critical*/ |
1903 |
|
}; /** end pargma omp parallel */ |
1904 |
|
|
1905 |
|
|
1906 |
|
/** Cast set to vector. */ |
1907 |
|
vector<vector<size_t>> recvlist( comm_size ); |
1908 |
|
if ( !tree.NearSentToRank.size() ) tree.NearSentToRank.resize( comm_size ); |
1909 |
|
if ( !tree.NearRecvFromRank.size() ) tree.NearRecvFromRank.resize( comm_size ); |
1910 |
|
#pragma omp parallel for |
1911 |
|
for ( int p = 0; p < comm_size; p ++ ) |
1912 |
|
{ |
1913 |
|
tree.NearSentToRank[ p ].insert( tree.NearSentToRank[ p ].end(), |
1914 |
|
lists[ p ].begin(), lists[ p ].end() ); |
1915 |
|
} |
1916 |
|
|
1917 |
|
/** Use buffer recvlist to catch Alltoallv results. */ |
1918 |
|
mpi::AlltoallVector( tree.NearSentToRank, recvlist, tree.GetComm() ); |
1919 |
|
|
1920 |
|
/** Cast vector of vectors to vector of maps */ |
1921 |
|
#pragma omp parallel for |
1922 |
|
for ( int p = 0; p < comm_size; p ++ ) |
1923 |
|
for ( int i = 0; i < recvlist[ p ].size(); i ++ ) |
1924 |
|
tree.NearRecvFromRank[ p ][ recvlist[ p ][ i ] ] = i; |
1925 |
|
} |
1926 |
|
else |
1927 |
|
{ |
1928 |
|
#pragma omp parallel |
1929 |
|
{ |
1930 |
|
/** Create a per thread list. Merge them into sendlist afterward. */ |
1931 |
|
vector<set<size_t>> list( comm_size ); |
1932 |
|
|
1933 |
|
/** Local traversal */ |
1934 |
|
#pragma omp for |
1935 |
|
for ( size_t i = 1; i < tree.treelist.size(); i ++ ) |
1936 |
|
{ |
1937 |
|
auto *node = tree.treelist[ i ]; |
1938 |
|
node->DistFar.resize( comm_size ); |
1939 |
|
for ( auto it = node->NNFarNodeMortonIDs.begin(); |
1940 |
|
it != node->NNFarNodeMortonIDs.end(); it ++ ) |
1941 |
|
{ |
1942 |
|
int dest = tree.Morton2Rank( *it ); |
1943 |
|
if ( dest >= comm_size ) printf( "%8lu dest %d\n", *it, dest ); |
1944 |
|
if ( dest != comm_rank ) |
1945 |
|
{ |
1946 |
|
list[ dest ].insert( node->morton ); |
1947 |
|
//node->data.FarDependents.insert( dest ); |
1948 |
|
} |
1949 |
|
node->DistFar[ dest ][ *it ] = Data<T>(); |
1950 |
|
} |
1951 |
|
} |
1952 |
|
|
1953 |
|
/** Distributed traversal */ |
1954 |
|
#pragma omp for |
1955 |
|
for ( size_t i = 0; i < tree.mpitreelists.size(); i ++ ) |
1956 |
|
{ |
1957 |
|
auto *node = tree.mpitreelists[ i ]; |
1958 |
|
node->DistFar.resize( comm_size ); |
1959 |
|
/** Add to the list iff this MPI rank owns the distributed node */ |
1960 |
|
if ( tree.Morton2Rank( node->morton ) == comm_rank ) |
1961 |
|
{ |
1962 |
|
for ( auto it = node->NNFarNodeMortonIDs.begin(); |
1963 |
|
it != node->NNFarNodeMortonIDs.end(); it ++ ) |
1964 |
|
{ |
1965 |
|
int dest = tree.Morton2Rank( *it ); |
1966 |
|
if ( dest >= comm_size ) printf( "%8lu dest %d\n", *it, dest ); |
1967 |
|
if ( dest != comm_rank ) |
1968 |
|
{ |
1969 |
|
list[ dest ].insert( node->morton ); |
1970 |
|
//node->data.FarDependents.insert( dest ); |
1971 |
|
} |
1972 |
|
node->DistFar[ dest ][ *it ] = Data<T>(); |
1973 |
|
} |
1974 |
|
} |
1975 |
|
} |
1976 |
|
/** Merge lists from all threads */ |
1977 |
|
#pragma omp critical |
1978 |
|
{ |
1979 |
|
for ( int p = 0; p < comm_size; p ++ ) |
1980 |
|
lists[ p ].insert( list[ p ].begin(), list[ p ].end() ); |
1981 |
|
} /** end pragma omp critical*/ |
1982 |
|
|
1983 |
|
}; /** end pargma omp parallel */ |
1984 |
|
|
1985 |
|
/** Cast set to vector */ |
1986 |
|
vector<vector<size_t>> recvlist( comm_size ); |
1987 |
|
if ( !tree.FarSentToRank.size() ) tree.FarSentToRank.resize( comm_size ); |
1988 |
|
if ( !tree.FarRecvFromRank.size() ) tree.FarRecvFromRank.resize( comm_size ); |
1989 |
|
#pragma omp parallel for |
1990 |
|
for ( int p = 0; p < comm_size; p ++ ) |
1991 |
|
{ |
1992 |
|
tree.FarSentToRank[ p ].insert( tree.FarSentToRank[ p ].end(), |
1993 |
|
lists[ p ].begin(), lists[ p ].end() ); |
1994 |
|
} |
1995 |
|
|
1996 |
|
|
1997 |
|
/** Use buffer recvlist to catch Alltoallv results. */ |
1998 |
|
mpi::AlltoallVector( tree.FarSentToRank, recvlist, tree.GetComm() ); |
1999 |
|
|
2000 |
|
/** Cast vector of vectors to vector of maps */ |
2001 |
|
#pragma omp parallel for |
2002 |
|
for ( int p = 0; p < comm_size; p ++ ) |
2003 |
|
for ( int i = 0; i < recvlist[ p ].size(); i ++ ) |
2004 |
|
tree.FarRecvFromRank[ p ][ recvlist[ p ][ i ] ] = i; |
2005 |
|
} |
2006 |
|
|
2007 |
|
mpi::Barrier( tree.GetComm() ); |
2008 |
|
}; /** end BuildInteractionListPerRank() */ |
2009 |
|
|
2010 |
|
|
2011 |
|
template<typename TREE> |
2012 |
|
pair<double, double> NonCompressedRatio( TREE &tree ) |
2013 |
|
{ |
2014 |
|
/** Tree MPI communicator */ |
2015 |
|
int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size ); |
2016 |
|
int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank ); |
2017 |
|
|
2018 |
|
/** Use double for accumulation. */ |
2019 |
|
double ratio_n = 0.0; |
2020 |
|
double ratio_f = 0.0; |
2021 |
|
|
2022 |
|
|
2023 |
|
/** Traverse all nodes in the local tree. */ |
2024 |
|
for ( auto &tar : tree.treelist ) |
2025 |
|
{ |
2026 |
|
if ( tar->isleaf ) |
2027 |
|
{ |
2028 |
|
for ( auto nearID : tar->NNNearNodeMortonIDs ) |
2029 |
|
{ |
2030 |
|
auto *src = tree.morton2node[ nearID ]; |
2031 |
|
assert( src ); |
2032 |
|
double m = tar->gids.size(); |
2033 |
|
double n = src->gids.size(); |
2034 |
|
double N = tree.n; |
2035 |
|
ratio_n += ( m / N ) * ( n / N ); |
2036 |
|
} |
2037 |
|
} |
2038 |
|
|
2039 |
|
for ( auto farID : tar->NNFarNodeMortonIDs ) |
2040 |
|
{ |
2041 |
|
auto *src = tree.morton2node[ farID ]; |
2042 |
|
assert( src ); |
2043 |
|
double m = tar->data.skels.size(); |
2044 |
|
double n = src->data.skels.size(); |
2045 |
|
double N = tree.n; |
2046 |
|
ratio_f += ( m / N ) * ( n / N ); |
2047 |
|
} |
2048 |
|
} |
2049 |
|
|
2050 |
|
/** Traverse all nodes in the distributed tree. */ |
2051 |
|
for ( auto &tar : tree.mpitreelists ) |
2052 |
|
{ |
2053 |
|
if ( !tar->child || tar->GetCommRank() ) continue; |
2054 |
|
for ( auto farID : tar->NNFarNodeMortonIDs ) |
2055 |
|
{ |
2056 |
|
auto *src = tree.morton2node[ farID ]; |
2057 |
|
assert( src ); |
2058 |
|
double m = tar->data.skels.size(); |
2059 |
|
double n = src->data.skels.size(); |
2060 |
|
double N = tree.n; |
2061 |
|
ratio_f += ( m / N ) * ( n / N ); |
2062 |
|
} |
2063 |
|
} |
2064 |
|
|
2065 |
|
/** Allreduce total evaluations from all MPI processes. */ |
2066 |
|
pair<double, double> ret( 0, 0 ); |
2067 |
|
mpi::Allreduce( &ratio_n, &(ret.first), 1, MPI_SUM, tree.GetComm() ); |
2068 |
|
mpi::Allreduce( &ratio_f, &(ret.second), 1, MPI_SUM, tree.GetComm() ); |
2069 |
|
|
2070 |
|
return ret; |
2071 |
|
}; |
2072 |
|
|
2073 |
|
|
2074 |
|
|
2075 |
|
template<typename T, typename TREE> |
2076 |
|
void PackNear( TREE &tree, string option, int p, |
2077 |
|
vector<size_t> &sendsizes, |
2078 |
|
vector<size_t> &sendskels, |
2079 |
|
vector<T> &sendbuffs ) |
2080 |
|
{ |
2081 |
|
vector<size_t> offsets( 1, 0 ); |
2082 |
|
|
2083 |
|
for ( auto it : tree.NearSentToRank[ p ] ) |
2084 |
|
{ |
2085 |
|
auto *node = tree.morton2node[ it ]; |
2086 |
|
auto &gids = node->gids; |
2087 |
|
if ( !option.compare( string( "leafgids" ) ) ) |
2088 |
|
{ |
2089 |
|
sendsizes.push_back( gids.size() ); |
2090 |
|
sendskels.insert( sendskels.end(), gids.begin(), gids.end() ); |
2091 |
|
} |
2092 |
|
else |
2093 |
|
{ |
2094 |
|
auto &w_view = node->data.w_view; |
2095 |
|
sendsizes.push_back( gids.size() * w_view.col() ); |
2096 |
|
offsets.push_back( sendsizes.back() + offsets.back() ); |
2097 |
|
} |
2098 |
|
} |
2099 |
|
|
2100 |
|
if ( offsets.size() ) sendbuffs.resize( offsets.back() ); |
2101 |
|
|
2102 |
|
if ( !option.compare( string( "leafweights" ) ) ) |
2103 |
|
{ |
2104 |
|
#pragma omp parallel for |
2105 |
|
for ( size_t i = 0; i < tree.NearSentToRank[ p ].size(); i ++ ) |
2106 |
|
{ |
2107 |
|
auto *node = tree.morton2node[ tree.NearSentToRank[ p ][ i ] ]; |
2108 |
|
auto &gids = node->gids; |
2109 |
|
auto &w_view = node->data.w_view; |
2110 |
|
auto w_leaf = w_view.toData(); |
2111 |
|
size_t offset = offsets[ i ]; |
2112 |
|
for ( size_t j = 0; j < w_leaf.size(); j ++ ) |
2113 |
|
sendbuffs[ offset + j ] = w_leaf[ j ]; |
2114 |
|
} |
2115 |
|
} |
2116 |
|
}; |
2117 |
|
|
2118 |
|
|
2119 |
|
template<typename T, typename TREE> |
2120 |
|
void UnpackLeaf( TREE &tree, string option, int p, |
2121 |
|
const vector<size_t> &recvsizes, |
2122 |
|
const vector<size_t> &recvskels, |
2123 |
|
const vector<T> &recvbuffs ) |
2124 |
|
{ |
2125 |
|
vector<size_t> offsets( 1, 0 ); |
2126 |
|
for ( auto it : recvsizes ) offsets.push_back( offsets.back() + it ); |
2127 |
|
|
2128 |
|
for ( auto it : tree.NearRecvFromRank[ p ] ) |
2129 |
|
{ |
2130 |
|
auto *node = tree.morton2node[ it.first ]; |
2131 |
|
if ( !option.compare( string( "leafgids" ) ) ) |
2132 |
|
{ |
2133 |
|
auto &gids = node->gids; |
2134 |
|
size_t i = it.second; |
2135 |
|
gids.reserve( recvsizes[ i ] ); |
2136 |
|
for ( uint64_t j = offsets[ i + 0 ]; |
2137 |
|
j < offsets[ i + 1 ]; |
2138 |
|
j ++ ) |
2139 |
|
{ |
2140 |
|
gids.push_back( recvskels[ j ] ); |
2141 |
|
} |
2142 |
|
} |
2143 |
|
else |
2144 |
|
{ |
2145 |
|
/** Number of right hand sides */ |
2146 |
|
size_t nrhs = tree.setup.w->col(); |
2147 |
|
auto &w_leaf = node->data.w_leaf; |
2148 |
|
size_t i = it.second; |
2149 |
|
w_leaf.resize( recvsizes[ i ] / nrhs, nrhs ); |
2150 |
|
//printf( "%d recv w_leaf from %d [%lu %lu]\n", |
2151 |
|
// comm_rank, p, w_leaf.row(), w_leaf.col() ); fflush( stdout ); |
2152 |
|
for ( uint64_t j = offsets[ i + 0 ], jj = 0; |
2153 |
|
j < offsets[ i + 1 ]; |
2154 |
|
j ++, jj ++ ) |
2155 |
|
{ |
2156 |
|
w_leaf[ jj ] = recvbuffs[ j ]; |
2157 |
|
} |
2158 |
|
} |
2159 |
|
} |
2160 |
|
}; |
2161 |
|
|
2162 |
|
|
2163 |
|
template<typename T, typename TREE> |
2164 |
|
void PackFar( TREE &tree, string option, int p, |
2165 |
|
vector<size_t> &sendsizes, |
2166 |
|
vector<size_t> &sendskels, |
2167 |
|
vector<T> &sendbuffs ) |
2168 |
|
{ |
2169 |
|
for ( auto it : tree.FarSentToRank[ p ] ) |
2170 |
|
{ |
2171 |
|
auto *node = tree.morton2node[ it ]; |
2172 |
|
auto &skels = node->data.skels; |
2173 |
|
if ( !option.compare( string( "skelgids" ) ) ) |
2174 |
|
{ |
2175 |
|
sendsizes.push_back( skels.size() ); |
2176 |
|
sendskels.insert( sendskels.end(), skels.begin(), skels.end() ); |
2177 |
|
} |
2178 |
|
else |
2179 |
|
{ |
2180 |
|
auto &w_skel = node->data.w_skel; |
2181 |
|
sendsizes.push_back( w_skel.size() ); |
2182 |
|
sendbuffs.insert( sendbuffs.end(), w_skel.begin(), w_skel.end() ); |
2183 |
|
} |
2184 |
|
} |
2185 |
|
}; /** end PackFar() */ |
2186 |
|
|
2187 |
|
|
2188 |
|
|
2189 |
|
|
2190 |
|
|
2191 |
|
|
2192 |
|
|
2193 |
|
|
2194 |
|
|
2195 |
|
|
2196 |
|
|
2197 |
|
|
2198 |
|
/** @brief Pack a list of weights and their sizes to two messages. */ |
2199 |
|
template<typename TREE, typename T> |
2200 |
|
void PackWeights( TREE &tree, int p, |
2201 |
|
vector<T> &sendbuffs, vector<size_t> &sendsizes ) |
2202 |
|
{ |
2203 |
|
for ( auto it : tree.NearSentToRank[ p ] ) |
2204 |
|
{ |
2205 |
|
auto *node = tree.morton2node[ it ]; |
2206 |
|
auto w_leaf = node->data.w_view.toData(); |
2207 |
|
sendbuffs.insert( sendbuffs.end(), w_leaf.begin(), w_leaf.end() ); |
2208 |
|
sendsizes.push_back( w_leaf.size() ); |
2209 |
|
} |
2210 |
|
}; /** end PackWeights() */ |
2211 |
|
|
2212 |
|
|
2213 |
|
/** @brief Unpack a list of weights and their sizes. */ |
2214 |
|
template<typename TREE, typename T> |
2215 |
|
void UnpackWeights( TREE &tree, int p, |
2216 |
|
const vector<T> recvbuffs, const vector<size_t> &recvsizes ) |
2217 |
|
{ |
2218 |
|
vector<size_t> offsets( 1, 0 ); |
2219 |
|
for ( auto it : recvsizes ) offsets.push_back( offsets.back() + it ); |
2220 |
|
|
2221 |
|
for ( auto it : tree.NearRecvFromRank[ p ] ) |
2222 |
|
{ |
2223 |
|
/** Get LET node pointer. */ |
2224 |
|
auto *node = tree.morton2node[ it.first ]; |
2225 |
|
/** Number of right hand sides */ |
2226 |
|
size_t nrhs = tree.setup.w->col(); |
2227 |
|
auto &w_leaf = node->data.w_leaf; |
2228 |
|
size_t i = it.second; |
2229 |
|
w_leaf.resize( recvsizes[ i ] / nrhs, nrhs ); |
2230 |
|
for ( uint64_t j = offsets[ i + 0 ], jj = 0; |
2231 |
|
j < offsets[ i + 1 ]; |
2232 |
|
j ++, jj ++ ) |
2233 |
|
{ |
2234 |
|
w_leaf[ jj ] = recvbuffs[ j ]; |
2235 |
|
} |
2236 |
|
} |
2237 |
|
}; /** end UnpackWeights() */ |
2238 |
|
|
2239 |
|
|
2240 |
|
|
2241 |
|
/** @brief Pack a list of skeletons and their sizes to two messages. */ |
2242 |
|
template<typename TREE> |
2243 |
|
void PackSkeletons( TREE &tree, int p, |
2244 |
|
vector<size_t> &sendbuffs, vector<size_t> &sendsizes ) |
2245 |
|
{ |
2246 |
|
for ( auto it : tree.FarSentToRank[ p ] ) |
2247 |
|
{ |
2248 |
|
/** Get LET node pointer. */ |
2249 |
|
auto *node = tree.morton2node[ it ]; |
2250 |
|
auto &skels = node->data.skels; |
2251 |
|
sendbuffs.insert( sendbuffs.end(), skels.begin(), skels.end() ); |
2252 |
|
sendsizes.push_back( skels.size() ); |
2253 |
|
} |
2254 |
|
}; /** end PackSkeletons() */ |
2255 |
|
|
2256 |
|
|
2257 |
|
/** @brief Unpack a list of skeletons and their sizes. */ |
2258 |
|
template<typename TREE> |
2259 |
|
void UnpackSkeletons( TREE &tree, int p, |
2260 |
|
const vector<size_t> recvbuffs, const vector<size_t> &recvsizes ) |
2261 |
|
{ |
2262 |
|
vector<size_t> offsets( 1, 0 ); |
2263 |
|
for ( auto it : recvsizes ) offsets.push_back( offsets.back() + it ); |
2264 |
|
|
2265 |
|
for ( auto it : tree.FarRecvFromRank[ p ] ) |
2266 |
|
{ |
2267 |
|
/** Get LET node pointer. */ |
2268 |
|
auto *node = tree.morton2node[ it.first ]; |
2269 |
|
auto &skels = node->data.skels; |
2270 |
|
size_t i = it.second; |
2271 |
|
skels.clear(); |
2272 |
|
skels.reserve( recvsizes[ i ] ); |
2273 |
|
for ( uint64_t j = offsets[ i + 0 ]; |
2274 |
|
j < offsets[ i + 1 ]; |
2275 |
|
j ++ ) |
2276 |
|
{ |
2277 |
|
skels.push_back( recvbuffs[ j ] ); |
2278 |
|
} |
2279 |
|
} |
2280 |
|
}; /** end UnpackSkeletons() */ |
2281 |
|
|
2282 |
|
|
2283 |
|
|
2284 |
|
/** @brief Pack a list of skeleton weights and their sizes to two messages. */ |
2285 |
|
template<typename TREE, typename T> |
2286 |
|
void PackSkeletonWeights( TREE &tree, int p, |
2287 |
|
vector<T> &sendbuffs, vector<size_t> &sendsizes ) |
2288 |
|
{ |
2289 |
|
for ( auto it : tree.FarSentToRank[ p ] ) |
2290 |
|
{ |
2291 |
|
auto *node = tree.morton2node[ it ]; |
2292 |
|
auto &w_skel = node->data.w_skel; |
2293 |
|
sendbuffs.insert( sendbuffs.end(), w_skel.begin(), w_skel.end() ); |
2294 |
|
sendsizes.push_back( w_skel.size() ); |
2295 |
|
} |
2296 |
|
}; /** end PackSkeletonWeights() */ |
2297 |
|
|
2298 |
|
|
2299 |
|
/** @brief Unpack a list of skeletons and their sizes. */ |
2300 |
|
template<typename TREE, typename T> |
2301 |
|
void UnpackSkeletonWeights( TREE &tree, int p, |
2302 |
|
const vector<T> recvbuffs, const vector<size_t> &recvsizes ) |
2303 |
|
{ |
2304 |
|
vector<size_t> offsets( 1, 0 ); |
2305 |
|
for ( auto it : recvsizes ) offsets.push_back( offsets.back() + it ); |
2306 |
|
|
2307 |
|
for ( auto it : tree.FarRecvFromRank[ p ] ) |
2308 |
|
{ |
2309 |
|
/** Get LET node pointer. */ |
2310 |
|
auto *node = tree.morton2node[ it.first ]; |
2311 |
|
/** Number of right hand sides */ |
2312 |
|
size_t nrhs = tree.setup.w->col(); |
2313 |
|
auto &w_skel = node->data.w_skel; |
2314 |
|
size_t i = it.second; |
2315 |
|
w_skel.resize( recvsizes[ i ] / nrhs, nrhs ); |
2316 |
|
for ( uint64_t j = offsets[ i + 0 ], jj = 0; |
2317 |
|
j < offsets[ i + 1 ]; |
2318 |
|
j ++, jj ++ ) |
2319 |
|
{ |
2320 |
|
w_skel[ jj ] = recvbuffs[ j ]; |
2321 |
|
} |
2322 |
|
} |
2323 |
|
}; /** end UnpackSkeletonWeights() */ |
2324 |
|
|
2325 |
|
|
2326 |
|
|
2327 |
|
|
2328 |
|
|
2329 |
|
|
2330 |
|
template<typename T, typename TREE> |
2331 |
|
void UnpackFar( TREE &tree, string option, int p, |
2332 |
|
const vector<size_t> &recvsizes, |
2333 |
|
const vector<size_t> &recvskels, |
2334 |
|
const vector<T> &recvbuffs ) |
2335 |
|
{ |
2336 |
|
vector<size_t> offsets( 1, 0 ); |
2337 |
|
for ( auto it : recvsizes ) offsets.push_back( offsets.back() + it ); |
2338 |
|
|
2339 |
|
for ( auto it : tree.FarRecvFromRank[ p ] ) |
2340 |
|
{ |
2341 |
|
/** Get LET node pointer */ |
2342 |
|
auto *node = tree.morton2node[ it.first ]; |
2343 |
|
if ( !option.compare( string( "skelgids" ) ) ) |
2344 |
|
{ |
2345 |
|
auto &skels = node->data.skels; |
2346 |
|
size_t i = it.second; |
2347 |
|
skels.clear(); |
2348 |
|
skels.reserve( recvsizes[ i ] ); |
2349 |
|
for ( uint64_t j = offsets[ i + 0 ]; |
2350 |
|
j < offsets[ i + 1 ]; |
2351 |
|
j ++ ) |
2352 |
|
{ |
2353 |
|
skels.push_back( recvskels[ j ] ); |
2354 |
|
} |
2355 |
|
} |
2356 |
|
else |
2357 |
|
{ |
2358 |
|
/** Number of right hand sides */ |
2359 |
|
size_t nrhs = tree.setup.w->col(); |
2360 |
|
auto &w_skel = node->data.w_skel; |
2361 |
|
size_t i = it.second; |
2362 |
|
w_skel.resize( recvsizes[ i ] / nrhs, nrhs ); |
2363 |
|
//printf( "%d recv w_skel (%8lu) from %d [%lu %lu], i %lu, offset[%lu %lu] \n", |
2364 |
|
// comm_rank, (*it).first, p, w_skel.row(), w_skel.col(), i, |
2365 |
|
// offsets[ p ][ i + 0 ], offsets[ p ][ i + 1 ] ); fflush( stdout ); |
2366 |
|
for ( uint64_t j = offsets[ i + 0 ], jj = 0; |
2367 |
|
j < offsets[ i + 1 ]; |
2368 |
|
j ++, jj ++ ) |
2369 |
|
{ |
2370 |
|
w_skel[ jj ] = recvbuffs[ j ]; |
2371 |
|
//if ( jj < 5 ) printf( "%E ", w_skel[ jj ] ); fflush( stdout ); |
2372 |
|
} |
2373 |
|
//printf( "\n" ); fflush( stdout ); |
2374 |
|
} |
2375 |
|
} |
2376 |
|
}; |
2377 |
|
|
2378 |
|
|
2379 |
|
template<typename T, typename TREE> |
2380 |
|
class PackNearTask : public SendTask<T, TREE> |
2381 |
|
{ |
2382 |
|
public: |
2383 |
|
|
2384 |
|
PackNearTask( TREE *tree, int src, int tar, int key ) |
2385 |
|
: SendTask<T, TREE>( tree, src, tar, key ) |
2386 |
|
{ |
2387 |
|
/** Submit and perform dependency analysis automaticallu. */ |
2388 |
|
this->Submit(); |
2389 |
|
this->DependencyAnalysis(); |
2390 |
|
}; |
2391 |
|
|
2392 |
|
void DependencyAnalysis() |
2393 |
|
{ |
2394 |
|
TREE &tree = *(this->arg); |
2395 |
|
tree.DependOnNearInteractions( this->tar, this ); |
2396 |
|
}; |
2397 |
|
|
2398 |
|
/** Instansiate Pack() for SendTask. */ |
2399 |
|
void Pack() |
2400 |
|
{ |
2401 |
|
PackWeights( *this->arg, this->tar, |
2402 |
|
this->send_buffs, this->send_sizes ); |
2403 |
|
}; |
2404 |
|
|
2405 |
|
}; /** end class PackNearTask */ |
2406 |
|
|
2407 |
|
|
2408 |
|
|
2409 |
|
|
2410 |
|
/** |
2411 |
|
* AlltoallvTask is used perform MPI_Alltoallv in asynchronous. |
2412 |
|
* Overall there will be (p - 1) tasks per MPI rank. Each task |
2413 |
|
* performs Isend while the dependencies toward the destination |
2414 |
|
* is fullfilled. |
2415 |
|
* |
2416 |
|
* To receive the results, each MPI rank also actively runs a |
2417 |
|
* ListenerTask. Listener will keep pulling for incioming message |
2418 |
|
* that matches. Once the received results are secured, it will |
2419 |
|
* release dependent tasks. |
2420 |
|
*/ |
2421 |
|
template<typename T, typename TREE> |
2422 |
|
class UnpackLeafTask : public RecvTask<T, TREE> |
2423 |
|
{ |
2424 |
|
public: |
2425 |
|
|
2426 |
|
UnpackLeafTask( TREE *tree, int src, int tar, int key ) |
2427 |
|
: RecvTask<T, TREE>( tree, src, tar, key ) |
2428 |
|
{ |
2429 |
|
/** Submit and perform dependency analysis automaticallu. */ |
2430 |
|
this->Submit(); |
2431 |
|
this->DependencyAnalysis(); |
2432 |
|
}; |
2433 |
|
|
2434 |
|
void Unpack() |
2435 |
|
{ |
2436 |
|
UnpackWeights( *this->arg, this->src, |
2437 |
|
this->recv_buffs, this->recv_sizes ); |
2438 |
|
}; |
2439 |
|
|
2440 |
|
}; /** end class UnpackLeafTask */ |
2441 |
|
|
2442 |
|
|
2443 |
|
/** @brief */ |
2444 |
|
template<typename T, typename TREE> |
2445 |
|
class PackFarTask : public SendTask<T, TREE> |
2446 |
|
{ |
2447 |
|
public: |
2448 |
|
|
2449 |
|
PackFarTask( TREE *tree, int src, int tar, int key ) |
2450 |
|
: SendTask<T, TREE>( tree, src, tar, key ) |
2451 |
|
{ |
2452 |
|
/** Submit and perform dependency analysis automaticallu. */ |
2453 |
|
this->Submit(); |
2454 |
|
this->DependencyAnalysis(); |
2455 |
|
}; |
2456 |
|
|
2457 |
|
void DependencyAnalysis() |
2458 |
|
{ |
2459 |
|
TREE &tree = *(this->arg); |
2460 |
|
tree.DependOnFarInteractions( this->tar, this ); |
2461 |
|
}; |
2462 |
|
|
2463 |
|
/** Instansiate Pack() for SendTask. */ |
2464 |
|
void Pack() |
2465 |
|
{ |
2466 |
|
PackSkeletonWeights( *this->arg, this->tar, |
2467 |
|
this->send_buffs, this->send_sizes ); |
2468 |
|
}; |
2469 |
|
|
2470 |
|
}; /** end class PackFarTask */ |
2471 |
|
|
2472 |
|
|
2473 |
|
/** @brief */ |
2474 |
|
template<typename T, typename TREE> |
2475 |
|
class UnpackFarTask : public RecvTask<T, TREE> |
2476 |
|
{ |
2477 |
|
public: |
2478 |
|
|
2479 |
|
UnpackFarTask( TREE *tree, int src, int tar, int key ) |
2480 |
|
: RecvTask<T, TREE>( tree, src, tar, key ) |
2481 |
|
{ |
2482 |
|
/** Submit and perform dependency analysis automaticallu. */ |
2483 |
|
this->Submit(); |
2484 |
|
this->DependencyAnalysis(); |
2485 |
|
}; |
2486 |
|
|
2487 |
|
void Unpack() |
2488 |
|
{ |
2489 |
|
UnpackSkeletonWeights( *this->arg, this->src, |
2490 |
|
this->recv_buffs, this->recv_sizes ); |
2491 |
|
}; |
2492 |
|
|
2493 |
|
}; /** end class UnpackFarTask */ |
2494 |
|
|
2495 |
|
|
2496 |
|
|
2497 |
|
|
2498 |
|
|
2499 |
|
|
2500 |
|
|
2501 |
|
|
2502 |
|
|
2503 |
|
|
2504 |
|
|
2505 |
|
/** |
2506 |
|
* Send my skeletons (in gids and params) to other ranks |
2507 |
|
* using FarSentToRank[:]. |
2508 |
|
* |
2509 |
|
* Recv skeletons from other ranks |
2510 |
|
* using FarRecvFromRank[:]. |
2511 |
|
*/ |
2512 |
|
template<typename TREE> |
2513 |
|
void ExchangeLET( TREE &tree, string option ) |
2514 |
|
{ |
2515 |
|
/** Derive type T from TREE. */ |
2516 |
|
using T = typename TREE::T; |
2517 |
|
/** MPI Support. */ |
2518 |
|
int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size ); |
2519 |
|
int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank ); |
2520 |
|
|
2521 |
|
/** Buffers for sizes and skeletons */ |
2522 |
|
vector<vector<size_t>> sendsizes( comm_size ); |
2523 |
|
vector<vector<size_t>> recvsizes( comm_size ); |
2524 |
|
vector<vector<size_t>> sendskels( comm_size ); |
2525 |
|
vector<vector<size_t>> recvskels( comm_size ); |
2526 |
|
vector<vector<T>> sendbuffs( comm_size ); |
2527 |
|
vector<vector<T>> recvbuffs( comm_size ); |
2528 |
|
|
2529 |
|
/** Pack */ |
2530 |
|
#pragma omp parallel for |
2531 |
|
for ( int p = 0; p < comm_size; p ++ ) |
2532 |
|
{ |
2533 |
|
if ( !option.compare( 0, 4, "leaf" ) ) |
2534 |
|
{ |
2535 |
|
PackNear( tree, option, p, sendsizes[ p ], sendskels[ p ], sendbuffs[ p ] ); |
2536 |
|
} |
2537 |
|
else if ( !option.compare( 0, 4, "skel" ) ) |
2538 |
|
{ |
2539 |
|
PackFar( tree, option, p, sendsizes[ p ], sendskels[ p ], sendbuffs[ p ] ); |
2540 |
|
} |
2541 |
|
else |
2542 |
|
{ |
2543 |
|
printf( "ExchangeLET: option <%s> not available.\n", option.data() ); |
2544 |
|
exit( 1 ); |
2545 |
|
} |
2546 |
|
} |
2547 |
|
|
2548 |
|
/** Alltoallv */ |
2549 |
|
mpi::AlltoallVector( sendsizes, recvsizes, tree.GetComm() ); |
2550 |
|
if ( !option.compare( string( "skelgids" ) ) || |
2551 |
|
!option.compare( string( "leafgids" ) ) ) |
2552 |
|
{ |
2553 |
|
auto &K = *tree.setup.K; |
2554 |
|
mpi::AlltoallVector( sendskels, recvskels, tree.GetComm() ); |
2555 |
|
K.RequestIndices( recvskels ); |
2556 |
|
} |
2557 |
|
else |
2558 |
|
{ |
2559 |
|
double beg = omp_get_wtime(); |
2560 |
|
mpi::AlltoallVector( sendbuffs, recvbuffs, tree.GetComm() ); |
2561 |
|
double a2av_time = omp_get_wtime() - beg; |
2562 |
|
if ( comm_rank == 0 ) printf( "a2av_time %lfs\n", a2av_time ); |
2563 |
|
} |
2564 |
|
|
2565 |
|
|
2566 |
|
/** Uppack */ |
2567 |
|
#pragma omp parallel for |
2568 |
|
for ( int p = 0; p < comm_size; p ++ ) |
2569 |
|
{ |
2570 |
|
if ( !option.compare( 0, 4, "leaf" ) ) |
2571 |
|
{ |
2572 |
|
UnpackLeaf( tree, option, p, recvsizes[ p ], recvskels[ p ], recvbuffs[ p ] ); |
2573 |
|
} |
2574 |
|
else if ( !option.compare( 0, 4, "skel" ) ) |
2575 |
|
{ |
2576 |
|
UnpackFar( tree, option, p, recvsizes[ p ], recvskels[ p ], recvbuffs[ p ] ); |
2577 |
|
} |
2578 |
|
else |
2579 |
|
{ |
2580 |
|
printf( "ExchangeLET: option <%s> not available.\n", option.data() ); |
2581 |
|
exit( 1 ); |
2582 |
|
} |
2583 |
|
} |
2584 |
|
|
2585 |
|
|
2586 |
|
}; /** end ExchangeLET() */ |
2587 |
|
|
2588 |
|
|
2589 |
|
|
2590 |
|
template<typename T, typename TREE> |
2591 |
|
void AsyncExchangeLET( TREE &tree, string option ) |
2592 |
|
{ |
2593 |
|
/** MPI */ |
2594 |
|
int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size ); |
2595 |
|
int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank ); |
2596 |
|
|
2597 |
|
/** Create sending tasks. */ |
2598 |
|
for ( int p = 0; p < comm_size; p ++ ) |
2599 |
|
{ |
2600 |
|
if ( !option.compare( 0, 4, "leaf" ) ) |
2601 |
|
{ |
2602 |
|
auto *task = new PackNearTask<T, TREE>( &tree, comm_rank, p, 300 ); |
2603 |
|
/** Set src, tar, and key (tags). */ |
2604 |
|
//task->Set( &tree, comm_rank, p, 300 ); |
2605 |
|
//task->Submit(); |
2606 |
|
//task->DependencyAnalysis(); |
2607 |
|
} |
2608 |
|
else if ( !option.compare( 0, 4, "skel" ) ) |
2609 |
|
{ |
2610 |
|
auto *task = new PackFarTask<T, TREE>( &tree, comm_rank, p, 306 ); |
2611 |
|
/** Set src, tar, and key (tags). */ |
2612 |
|
//task->Set( &tree, comm_rank, p, 306 ); |
2613 |
|
//task->Submit(); |
2614 |
|
//task->DependencyAnalysis(); |
2615 |
|
} |
2616 |
|
else |
2617 |
|
{ |
2618 |
|
printf( "AsyncExchangeLET: option <%s> not available.\n", option.data() ); |
2619 |
|
exit( 1 ); |
2620 |
|
} |
2621 |
|
} |
2622 |
|
|
2623 |
|
/** Create receiving tasks */ |
2624 |
|
for ( int p = 0; p < comm_size; p ++ ) |
2625 |
|
{ |
2626 |
|
if ( !option.compare( 0, 4, "leaf" ) ) |
2627 |
|
{ |
2628 |
|
auto *task = new UnpackLeafTask<T, TREE>( &tree, p, comm_rank, 300 ); |
2629 |
|
/** Set src, tar, and key (tags). */ |
2630 |
|
//task->Set( &tree, p, comm_rank, 300 ); |
2631 |
|
//task->Submit(); |
2632 |
|
//task->DependencyAnalysis(); |
2633 |
|
} |
2634 |
|
else if ( !option.compare( 0, 4, "skel" ) ) |
2635 |
|
{ |
2636 |
|
auto *task = new UnpackFarTask<T, TREE>( &tree, p, comm_rank, 306 ); |
2637 |
|
/** Set src, tar, and key (tags). */ |
2638 |
|
//task->Set( &tree, p, comm_rank, 306 ); |
2639 |
|
//task->Submit(); |
2640 |
|
//task->DependencyAnalysis(); |
2641 |
|
} |
2642 |
|
else |
2643 |
|
{ |
2644 |
|
printf( "AsyncExchangeLET: option <%s> not available.\n", option.data() ); |
2645 |
|
exit( 1 ); |
2646 |
|
} |
2647 |
|
} |
2648 |
|
|
2649 |
|
}; /** AsyncExchangeLET() */ |
2650 |
|
|
2651 |
|
|
2652 |
|
|
2653 |
|
|
2654 |
|
template<typename T, typename TREE> |
2655 |
|
void ExchangeNeighbors( TREE &tree ) |
2656 |
|
{ |
2657 |
|
mpi::PrintProgress( "[BEG] ExchangeNeighbors ...", tree.GetComm() ); |
2658 |
|
|
2659 |
|
int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank ); |
2660 |
|
int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size ); |
2661 |
|
|
2662 |
|
/** Alltoallv buffers */ |
2663 |
|
vector<vector<size_t>> send_buff( comm_size ); |
2664 |
|
vector<vector<size_t>> recv_buff( comm_size ); |
2665 |
|
|
2666 |
|
/** NN<STAR, CIDS, pair<T, size_t>> */ |
2667 |
|
unordered_set<size_t> requested_gids; |
2668 |
|
auto &NN = *tree.setup.NN; |
2669 |
|
|
2670 |
|
/** Remove duplication. */ |
2671 |
|
for ( auto & it : NN ) |
2672 |
|
{ |
2673 |
|
if ( it.second >= 0 && it.second < tree.n ) |
2674 |
|
requested_gids.insert( it.second ); |
2675 |
|
} |
2676 |
|
|
2677 |
|
/** Remove owned gids. */ |
2678 |
|
for ( auto it : tree.treelist[ 0 ]->gids ) requested_gids.erase( it ); |
2679 |
|
|
2680 |
|
/** Assume gid is owned by (gid % size) */ |
2681 |
|
for ( auto it :requested_gids ) |
2682 |
|
{ |
2683 |
|
int p = it % comm_size; |
2684 |
|
if ( p != comm_rank ) send_buff[ p ].push_back( it ); |
2685 |
|
} |
2686 |
|
|
2687 |
|
/** Redistribute K. */ |
2688 |
|
auto &K = *tree.setup.K; |
2689 |
|
K.RequestIndices( send_buff ); |
2690 |
|
|
2691 |
|
mpi::PrintProgress( "[END] ExchangeNeighbors ...", tree.GetComm() ); |
2692 |
|
}; /** end ExchangeNeighbors() */ |
2693 |
|
|
2694 |
|
|
2695 |
|
|
2696 |
|
|
2697 |
|
|
2698 |
|
|
2699 |
|
|
2700 |
|
|
2701 |
|
|
2702 |
|
|
2703 |
|
|
2704 |
|
template<bool SYMMETRIC, typename NODE, typename T> |
2705 |
|
void MergeFarNodes( NODE *node ) |
2706 |
|
{ |
2707 |
|
/** if I don't have any skeleton, then I'm nobody's far field */ |
2708 |
|
//if ( !node->data.isskel ) return; |
2709 |
|
|
2710 |
|
/** |
2711 |
|
* Examine "Near" interaction list |
2712 |
|
*/ |
2713 |
|
//if ( node->isleaf ) |
2714 |
|
//{ |
2715 |
|
// auto & NearMortonIDs = node->NNNearNodeMortonIDs; |
2716 |
|
// #pragma omp critical |
2717 |
|
// { |
2718 |
|
// int rank; |
2719 |
|
// mpi::Comm_rank( MPI_COMM_WORLD, &rank ); |
2720 |
|
// string outfile = to_string( rank ); |
2721 |
|
// FILE * pFile = fopen( outfile.data(), "a+" ); |
2722 |
|
// fprintf( pFile, "(%8lu) ", node->morton ); |
2723 |
|
// for ( auto it = NearMortonIDs.begin(); it != NearMortonIDs.end(); it ++ ) |
2724 |
|
// fprintf( pFile, "%8lu, ", (*it) ); |
2725 |
|
// fprintf( pFile, "\n" ); //fflush( stdout ); |
2726 |
|
// } |
2727 |
|
|
2728 |
|
// //auto & NearNodes = node->NNNearNodes; |
2729 |
|
// //for ( auto it = NearNodes.begin(); it != NearNodes.end(); it ++ ) |
2730 |
|
// //{ |
2731 |
|
// // if ( !(*it)->NNNearNodes.count( node ) ) |
2732 |
|
// // { |
2733 |
|
// // printf( "(%8lu) misses %lu\n", (*it)->morton, node->morton ); fflush( stdout ); |
2734 |
|
// // } |
2735 |
|
// //} |
2736 |
|
//}; |
2737 |
|
|
2738 |
|
|
2739 |
|
/** Add my sibling (in the same level) to far interaction lists */ |
2740 |
|
assert( !node->FarNodeMortonIDs.size() ); |
2741 |
|
assert( !node->FarNodes.size() ); |
2742 |
|
node->FarNodeMortonIDs.insert( node->sibling->morton ); |
2743 |
|
node->FarNodes.insert( node->sibling ); |
2744 |
|
|
2745 |
|
/** Construct NN far interaction lists */ |
2746 |
|
if ( node->isleaf ) |
2747 |
|
{ |
2748 |
|
FindFarNodes( MortonHelper::Root(), node ); |
2749 |
|
} |
2750 |
|
else |
2751 |
|
{ |
2752 |
|
/** Merge Far( lchild ) and Far( rchild ) from children */ |
2753 |
|
auto *lchild = node->lchild; |
2754 |
|
auto *rchild = node->rchild; |
2755 |
|
|
2756 |
|
/** case: NNPRUNE (FMM specific) */ |
2757 |
|
auto &pNNFarNodes = node->NNFarNodeMortonIDs; |
2758 |
|
auto &lNNFarNodes = lchild->NNFarNodeMortonIDs; |
2759 |
|
auto &rNNFarNodes = rchild->NNFarNodeMortonIDs; |
2760 |
|
|
2761 |
|
/** Far( parent ) = Far( lchild ) intersects Far( rchild ) */ |
2762 |
|
for ( auto it = lNNFarNodes.begin(); |
2763 |
|
it != lNNFarNodes.end(); it ++ ) |
2764 |
|
{ |
2765 |
|
if ( rNNFarNodes.count( *it ) ) |
2766 |
|
{ |
2767 |
|
pNNFarNodes.insert( *it ); |
2768 |
|
} |
2769 |
|
} |
2770 |
|
/** Far( lchild ) \= Far( parent ); Far( rchild ) \= Far( parent ) */ |
2771 |
|
for ( auto it = pNNFarNodes.begin(); |
2772 |
|
it != pNNFarNodes.end(); it ++ ) |
2773 |
|
{ |
2774 |
|
lNNFarNodes.erase( *it ); |
2775 |
|
rNNFarNodes.erase( *it ); |
2776 |
|
} |
2777 |
|
} |
2778 |
|
|
2779 |
|
}; /** end MergeFarNodes() */ |
2780 |
|
|
2781 |
|
|
2782 |
|
|
2783 |
|
template<bool SYMMETRIC, typename NODE, typename T> |
2784 |
|
class MergeFarNodesTask : public Task |
2785 |
|
{ |
2786 |
|
public: |
2787 |
|
|
2788 |
|
NODE *arg; |
2789 |
|
|
2790 |
|
void Set( NODE *user_arg ) |
2791 |
|
{ |
2792 |
|
arg = user_arg; |
2793 |
|
name = string( "merge" ); |
2794 |
|
label = to_string( arg->treelist_id ); |
2795 |
|
/** we don't know the exact cost here */ |
2796 |
|
cost = 5.0; |
2797 |
|
/** high priority */ |
2798 |
|
priority = true; |
2799 |
|
}; |
2800 |
|
|
2801 |
|
/** read this node and write to children */ |
2802 |
|
void DependencyAnalysis() |
2803 |
|
{ |
2804 |
|
arg->DependencyAnalysis( RW, this ); |
2805 |
|
if ( !arg->isleaf ) |
2806 |
|
{ |
2807 |
|
arg->lchild->DependencyAnalysis( RW, this ); |
2808 |
|
arg->rchild->DependencyAnalysis( RW, this ); |
2809 |
|
} |
2810 |
|
this->TryEnqueue(); |
2811 |
|
}; |
2812 |
|
|
2813 |
|
void Execute( Worker* user_worker ) |
2814 |
|
{ |
2815 |
|
MergeFarNodes<SYMMETRIC, NODE, T>( arg ); |
2816 |
|
}; |
2817 |
|
|
2818 |
|
}; /** end class MergeFarNodesTask */ |
2819 |
|
|
2820 |
|
|
2821 |
|
|
2822 |
|
|
2823 |
|
|
2824 |
|
|
2825 |
|
|
2826 |
|
|
2827 |
|
|
2828 |
|
|
2829 |
|
|
2830 |
|
|
2831 |
|
template<bool SYMMETRIC, typename NODE, typename T> |
2832 |
|
void DistMergeFarNodes( NODE *node ) |
2833 |
|
{ |
2834 |
|
/** MPI */ |
2835 |
|
mpi::Status status; |
2836 |
|
mpi::Comm comm = node->GetComm(); |
2837 |
|
int comm_size = node->GetCommSize(); |
2838 |
|
int comm_rank = node->GetCommRank(); |
2839 |
|
|
2840 |
|
/** if I don't have any skeleton, then I'm nobody's far field */ |
2841 |
|
//if ( !node->data.isskel ) return; |
2842 |
|
|
2843 |
|
|
2844 |
|
/** Early return if this is the root node. */ |
2845 |
|
if ( !node->parent ) return; |
2846 |
|
|
2847 |
|
/** Distributed treenode */ |
2848 |
|
if ( node->GetCommSize() < 2 ) |
2849 |
|
{ |
2850 |
|
MergeFarNodes<SYMMETRIC, NODE, T>( node ); |
2851 |
|
} |
2852 |
|
else |
2853 |
|
{ |
2854 |
|
/** merge Far( lchild ) and Far( rchild ) from children */ |
2855 |
|
auto *child = node->child; |
2856 |
|
|
2857 |
|
if ( comm_rank == 0 ) |
2858 |
|
{ |
2859 |
|
auto &pNNFarNodes = node->NNFarNodeMortonIDs; |
2860 |
|
auto &lNNFarNodes = child->NNFarNodeMortonIDs; |
2861 |
|
vector<size_t> recvFarNodes; |
2862 |
|
|
2863 |
|
/** Recv rNNFarNodes */ |
2864 |
|
mpi::RecvVector( recvFarNodes, comm_size / 2, 0, comm, &status ); |
2865 |
|
|
2866 |
|
/** Far( parent ) = Far( lchild ) intersects Far( rchild ). */ |
2867 |
|
for ( auto it : recvFarNodes ) |
2868 |
|
{ |
2869 |
|
if ( lNNFarNodes.count( it ) ) |
2870 |
|
{ |
2871 |
|
pNNFarNodes.insert( it ); |
2872 |
|
} |
2873 |
|
} |
2874 |
|
|
2875 |
|
/** Reuse space to send pNNFarNodes. */ |
2876 |
|
recvFarNodes.clear(); |
2877 |
|
recvFarNodes.reserve( pNNFarNodes.size() ); |
2878 |
|
|
2879 |
|
/** Far( lchild ) \= Far( parent ); Far( rchild ) \= Far( parent ). */ |
2880 |
|
for ( auto it : pNNFarNodes ) |
2881 |
|
{ |
2882 |
|
lNNFarNodes.erase( it ); |
2883 |
|
recvFarNodes.push_back( it ); |
2884 |
|
} |
2885 |
|
|
2886 |
|
/** Send pNNFarNodes. */ |
2887 |
|
mpi::SendVector( recvFarNodes, comm_size / 2, 0, comm ); |
2888 |
|
} |
2889 |
|
|
2890 |
|
|
2891 |
|
if ( comm_rank == comm_size / 2 ) |
2892 |
|
{ |
2893 |
|
auto &rNNFarNodes = child->NNFarNodeMortonIDs; |
2894 |
|
vector<size_t> sendFarNodes( rNNFarNodes.begin(), rNNFarNodes.end() ); |
2895 |
|
|
2896 |
|
/** Send rNNFarNodes. */ |
2897 |
|
mpi::SendVector( sendFarNodes, 0, 0, comm ); |
2898 |
|
/** Recuse sendFarNodes to receive pNNFarNodes. */ |
2899 |
|
mpi::RecvVector( sendFarNodes, 0, 0, comm, &status ); |
2900 |
|
/** Far( lchild ) \= Far( parent ); Far( rchild ) \= Far( parent ) */ |
2901 |
|
for ( auto it : sendFarNodes ) rNNFarNodes.erase( it ); |
2902 |
|
} |
2903 |
|
} |
2904 |
|
|
2905 |
|
}; /** end DistMergeFarNodes() */ |
2906 |
|
|
2907 |
|
|
2908 |
|
|
2909 |
|
template<bool SYMMETRIC, typename NODE, typename T> |
2910 |
|
class DistMergeFarNodesTask : public Task |
2911 |
|
{ |
2912 |
|
public: |
2913 |
|
|
2914 |
|
NODE *arg = NULL; |
2915 |
|
|
2916 |
|
void Set( NODE *user_arg ) |
2917 |
|
{ |
2918 |
|
arg = user_arg; |
2919 |
|
name = string( "dist-merge" ); |
2920 |
|
label = to_string( arg->treelist_id ); |
2921 |
|
/** we don't know the exact cost here */ |
2922 |
|
cost = 5.0; |
2923 |
|
/** high priority */ |
2924 |
|
priority = true; |
2925 |
|
}; |
2926 |
|
|
2927 |
|
/** read this node and write to children */ |
2928 |
|
void DependencyAnalysis() |
2929 |
|
{ |
2930 |
|
arg->DependencyAnalysis( RW, this ); |
2931 |
|
if ( !arg->isleaf ) |
2932 |
|
{ |
2933 |
|
if ( arg->GetCommSize() > 1 ) |
2934 |
|
{ |
2935 |
|
arg->child->DependencyAnalysis( RW, this ); |
2936 |
|
} |
2937 |
|
else |
2938 |
|
{ |
2939 |
|
arg->lchild->DependencyAnalysis( RW, this ); |
2940 |
|
arg->rchild->DependencyAnalysis( RW, this ); |
2941 |
|
} |
2942 |
|
} |
2943 |
|
this->TryEnqueue(); |
2944 |
|
}; |
2945 |
|
|
2946 |
|
void Execute( Worker* user_worker ) |
2947 |
|
{ |
2948 |
|
DistMergeFarNodes<SYMMETRIC, NODE, T>( arg ); |
2949 |
|
}; |
2950 |
|
|
2951 |
|
}; /** end class DistMergeFarNodesTask */ |
2952 |
|
|
2953 |
|
|
2954 |
|
|
2955 |
|
|
2956 |
|
|
2957 |
|
|
2958 |
|
/** |
2959 |
|
* |
2960 |
|
*/ |
2961 |
|
template<bool NNPRUNE, typename NODE> |
2962 |
|
class CacheFarNodesTask : public Task |
2963 |
|
{ |
2964 |
|
public: |
2965 |
|
|
2966 |
|
NODE *arg = NULL; |
2967 |
|
|
2968 |
|
void Set( NODE *user_arg ) |
2969 |
|
{ |
2970 |
|
arg = user_arg; |
2971 |
|
name = string( "FKIJ" ); |
2972 |
|
label = to_string( arg->treelist_id ); |
2973 |
|
/** Compute FLOPS and MOPS. */ |
2974 |
|
double flops = 0, mops = 0; |
2975 |
|
/** We don't know the exact cost here. */ |
2976 |
|
cost = 5.0; |
2977 |
|
}; |
2978 |
|
|
2979 |
|
void DependencyAnalysis() |
2980 |
|
{ |
2981 |
|
arg->DependencyAnalysis( RW, this ); |
2982 |
|
this->TryEnqueue(); |
2983 |
|
}; |
2984 |
|
|
2985 |
|
void Execute( Worker* user_worker ) |
2986 |
|
{ |
2987 |
|
auto *node = arg; |
2988 |
|
auto &K = *node->setup->K; |
2989 |
|
|
2990 |
|
for ( int p = 0; p < node->DistFar.size(); p ++ ) |
2991 |
|
{ |
2992 |
|
for ( auto &it : node->DistFar[ p ] ) |
2993 |
|
{ |
2994 |
|
auto *src = (*node->morton2node)[ it.first ]; |
2995 |
|
auto &I = node->data.skels; |
2996 |
|
auto &J = src->data.skels; |
2997 |
|
it.second = K( I, J ); |
2998 |
|
//printf( "Cache I %lu J %lu\n", I.size(), J.size() ); fflush( stdout ); |
2999 |
|
} |
3000 |
|
} |
3001 |
|
}; |
3002 |
|
|
3003 |
|
}; /** end class CacheFarNodesTask */ |
3004 |
|
|
3005 |
|
|
3006 |
|
|
3007 |
|
|
3008 |
|
|
3009 |
|
|
3010 |
|
/** |
3011 |
|
* |
3012 |
|
*/ |
3013 |
|
template<bool NNPRUNE, typename NODE> |
3014 |
|
class CacheNearNodesTask : public Task |
3015 |
|
{ |
3016 |
|
public: |
3017 |
|
|
3018 |
|
NODE *arg = NULL; |
3019 |
|
|
3020 |
|
void Set( NODE *user_arg ) |
3021 |
|
{ |
3022 |
|
arg = user_arg; |
3023 |
|
name = string( "NKIJ" ); |
3024 |
|
label = to_string( arg->treelist_id ); |
3025 |
|
/** We don't know the exact cost here */ |
3026 |
|
cost = 5.0; |
3027 |
|
}; |
3028 |
|
|
3029 |
|
void DependencyAnalysis() |
3030 |
|
{ |
3031 |
|
arg->DependencyAnalysis( RW, this ); |
3032 |
|
this->TryEnqueue(); |
3033 |
|
}; |
3034 |
|
|
3035 |
|
void Execute( Worker* user_worker ) |
3036 |
|
{ |
3037 |
|
auto *node = arg; |
3038 |
|
auto &K = *node->setup->K; |
3039 |
|
|
3040 |
|
for ( int p = 0; p < node->DistNear.size(); p ++ ) |
3041 |
|
{ |
3042 |
|
for ( auto &it : node->DistNear[ p ] ) |
3043 |
|
{ |
3044 |
|
auto *src = (*node->morton2node)[ it.first ]; |
3045 |
|
auto &I = node->gids; |
3046 |
|
auto &J = src->gids; |
3047 |
|
it.second = K( I, J ); |
3048 |
|
//printf( "Cache I %lu J %lu\n", I.size(), J.size() ); fflush( stdout ); |
3049 |
|
} |
3050 |
|
} |
3051 |
|
}; |
3052 |
|
|
3053 |
|
}; /** end class CacheNearNodesTask */ |
3054 |
|
|
3055 |
|
|
3056 |
|
|
3057 |
|
|
3058 |
|
|
3059 |
|
|
3060 |
|
|
3061 |
|
|
3062 |
|
|
3063 |
|
|
3064 |
|
|
3065 |
|
|
3066 |
|
template<typename NODE, typename T> |
3067 |
|
void DistRowSamples( NODE *node, size_t nsamples ) |
3068 |
|
{ |
3069 |
|
/** MPI */ |
3070 |
|
mpi::Comm comm = node->GetComm(); |
3071 |
|
int size = node->GetCommSize(); |
3072 |
|
int rank = node->GetCommRank(); |
3073 |
|
|
3074 |
|
/** gather shared data and create reference */ |
3075 |
|
auto &K = *node->setup->K; |
3076 |
|
|
3077 |
|
/** amap contains nsamples of row gids of K */ |
3078 |
|
vector<size_t> &I = node->data.candidate_rows; |
3079 |
|
|
3080 |
|
/** Clean up candidates from previous iteration */ |
3081 |
|
I.clear(); |
3082 |
|
|
3083 |
|
/** Fill-on snids first */ |
3084 |
|
if ( rank == 0 ) |
3085 |
|
{ |
3086 |
|
/** reserve space */ |
3087 |
|
I.reserve( nsamples ); |
3088 |
|
|
3089 |
|
auto &snids = node->data.snids; |
3090 |
|
multimap<T, size_t> ordered_snids = gofmm::flip_map( snids ); |
3091 |
|
|
3092 |
|
for ( auto it = ordered_snids.begin(); |
3093 |
|
it != ordered_snids.end(); it++ ) |
3094 |
|
{ |
3095 |
|
/** (*it) has type pair<T, size_t> */ |
3096 |
|
I.push_back( (*it).second ); |
3097 |
|
if ( I.size() >= nsamples ) break; |
3098 |
|
} |
3099 |
|
} |
3100 |
|
|
3101 |
|
/** buffer space */ |
3102 |
|
vector<size_t> candidates( nsamples ); |
3103 |
|
|
3104 |
|
size_t n_required = nsamples - I.size(); |
3105 |
|
|
3106 |
|
/** bcast the termination criteria */ |
3107 |
|
mpi::Bcast( &n_required, 1, 0, comm ); |
3108 |
|
|
3109 |
|
while ( n_required ) |
3110 |
|
{ |
3111 |
|
if ( rank == 0 ) |
3112 |
|
{ |
3113 |
|
for ( size_t i = 0; i < nsamples; i ++ ) |
3114 |
|
{ |
3115 |
|
auto important_sample = K.ImportantSample( 0 ); |
3116 |
|
candidates[ i ] = important_sample.second; |
3117 |
|
} |
3118 |
|
} |
3119 |
|
|
3120 |
|
/** Bcast candidates */ |
3121 |
|
mpi::Bcast( candidates.data(), candidates.size(), 0, comm ); |
3122 |
|
|
3123 |
|
/** validation */ |
3124 |
|
vector<size_t> vconsensus( nsamples, 0 ); |
3125 |
|
vector<size_t> validation = node->setup->ContainAny( candidates, node->morton ); |
3126 |
|
|
3127 |
|
/** reduce validation */ |
3128 |
|
mpi::Reduce( validation.data(), vconsensus.data(), nsamples, MPI_SUM, 0, comm ); |
3129 |
|
|
3130 |
|
if ( rank == 0 ) |
3131 |
|
{ |
3132 |
|
for ( size_t i = 0; i < nsamples; i ++ ) |
3133 |
|
{ |
3134 |
|
/** exit is there is enough samples */ |
3135 |
|
if ( I.size() >= nsamples ) |
3136 |
|
{ |
3137 |
|
I.resize( nsamples ); |
3138 |
|
break; |
3139 |
|
} |
3140 |
|
/** Push the candidate to I after validation */ |
3141 |
|
if ( !vconsensus[ i ] ) |
3142 |
|
{ |
3143 |
|
if ( find( I.begin(), I.end(), candidates[ i ] ) == I.end() ) |
3144 |
|
I.push_back( candidates[ i ] ); |
3145 |
|
} |
3146 |
|
}; |
3147 |
|
|
3148 |
|
/** Update n_required */ |
3149 |
|
n_required = nsamples - I.size(); |
3150 |
|
} |
3151 |
|
|
3152 |
|
/** Bcast the termination criteria */ |
3153 |
|
mpi::Bcast( &n_required, 1, 0, comm ); |
3154 |
|
} |
3155 |
|
|
3156 |
|
}; /** end DistRowSamples() */ |
3157 |
|
|
3158 |
|
|
3159 |
|
|
3160 |
|
|
3161 |
|
|
3162 |
|
|
3163 |
|
/** @brief Involve MPI routins */ |
3164 |
|
template<bool NNPRUNE, typename NODE> |
3165 |
|
void DistSkeletonKIJ( NODE *node ) |
3166 |
|
{ |
3167 |
|
/** Derive type T from NODE. */ |
3168 |
|
using T = typename NODE::T; |
3169 |
|
/** Early return */ |
3170 |
|
if ( !node->parent ) return; |
3171 |
|
/** Gather shared data and create reference. */ |
3172 |
|
auto &K = *(node->setup->K); |
3173 |
|
/** Gather per node data and create reference. */ |
3174 |
|
auto &data = node->data; |
3175 |
|
auto &candidate_rows = data.candidate_rows; |
3176 |
|
auto &candidate_cols = data.candidate_cols; |
3177 |
|
auto &KIJ = data.KIJ; |
3178 |
|
|
3179 |
|
/** MPI Support. */ |
3180 |
|
auto comm = node->GetComm(); |
3181 |
|
auto size = node->GetCommSize(); |
3182 |
|
auto rank = node->GetCommRank(); |
3183 |
|
mpi::Status status; |
3184 |
|
|
3185 |
|
if ( size < 2 ) |
3186 |
|
{ |
3187 |
|
/** This node is the root of the local tree. */ |
3188 |
|
gofmm::SkeletonKIJ<NNPRUNE>( node ); |
3189 |
|
} |
3190 |
|
else |
3191 |
|
{ |
3192 |
|
/** This node (mpitree::Node) belongs to the distributed tree |
3193 |
|
* only executed by 0th and size/2 th rank of |
3194 |
|
* the node communicator. At this moment, children have been |
3195 |
|
* skeletonized. Thus, we should first update (isskel) to |
3196 |
|
* all MPI processes. Then we gather information for the |
3197 |
|
* skeletonization. |
3198 |
|
*/ |
3199 |
|
NODE *child = node->child; |
3200 |
|
size_t nsamples = 0; |
3201 |
|
|
3202 |
|
/** Bcast (isskel) to all MPI processes using children's communicator. */ |
3203 |
|
int child_isskel = child->data.isskel; |
3204 |
|
mpi::Bcast( &child_isskel, 1, 0, child->GetComm() ); |
3205 |
|
child->data.isskel = child_isskel; |
3206 |
|
|
3207 |
|
|
3208 |
|
/** rank-0 owns data of this node, and it also owns the left child. */ |
3209 |
|
if ( rank == 0 ) |
3210 |
|
{ |
3211 |
|
candidate_cols = child->data.skels; |
3212 |
|
vector<size_t> rskel; |
3213 |
|
/** Receive rskel from my sibling. */ |
3214 |
|
mpi::RecvVector( rskel, size / 2, 10, comm, &status ); |
3215 |
|
/** Correspondingly, we need to redistribute the matrix K. */ |
3216 |
|
K.RecvIndices( size / 2, comm, &status ); |
3217 |
|
/** Concatinate [ lskels, rskels ]. */ |
3218 |
|
candidate_cols.insert( candidate_cols.end(), rskel.begin(), rskel.end() ); |
3219 |
|
/** Use two times of skeletons */ |
3220 |
|
nsamples = 2 * candidate_cols.size(); |
3221 |
|
/** Make sure we at least m samples */ |
3222 |
|
if ( nsamples < 2 * node->setup->LeafNodeSize() ) |
3223 |
|
nsamples = 2 * node->setup->LeafNodeSize(); |
3224 |
|
|
3225 |
|
/** Gather rsnids. */ |
3226 |
|
auto &lsnids = node->child->data.snids; |
3227 |
|
vector<T> recv_rsdist; |
3228 |
|
vector<size_t> recv_rsnids; |
3229 |
|
|
3230 |
|
/** Receive rsnids from size / 2. */ |
3231 |
|
mpi::RecvVector( recv_rsdist, size / 2, 20, comm, &status ); |
3232 |
|
mpi::RecvVector( recv_rsnids, size / 2, 30, comm, &status ); |
3233 |
|
/** Correspondingly, we need to redistribute the matrix K. */ |
3234 |
|
K.RecvIndices( size / 2, comm, &status ); |
3235 |
|
|
3236 |
|
|
3237 |
|
/** Merge snids and update the smallest distance. */ |
3238 |
|
auto &snids = node->data.snids; |
3239 |
|
snids = lsnids; |
3240 |
|
|
3241 |
|
for ( size_t i = 0; i < recv_rsdist.size(); i ++ ) |
3242 |
|
{ |
3243 |
|
pair<size_t, T> query( recv_rsnids[ i ], recv_rsdist[ i ] ); |
3244 |
|
auto ret = snids.insert( query ); |
3245 |
|
if ( !ret.second ) |
3246 |
|
{ |
3247 |
|
if ( ret.first->second > recv_rsdist[ i ] ) |
3248 |
|
ret.first->second = recv_rsdist[ i ]; |
3249 |
|
} |
3250 |
|
} |
3251 |
|
|
3252 |
|
/** Remove gids from snids */ |
3253 |
|
for ( auto gid : node->gids ) snids.erase( gid ); |
3254 |
|
} |
3255 |
|
|
3256 |
|
if ( rank == size / 2 ) |
3257 |
|
{ |
3258 |
|
/** Send rskel to rank 0. */ |
3259 |
|
mpi::SendVector( child->data.skels, 0, 10, comm ); |
3260 |
|
/** Correspondingly, we need to redistribute the matrix K. */ |
3261 |
|
K.SendIndices( child->data.skels, 0, comm ); |
3262 |
|
|
3263 |
|
/** Gather rsnids */ |
3264 |
|
auto &rsnids = node->child->data.snids; |
3265 |
|
vector<T> send_rsdist; |
3266 |
|
vector<size_t> send_rsnids; |
3267 |
|
|
3268 |
|
/** reserve space and push in from map */ |
3269 |
|
send_rsdist.reserve( rsnids.size() ); |
3270 |
|
send_rsnids.reserve( rsnids.size() ); |
3271 |
|
|
3272 |
|
for ( auto it = rsnids.begin(); it != rsnids.end(); it ++ ) |
3273 |
|
{ |
3274 |
|
/** (*it) has type std::pair<size_t, T> */ |
3275 |
|
send_rsnids.push_back( (*it).first ); |
3276 |
|
send_rsdist.push_back( (*it).second ); |
3277 |
|
} |
3278 |
|
|
3279 |
|
/** send rsnids to rank-0 */ |
3280 |
|
mpi::SendVector( send_rsdist, 0, 20, comm ); |
3281 |
|
mpi::SendVector( send_rsnids, 0, 30, comm ); |
3282 |
|
|
3283 |
|
/** Correspondingly, we need to redistribute the matrix K. */ |
3284 |
|
K.SendIndices( send_rsnids, 0, comm ); |
3285 |
|
} |
3286 |
|
|
3287 |
|
/** Bcast nsamples. */ |
3288 |
|
mpi::Bcast( &nsamples, 1, 0, comm ); |
3289 |
|
/** Distributed row samples. */ |
3290 |
|
DistRowSamples<NODE, T>( node, nsamples ); |
3291 |
|
/** only rank-0 has non-empty I and J sets */ |
3292 |
|
if ( rank != 0 ) |
3293 |
|
{ |
3294 |
|
assert( !candidate_rows.size() ); |
3295 |
|
assert( !candidate_cols.size() ); |
3296 |
|
} |
3297 |
|
/** |
3298 |
|
* Now rank-0 has the correct ( I, J ). All other ranks in the |
3299 |
|
* communicator must flush their I and J sets before evaluation. |
3300 |
|
* all MPI process must participate in operator () |
3301 |
|
*/ |
3302 |
|
KIJ = K( candidate_rows, candidate_cols ); |
3303 |
|
} |
3304 |
|
}; /** end DistSkeletonKIJ() */ |
3305 |
|
|
3306 |
|
|
3307 |
|
/** |
3308 |
|
* |
3309 |
|
*/ |
3310 |
|
template<bool NNPRUNE, typename NODE, typename T> |
3311 |
|
class DistSkeletonKIJTask : public Task |
3312 |
|
{ |
3313 |
|
public: |
3314 |
|
|
3315 |
|
NODE *arg = NULL; |
3316 |
|
|
3317 |
|
void Set( NODE *user_arg ) |
3318 |
|
{ |
3319 |
|
arg = user_arg; |
3320 |
|
name = string( "par-gskm" ); |
3321 |
|
label = to_string( arg->treelist_id ); |
3322 |
|
/** We don't know the exact cost here */ |
3323 |
|
cost = 5.0; |
3324 |
|
/** "High" priority */ |
3325 |
|
priority = true; |
3326 |
|
}; |
3327 |
|
|
3328 |
|
void DependencyAnalysis() { arg->DependOnChildren( this ); }; |
3329 |
|
|
3330 |
|
void Execute( Worker* user_worker ) { DistSkeletonKIJ<NNPRUNE>( arg ); }; |
3331 |
|
|
3332 |
|
}; /** end class DistSkeletonKIJTask */ |
3333 |
|
|
3334 |
|
|
3335 |
|
|
3336 |
|
|
3337 |
|
|
3338 |
|
|
3339 |
|
|
3340 |
|
|
3341 |
|
|
3342 |
|
|
3343 |
|
|
3344 |
|
|
3345 |
|
/** |
3346 |
|
* @brief Skeletonization with interpolative decomposition. |
3347 |
|
*/ |
3348 |
|
template<typename NODE, typename T> |
3349 |
|
void DistSkeletonize( NODE *node ) |
3350 |
|
{ |
3351 |
|
/** early return if we do not need to skeletonize */ |
3352 |
|
if ( !node->parent ) return; |
3353 |
|
|
3354 |
|
/** gather shared data and create reference */ |
3355 |
|
auto &K = *(node->setup->K); |
3356 |
|
auto &NN = *(node->setup->NN); |
3357 |
|
auto maxs = node->setup->MaximumRank(); |
3358 |
|
auto stol = node->setup->Tolerance(); |
3359 |
|
bool secure_accuracy = node->setup->SecureAccuracy(); |
3360 |
|
bool use_adaptive_ranks = node->setup->UseAdaptiveRanks(); |
3361 |
|
|
3362 |
|
/** gather per node data and create reference */ |
3363 |
|
auto &data = node->data; |
3364 |
|
auto &skels = data.skels; |
3365 |
|
auto &proj = data.proj; |
3366 |
|
auto &jpvt = data.jpvt; |
3367 |
|
auto &KIJ = data.KIJ; |
3368 |
|
auto &candidate_cols = data.candidate_cols; |
3369 |
|
|
3370 |
|
/** interpolative decomposition */ |
3371 |
|
size_t N = K.col(); |
3372 |
|
size_t m = KIJ.row(); |
3373 |
|
size_t n = KIJ.col(); |
3374 |
|
size_t q = node->n; |
3375 |
|
|
3376 |
|
if ( secure_accuracy ) |
3377 |
|
{ |
3378 |
|
/** TODO: need to check of both children's isskel to preceed */ |
3379 |
|
} |
3380 |
|
|
3381 |
|
|
3382 |
|
/** Bill's l2 norm scaling factor */ |
3383 |
|
T scaled_stol = std::sqrt( (T)n / q ) * std::sqrt( (T)m / (N - q) ) * stol; |
3384 |
|
|
3385 |
|
/** account for uniform sampling */ |
3386 |
|
scaled_stol *= std::sqrt( (T)q / N ); |
3387 |
|
|
3388 |
|
lowrank::id |
3389 |
|
( |
3390 |
|
use_adaptive_ranks, secure_accuracy, |
3391 |
|
KIJ.row(), KIJ.col(), maxs, scaled_stol, |
3392 |
|
KIJ, skels, proj, jpvt |
3393 |
|
); |
3394 |
|
|
3395 |
|
/** Free KIJ for spaces */ |
3396 |
|
KIJ.resize( 0, 0 ); |
3397 |
|
KIJ.shrink_to_fit(); |
3398 |
|
|
3399 |
|
/** depending on the flag, decide isskel or not */ |
3400 |
|
if ( secure_accuracy ) |
3401 |
|
{ |
3402 |
|
/** TODO: this needs to be bcast to other nodes */ |
3403 |
|
data.isskel = (skels.size() != 0); |
3404 |
|
} |
3405 |
|
else |
3406 |
|
{ |
3407 |
|
assert( skels.size() ); |
3408 |
|
assert( proj.size() ); |
3409 |
|
assert( jpvt.size() ); |
3410 |
|
data.isskel = true; |
3411 |
|
} |
3412 |
|
|
3413 |
|
/** Relabel skeletions with the real gids */ |
3414 |
|
for ( size_t i = 0; i < skels.size(); i ++ ) |
3415 |
|
{ |
3416 |
|
skels[ i ] = candidate_cols[ skels[ i ] ]; |
3417 |
|
} |
3418 |
|
|
3419 |
|
|
3420 |
|
}; /** end DistSkeletonize() */ |
3421 |
|
|
3422 |
|
|
3423 |
|
|
3424 |
|
|
3425 |
|
template<typename NODE, typename T> |
3426 |
|
class SkeletonizeTask : public hmlp::Task |
3427 |
|
{ |
3428 |
|
public: |
3429 |
|
|
3430 |
|
NODE *arg; |
3431 |
|
|
3432 |
|
void Set( NODE *user_arg ) |
3433 |
|
{ |
3434 |
|
arg = user_arg; |
3435 |
|
name = string( "SK" ); |
3436 |
|
label = to_string( arg->treelist_id ); |
3437 |
|
/** We don't know the exact cost here */ |
3438 |
|
cost = 5.0; |
3439 |
|
/** "High" priority */ |
3440 |
|
priority = true; |
3441 |
|
}; |
3442 |
|
|
3443 |
|
void GetEventRecord() |
3444 |
|
{ |
3445 |
|
double flops = 0.0, mops = 0.0; |
3446 |
|
|
3447 |
|
auto &K = *arg->setup->K; |
3448 |
|
size_t n = arg->data.proj.col(); |
3449 |
|
size_t m = 2 * n; |
3450 |
|
size_t k = arg->data.proj.row(); |
3451 |
|
|
3452 |
|
/** GEQP3 */ |
3453 |
|
flops += ( 4.0 / 3.0 ) * n * n * ( 3 * m - n ); |
3454 |
|
mops += ( 2.0 / 3.0 ) * n * n * ( 3 * m - n ); |
3455 |
|
|
3456 |
|
/* TRSM */ |
3457 |
|
flops += k * ( k - 1 ) * ( n + 1 ); |
3458 |
|
mops += 2.0 * ( k * k + k * n ); |
3459 |
|
|
3460 |
|
event.Set( label + name, flops, mops ); |
3461 |
|
arg->data.skeletonize = event; |
3462 |
|
}; |
3463 |
|
|
3464 |
|
void DependencyAnalysis() |
3465 |
|
{ |
3466 |
|
arg->DependencyAnalysis( RW, this ); |
3467 |
|
this->TryEnqueue(); |
3468 |
|
}; |
3469 |
|
|
3470 |
|
void Execute( Worker* user_worker ) |
3471 |
|
{ |
3472 |
|
//printf( "%d Par-Skel beg\n", global_rank ); |
3473 |
|
|
3474 |
|
DistSkeletonize<NODE, T>( arg ); |
3475 |
|
|
3476 |
|
//printf( "%d Par-Skel end\n", global_rank ); |
3477 |
|
}; |
3478 |
|
|
3479 |
|
}; /** end class SkeletonTask */ |
3480 |
|
|
3481 |
|
|
3482 |
|
|
3483 |
|
|
3484 |
|
/** |
3485 |
|
* |
3486 |
|
*/ |
3487 |
|
template<typename NODE, typename T> |
3488 |
|
class DistSkeletonizeTask : public hmlp::Task |
3489 |
|
{ |
3490 |
|
public: |
3491 |
|
|
3492 |
|
NODE *arg; |
3493 |
|
|
3494 |
|
void Set( NODE *user_arg ) |
3495 |
|
{ |
3496 |
|
arg = user_arg; |
3497 |
|
name = string( "PSK" ); |
3498 |
|
label = to_string( arg->treelist_id ); |
3499 |
|
|
3500 |
|
/** We don't know the exact cost here */ |
3501 |
|
cost = 5.0; |
3502 |
|
/** "High" priority */ |
3503 |
|
priority = true; |
3504 |
|
}; |
3505 |
|
|
3506 |
|
void GetEventRecord() |
3507 |
|
{ |
3508 |
|
double flops = 0.0, mops = 0.0; |
3509 |
|
|
3510 |
|
auto &K = *arg->setup->K; |
3511 |
|
size_t n = arg->data.proj.col(); |
3512 |
|
size_t m = 2 * n; |
3513 |
|
size_t k = arg->data.proj.row(); |
3514 |
|
|
3515 |
|
if ( arg->GetCommRank() == 0 ) |
3516 |
|
{ |
3517 |
|
/** GEQP3 */ |
3518 |
|
flops += ( 4.0 / 3.0 ) * n * n * ( 3 * m - n ); |
3519 |
|
mops += ( 2.0 / 3.0 ) * n * n * ( 3 * m - n ); |
3520 |
|
|
3521 |
|
/* TRSM */ |
3522 |
|
flops += k * ( k - 1 ) * ( n + 1 ); |
3523 |
|
mops += 2.0 * ( k * k + k * n ); |
3524 |
|
} |
3525 |
|
|
3526 |
|
event.Set( label + name, flops, mops ); |
3527 |
|
arg->data.skeletonize = event; |
3528 |
|
}; |
3529 |
|
|
3530 |
|
void DependencyAnalysis() |
3531 |
|
{ |
3532 |
|
arg->DependencyAnalysis( RW, this ); |
3533 |
|
this->TryEnqueue(); |
3534 |
|
}; |
3535 |
|
|
3536 |
|
void Execute( Worker* user_worker ) |
3537 |
|
{ |
3538 |
|
mpi::Comm comm = arg->GetComm(); |
3539 |
|
|
3540 |
|
double beg = omp_get_wtime(); |
3541 |
|
if ( arg->GetCommRank() == 0 ) |
3542 |
|
{ |
3543 |
|
DistSkeletonize<NODE, T>( arg ); |
3544 |
|
} |
3545 |
|
double skel_t = omp_get_wtime() - beg; |
3546 |
|
|
3547 |
|
/** Bcast isskel to every MPI processes in the same comm */ |
3548 |
|
int isskel = arg->data.isskel; |
3549 |
|
mpi::Bcast( &isskel, 1, 0, comm ); |
3550 |
|
arg->data.isskel = isskel; |
3551 |
|
|
3552 |
|
/** Bcast skels and proj to every MPI processes in the same comm */ |
3553 |
|
auto &skels = arg->data.skels; |
3554 |
|
size_t nskels = skels.size(); |
3555 |
|
mpi::Bcast( &nskels, 1, 0, comm ); |
3556 |
|
if ( skels.size() != nskels ) skels.resize( nskels ); |
3557 |
|
mpi::Bcast( skels.data(), skels.size(), 0, comm ); |
3558 |
|
|
3559 |
|
}; |
3560 |
|
|
3561 |
|
}; /** end class DistSkeletonTask */ |
3562 |
|
|
3563 |
|
|
3564 |
|
|
3565 |
|
|
3566 |
|
/** |
3567 |
|
* @brief |
3568 |
|
*/ |
3569 |
|
template<typename NODE> |
3570 |
|
class InterpolateTask : public Task |
3571 |
|
{ |
3572 |
|
public: |
3573 |
|
|
3574 |
|
NODE *arg = NULL; |
3575 |
|
|
3576 |
|
void Set( NODE *user_arg ) |
3577 |
|
{ |
3578 |
|
arg = user_arg; |
3579 |
|
name = string( "PROJ" ); |
3580 |
|
label = to_string( arg->treelist_id ); |
3581 |
|
// Need an accurate cost model. |
3582 |
|
cost = 1.0; |
3583 |
|
}; |
3584 |
|
|
3585 |
|
void DependencyAnalysis() { arg->DependOnNoOne( this ); }; |
3586 |
|
|
3587 |
|
void Execute( Worker* user_worker ) |
3588 |
|
{ |
3589 |
|
/** MPI Support. */ |
3590 |
|
auto comm = arg->GetComm(); |
3591 |
|
/** Only executed by rank 0. */ |
3592 |
|
if ( arg->GetCommRank() == 0 ) gofmm::Interpolate( arg ); |
3593 |
|
|
3594 |
|
auto &proj = arg->data.proj; |
3595 |
|
size_t nrow = proj.row(); |
3596 |
|
size_t ncol = proj.col(); |
3597 |
|
mpi::Bcast( &nrow, 1, 0, comm ); |
3598 |
|
mpi::Bcast( &ncol, 1, 0, comm ); |
3599 |
|
if ( proj.row() != nrow || proj.col() != ncol ) proj.resize( nrow, ncol ); |
3600 |
|
mpi::Bcast( proj.data(), proj.size(), 0, comm ); |
3601 |
|
}; |
3602 |
|
|
3603 |
|
}; /** end class InterpolateTask */ |
3604 |
|
|
3605 |
|
|
3606 |
|
|
3607 |
|
|
3608 |
|
|
3609 |
|
|
3610 |
|
|
3611 |
|
|
3612 |
|
|
3613 |
|
|
3614 |
|
|
3615 |
|
|
3616 |
|
|
3617 |
|
|
3618 |
|
|
3619 |
|
|
3620 |
|
|
3621 |
|
|
3622 |
|
|
3623 |
|
|
3624 |
|
|
3625 |
|
|
3626 |
|
|
3627 |
|
|
3628 |
|
|
3629 |
|
|
3630 |
|
|
3631 |
|
|
3632 |
|
|
3633 |
|
|
3634 |
|
/** |
3635 |
|
* @brief ComputeAll |
3636 |
|
*/ |
3637 |
|
template<bool NNPRUNE = true, typename TREE, typename T> |
3638 |
|
DistData<RIDS, STAR, T> Evaluate( TREE &tree, DistData<RIDS, STAR, T> &weights ) |
3639 |
|
{ |
3640 |
|
try |
3641 |
|
{ |
3642 |
|
/** MPI Support. */ |
3643 |
|
int size; mpi::Comm_size( tree.GetComm(), &size ); |
3644 |
|
int rank; mpi::Comm_rank( tree.GetComm(), &rank ); |
3645 |
|
/** Derive type NODE and MPINODE from TREE. */ |
3646 |
|
using NODE = typename TREE::NODE; |
3647 |
|
using MPINODE = typename TREE::MPINODE; |
3648 |
|
|
3649 |
|
/** All timers */ |
3650 |
|
double beg, time_ratio, evaluation_time = 0.0; |
3651 |
|
double direct_evaluation_time = 0.0, computeall_time, telescope_time, let_exchange_time, async_time; |
3652 |
|
double overhead_time; |
3653 |
|
double forward_permute_time, backward_permute_time; |
3654 |
|
|
3655 |
|
/** Clean up all r/w dependencies left on tree nodes. */ |
3656 |
|
tree.DependencyCleanUp(); |
3657 |
|
|
3658 |
|
/** n-by-nrhs, initialize potentials. */ |
3659 |
|
size_t n = weights.row(); |
3660 |
|
size_t nrhs = weights.col(); |
3661 |
|
|
3662 |
|
/** Potentials must be in [RIDS,STAR] distribution */ |
3663 |
|
auto &gids_owned = tree.treelist[ 0 ]->gids; |
3664 |
|
DistData<RIDS, STAR, T> potentials( n, nrhs, gids_owned, tree.GetComm() ); |
3665 |
|
potentials.setvalue( 0.0 ); |
3666 |
|
|
3667 |
|
/** Provide pointers. */ |
3668 |
|
tree.setup.w = &weights; |
3669 |
|
tree.setup.u = &potentials; |
3670 |
|
|
3671 |
|
/** TreeView (downward traversal) */ |
3672 |
|
gofmm::TreeViewTask<NODE> seqVIEWtask; |
3673 |
|
mpigofmm::DistTreeViewTask<MPINODE> mpiVIEWtask; |
3674 |
|
/** Telescope (upward traversal) */ |
3675 |
|
gofmm::UpdateWeightsTask<NODE, T> seqN2Stask; |
3676 |
|
mpigofmm::DistUpdateWeightsTask<MPINODE, T> mpiN2Stask; |
3677 |
|
/** L2L (sum of direct evaluations) */ |
3678 |
|
//mpigofmm::DistLeavesToLeavesTask<NNPRUNE, NODE, T> seqL2Ltask; |
3679 |
|
//mpigofmm::L2LReduceTask<NODE, T> seqL2LReducetask; |
3680 |
|
mpigofmm::L2LReduceTask2<NODE, T> seqL2LReducetask2; |
3681 |
|
/** S2S (sum of low-rank approximation) */ |
3682 |
|
//gofmm::SkeletonsToSkeletonsTask<NNPRUNE, NODE, T> seqS2Stask; |
3683 |
|
//mpigofmm::DistSkeletonsToSkeletonsTask<NNPRUNE, MPINODE, T> mpiS2Stask; |
3684 |
|
//mpigofmm::S2SReduceTask<NODE, T> seqS2SReducetask; |
3685 |
|
//mpigofmm::S2SReduceTask<MPINODE, T> mpiS2SReducetask; |
3686 |
|
mpigofmm::S2SReduceTask2<NODE, NODE, T> seqS2SReducetask2; |
3687 |
|
mpigofmm::S2SReduceTask2<MPINODE, NODE, T> mpiS2SReducetask2; |
3688 |
|
/** Telescope (downward traversal) */ |
3689 |
|
gofmm::SkeletonsToNodesTask<NNPRUNE, NODE, T> seqS2Ntask; |
3690 |
|
mpigofmm::DistSkeletonsToNodesTask<NNPRUNE, MPINODE, T> mpiS2Ntask; |
3691 |
|
|
3692 |
|
/** Global barrier and timer */ |
3693 |
|
mpi::Barrier( tree.GetComm() ); |
3694 |
|
|
3695 |
|
//{ |
3696 |
|
// /** Stage 1: TreeView and upward telescoping */ |
3697 |
|
// beg = omp_get_wtime(); |
3698 |
|
// tree.DependencyCleanUp(); |
3699 |
|
// tree.DistTraverseDown( mpiVIEWtask ); |
3700 |
|
// tree.LocaTraverseDown( seqVIEWtask ); |
3701 |
|
// tree.LocaTraverseUp( seqN2Stask ); |
3702 |
|
// tree.DistTraverseUp( mpiN2Stask ); |
3703 |
|
// hmlp_run(); |
3704 |
|
// mpi::Barrier( tree.GetComm() ); |
3705 |
|
// telescope_time = omp_get_wtime() - beg; |
3706 |
|
|
3707 |
|
// /** Stage 2: LET exchange */ |
3708 |
|
// beg = omp_get_wtime(); |
3709 |
|
// ExchangeLET<T>( tree, string( "skelweights" ) ); |
3710 |
|
// mpi::Barrier( tree.GetComm() ); |
3711 |
|
// ExchangeLET<T>( tree, string( "leafweights" ) ); |
3712 |
|
// mpi::Barrier( tree.GetComm() ); |
3713 |
|
// let_exchange_time = omp_get_wtime() - beg; |
3714 |
|
|
3715 |
|
// /** Stage 3: L2L */ |
3716 |
|
// beg = omp_get_wtime(); |
3717 |
|
// tree.DependencyCleanUp(); |
3718 |
|
// tree.LocaTraverseLeafs( seqL2LReducetask2 ); |
3719 |
|
// hmlp_run(); |
3720 |
|
// mpi::Barrier( tree.GetComm() ); |
3721 |
|
// direct_evaluation_time = omp_get_wtime() - beg; |
3722 |
|
|
3723 |
|
// /** Stage 4: S2S and downward telescoping */ |
3724 |
|
// beg = omp_get_wtime(); |
3725 |
|
// tree.DependencyCleanUp(); |
3726 |
|
// tree.LocaTraverseUnOrdered( seqS2SReducetask2 ); |
3727 |
|
// tree.DistTraverseUnOrdered( mpiS2SReducetask2 ); |
3728 |
|
// tree.DistTraverseDown( mpiS2Ntask ); |
3729 |
|
// tree.LocaTraverseDown( seqS2Ntask ); |
3730 |
|
// hmlp_run(); |
3731 |
|
// mpi::Barrier( tree.GetComm() ); |
3732 |
|
// computeall_time = omp_get_wtime() - beg; |
3733 |
|
//} |
3734 |
|
|
3735 |
|
|
3736 |
|
/** Global barrier and timer */ |
3737 |
|
potentials.setvalue( 0.0 ); |
3738 |
|
mpi::Barrier( tree.GetComm() ); |
3739 |
|
|
3740 |
|
/** Stage 1: TreeView and upward telescoping */ |
3741 |
|
beg = omp_get_wtime(); |
3742 |
|
tree.DependencyCleanUp(); |
3743 |
|
tree.DistTraverseDown( mpiVIEWtask ); |
3744 |
|
tree.LocaTraverseDown( seqVIEWtask ); |
3745 |
|
tree.ExecuteAllTasks(); |
3746 |
|
/** Stage 2: redistribute weights from IDS to LET. */ |
3747 |
|
AsyncExchangeLET<T>( tree, string( "leafweights" ) ); |
3748 |
|
/** Stage 3: N2S. */ |
3749 |
|
tree.LocaTraverseUp( seqN2Stask ); |
3750 |
|
tree.DistTraverseUp( mpiN2Stask ); |
3751 |
|
/** Stage 4: redistribute skeleton weights from IDS to LET. */ |
3752 |
|
AsyncExchangeLET<T>( tree, string( "skelweights" ) ); |
3753 |
|
/** Stage 5: L2L */ |
3754 |
|
tree.LocaTraverseLeafs( seqL2LReducetask2 ); |
3755 |
|
/** Stage 6: S2S */ |
3756 |
|
tree.LocaTraverseUnOrdered( seqS2SReducetask2 ); |
3757 |
|
tree.DistTraverseUnOrdered( mpiS2SReducetask2 ); |
3758 |
|
/** Stage 7: S2N */ |
3759 |
|
tree.DistTraverseDown( mpiS2Ntask ); |
3760 |
|
tree.LocaTraverseDown( seqS2Ntask ); |
3761 |
|
overhead_time = omp_get_wtime() - beg; |
3762 |
|
tree.ExecuteAllTasks(); |
3763 |
|
async_time = omp_get_wtime() - beg; |
3764 |
|
|
3765 |
|
|
3766 |
|
|
3767 |
|
/** Compute the breakdown cost */ |
3768 |
|
evaluation_time += direct_evaluation_time; |
3769 |
|
evaluation_time += telescope_time; |
3770 |
|
evaluation_time += let_exchange_time; |
3771 |
|
evaluation_time += computeall_time; |
3772 |
|
time_ratio = 100 / evaluation_time; |
3773 |
|
|
3774 |
|
if ( rank == 0 && REPORT_EVALUATE_STATUS ) |
3775 |
|
{ |
3776 |
|
printf( "========================================================\n"); |
3777 |
|
printf( "GOFMM evaluation phase\n" ); |
3778 |
|
printf( "========================================================\n"); |
3779 |
|
//printf( "Allocate ------------------------------ %5.2lfs (%5.1lf%%)\n", |
3780 |
|
// allocate_time, allocate_time * time_ratio ); |
3781 |
|
//printf( "Forward permute ----------------------- %5.2lfs (%5.1lf%%)\n", |
3782 |
|
// forward_permute_time, forward_permute_time * time_ratio ); |
3783 |
|
printf( "Upward telescope ---------------------- %5.2lfs (%5.1lf%%)\n", |
3784 |
|
telescope_time, telescope_time * time_ratio ); |
3785 |
|
printf( "LET exchange -------------------------- %5.2lfs (%5.1lf%%)\n", |
3786 |
|
let_exchange_time, let_exchange_time * time_ratio ); |
3787 |
|
printf( "L2L ----------------------------------- %5.2lfs (%5.1lf%%)\n", |
3788 |
|
direct_evaluation_time, direct_evaluation_time * time_ratio ); |
3789 |
|
printf( "S2S, S2N ------------------------------ %5.2lfs (%5.1lf%%)\n", |
3790 |
|
computeall_time, computeall_time * time_ratio ); |
3791 |
|
//printf( "Backward permute ---------------------- %5.2lfs (%5.1lf%%)\n", |
3792 |
|
// backward_permute_time, backward_permute_time * time_ratio ); |
3793 |
|
printf( "========================================================\n"); |
3794 |
|
printf( "Evaluate ------------------------------ %5.2lfs (%5.1lf%%)\n", |
3795 |
|
evaluation_time, evaluation_time * time_ratio ); |
3796 |
|
printf( "Evaluate (Async) ---------------------- %5.2lfs (%5.2lfs)\n", |
3797 |
|
async_time, overhead_time ); |
3798 |
|
printf( "========================================================\n\n"); |
3799 |
|
} |
3800 |
|
|
3801 |
|
return potentials; |
3802 |
|
} |
3803 |
|
catch ( const exception & e ) |
3804 |
|
{ |
3805 |
|
cout << e.what() << endl; |
3806 |
|
exit( 1 ); |
3807 |
|
} |
3808 |
|
}; /** end Evaluate() */ |
3809 |
|
|
3810 |
|
|
3811 |
|
|
3812 |
|
|
3813 |
|
template<bool NNPRUNE = true, typename TREE, typename T> |
3814 |
|
DistData<RBLK, STAR, T> Evaluate( TREE &tree, DistData<RBLK, STAR, T> &w_rblk ) |
3815 |
|
{ |
3816 |
|
size_t n = w_rblk.row(); |
3817 |
|
size_t nrhs = w_rblk.col(); |
3818 |
|
/** Redistribute weights from RBLK to RIDS. */ |
3819 |
|
DistData<RIDS, STAR, T> w_rids( n, nrhs, tree.treelist[ 0 ]->gids, tree.GetComm() ); |
3820 |
|
w_rids = w_rblk; |
3821 |
|
/** Evaluation with RIDS distribution. */ |
3822 |
|
auto u_rids = Evaluate<NNPRUNE>( tree, w_rids ); |
3823 |
|
mpi::Barrier( tree.GetComm() ); |
3824 |
|
/** Redistribute potentials from RIDS to RBLK. */ |
3825 |
|
DistData<RBLK, STAR, T> u_rblk( n, nrhs, tree.GetComm() ); |
3826 |
|
u_rblk = u_rids; |
3827 |
|
/** Return potentials in RBLK distribution. */ |
3828 |
|
return u_rblk; |
3829 |
|
}; /** end Evaluate() */ |
3830 |
|
|
3831 |
|
|
3832 |
|
|
3833 |
|
template<typename SPLITTER, typename T, typename SPDMATRIX> |
3834 |
|
DistData<STAR, CBLK, pair<T, size_t>> FindNeighbors |
3835 |
|
( |
3836 |
|
SPDMATRIX &K, |
3837 |
|
SPLITTER splitter, |
3838 |
|
gofmm::Configuration<T> &config, |
3839 |
|
mpi::Comm CommGOFMM, |
3840 |
|
size_t n_iter = 10 |
3841 |
|
) |
3842 |
|
{ |
3843 |
|
/** Instantiation for the randomized metric tree. */ |
3844 |
|
using DATA = gofmm::NodeData<T>; |
3845 |
|
using SETUP = mpigofmm::Setup<SPDMATRIX, SPLITTER, T>; |
3846 |
|
using TREE = mpitree::Tree<SETUP, DATA>; |
3847 |
|
/** Derive type NODE from TREE. */ |
3848 |
|
using NODE = typename TREE::NODE; |
3849 |
|
/** Get all user-defined parameters. */ |
3850 |
|
DistanceMetric metric = config.MetricType(); |
3851 |
|
size_t n = config.ProblemSize(); |
3852 |
|
size_t k = config.NeighborSize(); |
3853 |
|
/** Iterative all nearnest-neighbor (ANN). */ |
3854 |
|
pair<T, size_t> init( numeric_limits<T>::max(), n ); |
3855 |
|
gofmm::NeighborsTask<NODE, T> NEIGHBORStask; |
3856 |
|
TREE rkdt( CommGOFMM ); |
3857 |
|
rkdt.setup.FromConfiguration( config, K, splitter, NULL ); |
3858 |
|
return rkdt.AllNearestNeighbor( n_iter, n, k, init, NEIGHBORStask ); |
3859 |
|
}; /** end FindNeighbors() */ |
3860 |
|
|
3861 |
|
|
3862 |
|
|
3863 |
|
|
3864 |
|
|
3865 |
|
|
3866 |
|
|
3867 |
|
|
3868 |
|
|
3869 |
|
/** |
3870 |
|
* @brief template of the compress routine |
3871 |
|
*/ |
3872 |
|
template<typename SPLITTER, typename RKDTSPLITTER, typename T, typename SPDMATRIX> |
3873 |
|
mpitree::Tree<mpigofmm::Setup<SPDMATRIX, SPLITTER, T>, gofmm::NodeData<T>> |
3874 |
|
*Compress |
3875 |
|
( |
3876 |
|
SPDMATRIX &K, |
3877 |
|
DistData<STAR, CBLK, pair<T, size_t>> &NN_cblk, |
3878 |
|
SPLITTER splitter, |
3879 |
|
RKDTSPLITTER rkdtsplitter, |
3880 |
|
gofmm::Configuration<T> &config, |
3881 |
|
mpi::Comm CommGOFMM |
3882 |
|
) |
3883 |
|
{ |
3884 |
|
try |
3885 |
|
{ |
3886 |
|
/** MPI size ane rank. */ |
3887 |
|
int size; mpi::Comm_size( CommGOFMM, &size ); |
3888 |
|
int rank; mpi::Comm_rank( CommGOFMM, &rank ); |
3889 |
|
|
3890 |
|
/** Get all user-defined parameters. */ |
3891 |
|
DistanceMetric metric = config.MetricType(); |
3892 |
|
size_t n = config.ProblemSize(); |
3893 |
|
size_t m = config.LeafNodeSize(); |
3894 |
|
size_t k = config.NeighborSize(); |
3895 |
|
size_t s = config.MaximumRank(); |
3896 |
|
|
3897 |
|
/** options */ |
3898 |
|
const bool SYMMETRIC = true; |
3899 |
|
const bool NNPRUNE = true; |
3900 |
|
const bool CACHE = true; |
3901 |
|
|
3902 |
|
/** Instantiation for the GOFMM metric tree. */ |
3903 |
|
using SETUP = mpigofmm::Setup<SPDMATRIX, SPLITTER, T>; |
3904 |
|
using DATA = gofmm::NodeData<T>; |
3905 |
|
using TREE = mpitree::Tree<SETUP, DATA>; |
3906 |
|
/** Derive type NODE and MPINODE from TREE. */ |
3907 |
|
using NODE = typename TREE::NODE; |
3908 |
|
using MPINODE = typename TREE::MPINODE; |
3909 |
|
|
3910 |
|
/** All timers. */ |
3911 |
|
double beg, omptask45_time, omptask_time, ref_time; |
3912 |
|
double time_ratio, compress_time = 0.0, other_time = 0.0; |
3913 |
|
double ann_time, tree_time, skel_time, mpi_skel_time, mergefarnodes_time, cachefarnodes_time; |
3914 |
|
double local_skel_time, dist_skel_time, let_time; |
3915 |
|
double nneval_time, nonneval_time, fmm_evaluation_time, symbolic_evaluation_time; |
3916 |
|
double exchange_neighbor_time, symmetrize_time; |
3917 |
|
|
3918 |
|
/** Iterative all nearnest-neighbor (ANN). */ |
3919 |
|
beg = omp_get_wtime(); |
3920 |
|
if ( k && NN_cblk.row() * NN_cblk.col() != k * n ) |
3921 |
|
{ |
3922 |
|
NN_cblk = mpigofmm::FindNeighbors( K, rkdtsplitter, |
3923 |
|
config, CommGOFMM ); |
3924 |
|
} |
3925 |
|
ann_time = omp_get_wtime() - beg; |
3926 |
|
|
3927 |
|
/** Initialize metric ball tree using approximate center split. */ |
3928 |
|
auto *tree_ptr = new TREE( CommGOFMM ); |
3929 |
|
auto &tree = *tree_ptr; |
3930 |
|
|
3931 |
|
/** Global configuration for the metric tree. */ |
3932 |
|
tree.setup.FromConfiguration( config, K, splitter, &NN_cblk ); |
3933 |
|
|
3934 |
|
/** Metric ball tree partitioning. */ |
3935 |
|
beg = omp_get_wtime(); |
3936 |
|
tree.TreePartition(); |
3937 |
|
tree_time = omp_get_wtime() - beg; |
3938 |
|
|
3939 |
|
/** Get tree permutataion. */ |
3940 |
|
vector<size_t> perm = tree.GetPermutation(); |
3941 |
|
if ( rank == 0 ) |
3942 |
|
{ |
3943 |
|
ofstream perm_file( "perm.txt" ); |
3944 |
|
for ( auto &id : perm ) perm_file << id << " "; |
3945 |
|
perm_file.close(); |
3946 |
|
} |
3947 |
|
|
3948 |
|
|
3949 |
|
/** Redistribute neighbors i.e. NN[ *, CIDS ] = NN[ *, CBLK ]; */ |
3950 |
|
DistData<STAR, CIDS, pair<T, size_t>> NN( k, n, tree.treelist[ 0 ]->gids, tree.GetComm() ); |
3951 |
|
NN = NN_cblk; |
3952 |
|
tree.setup.NN = &NN; |
3953 |
|
beg = omp_get_wtime(); |
3954 |
|
ExchangeNeighbors<T>( tree ); |
3955 |
|
exchange_neighbor_time = omp_get_wtime() - beg; |
3956 |
|
|
3957 |
|
|
3958 |
|
beg = omp_get_wtime(); |
3959 |
|
/** Construct near interaction lists. */ |
3960 |
|
FindNearInteractions( tree ); |
3961 |
|
/** Symmetrize interaction pairs by Alltoallv. */ |
3962 |
|
mpigofmm::SymmetrizeNearInteractions( tree ); |
3963 |
|
/** Split node interaction lists per MPI rank. */ |
3964 |
|
BuildInteractionListPerRank( tree, true ); |
3965 |
|
/** Exchange {leafs} and {paramsleafs)}. */ |
3966 |
|
ExchangeLET( tree, string( "leafgids" ) ); |
3967 |
|
symmetrize_time = omp_get_wtime() - beg; |
3968 |
|
|
3969 |
|
|
3970 |
|
/** Find and merge far interactions. */ |
3971 |
|
mpi::PrintProgress( "[BEG] MergeFarNodes ...", tree.GetComm() ); |
3972 |
|
beg = omp_get_wtime(); |
3973 |
|
tree.DependencyCleanUp(); |
3974 |
|
MergeFarNodesTask<true, NODE, T> seqMERGEtask; |
3975 |
|
DistMergeFarNodesTask<true, MPINODE, T> mpiMERGEtask; |
3976 |
|
tree.LocaTraverseUp( seqMERGEtask ); |
3977 |
|
tree.DistTraverseUp( mpiMERGEtask ); |
3978 |
|
tree.ExecuteAllTasks(); |
3979 |
|
mergefarnodes_time += omp_get_wtime() - beg; |
3980 |
|
mpi::PrintProgress( "[END] MergeFarNodes ...", tree.GetComm() ); |
3981 |
|
|
3982 |
|
/** Symmetrize interaction pairs by Alltoallv. */ |
3983 |
|
beg = omp_get_wtime(); |
3984 |
|
mpigofmm::SymmetrizeFarInteractions( tree ); |
3985 |
|
/** Split node interaction lists per MPI rank. */ |
3986 |
|
BuildInteractionListPerRank( tree, false ); |
3987 |
|
symmetrize_time += omp_get_wtime() - beg; |
3988 |
|
|
3989 |
|
mpi::PrintProgress( "[BEG] Skeletonization ...", tree.GetComm() ); |
3990 |
|
/** Skeletonization */ |
3991 |
|
beg = omp_get_wtime(); |
3992 |
|
tree.DependencyCleanUp(); |
3993 |
|
/** Gather sample rows and skeleton columns, then ID */ |
3994 |
|
gofmm::SkeletonKIJTask<NNPRUNE, NODE, T> seqGETMTXtask; |
3995 |
|
mpigofmm::DistSkeletonKIJTask<NNPRUNE, MPINODE, T> mpiGETMTXtask; |
3996 |
|
mpigofmm::SkeletonizeTask<NODE, T> seqSKELtask; |
3997 |
|
mpigofmm::DistSkeletonizeTask<MPINODE, T> mpiSKELtask; |
3998 |
|
tree.LocaTraverseUp( seqGETMTXtask, seqSKELtask ); |
3999 |
|
//tree.DistTraverseUp( mpiGETMTXtask, mpiSKELtask ); |
4000 |
|
/** Compute the coefficient matrix of ID */ |
4001 |
|
gofmm::InterpolateTask<NODE> seqPROJtask; |
4002 |
|
mpigofmm::InterpolateTask<MPINODE> mpiPROJtask; |
4003 |
|
tree.LocaTraverseUnOrdered( seqPROJtask ); |
4004 |
|
//tree.DistTraverseUnOrdered( mpiPROJtask ); |
4005 |
|
|
4006 |
|
/** Cache near KIJ interactions */ |
4007 |
|
mpigofmm::CacheNearNodesTask<NNPRUNE, NODE> seqNEARKIJtask; |
4008 |
|
//tree.LocaTraverseLeafs( seqNEARKIJtask ); |
4009 |
|
|
4010 |
|
tree.ExecuteAllTasks(); |
4011 |
|
skel_time = omp_get_wtime() - beg; |
4012 |
|
|
4013 |
|
beg = omp_get_wtime(); |
4014 |
|
tree.DistTraverseUp( mpiGETMTXtask, mpiSKELtask ); |
4015 |
|
tree.DistTraverseUnOrdered( mpiPROJtask ); |
4016 |
|
tree.ExecuteAllTasks(); |
4017 |
|
mpi_skel_time = omp_get_wtime() - beg; |
4018 |
|
mpi::PrintProgress( "[END] Skeletonization ...", tree.GetComm() ); |
4019 |
|
|
4020 |
|
|
4021 |
|
|
4022 |
|
/** Exchange {skels} and {params(skels)}. */ |
4023 |
|
ExchangeLET( tree, string( "skelgids" ) ); |
4024 |
|
|
4025 |
|
beg = omp_get_wtime(); |
4026 |
|
/** Cache near KIJ interactions */ |
4027 |
|
//mpigofmm::CacheNearNodesTask<NNPRUNE, NODE> seqNEARKIJtask; |
4028 |
|
//tree.LocaTraverseLeafs( seqNEARKIJtask ); |
4029 |
|
/** Cache far KIJ interactions */ |
4030 |
|
mpigofmm::CacheFarNodesTask<NNPRUNE, NODE> seqFARKIJtask; |
4031 |
|
mpigofmm::CacheFarNodesTask<NNPRUNE, MPINODE> mpiFARKIJtask; |
4032 |
|
//tree.LocaTraverseUnOrdered( seqFARKIJtask ); |
4033 |
|
//tree.DistTraverseUnOrdered( mpiFARKIJtask ); |
4034 |
|
cachefarnodes_time = omp_get_wtime() - beg; |
4035 |
|
tree.ExecuteAllTasks(); |
4036 |
|
cachefarnodes_time = omp_get_wtime() - beg; |
4037 |
|
|
4038 |
|
|
4039 |
|
|
4040 |
|
/** Compute the ratio of exact evaluation. */ |
4041 |
|
auto ratio = NonCompressedRatio( tree ); |
4042 |
|
|
4043 |
|
double exact_ratio = (double) m / n; |
4044 |
|
|
4045 |
|
if ( rank == 0 && REPORT_COMPRESS_STATUS ) |
4046 |
|
{ |
4047 |
|
compress_time += ann_time; |
4048 |
|
compress_time += tree_time; |
4049 |
|
compress_time += exchange_neighbor_time; |
4050 |
|
compress_time += symmetrize_time; |
4051 |
|
compress_time += skel_time; |
4052 |
|
compress_time += mpi_skel_time; |
4053 |
|
compress_time += mergefarnodes_time; |
4054 |
|
compress_time += cachefarnodes_time; |
4055 |
|
time_ratio = 100.0 / compress_time; |
4056 |
|
printf( "========================================================\n"); |
4057 |
|
printf( "GOFMM compression phase\n" ); |
4058 |
|
printf( "========================================================\n"); |
4059 |
|
printf( "NeighborSearch ------------------------ %5.2lfs (%5.1lf%%)\n", ann_time, ann_time * time_ratio ); |
4060 |
|
printf( "TreePartitioning ---------------------- %5.2lfs (%5.1lf%%)\n", tree_time, tree_time * time_ratio ); |
4061 |
|
printf( "ExchangeNeighbors --------------------- %5.2lfs (%5.1lf%%)\n", exchange_neighbor_time, exchange_neighbor_time * time_ratio ); |
4062 |
|
printf( "MergeFarNodes ------------------------- %5.2lfs (%5.1lf%%)\n", mergefarnodes_time, mergefarnodes_time * time_ratio ); |
4063 |
|
printf( "Symmetrize ---------------------------- %5.2lfs (%5.1lf%%)\n", symmetrize_time, symmetrize_time * time_ratio ); |
4064 |
|
printf( "Skeletonization (HMLP Runtime ) ----- %5.2lfs (%5.1lf%%)\n", skel_time, skel_time * time_ratio ); |
4065 |
|
printf( "Skeletonization (MPI ) ----- %5.2lfs (%5.1lf%%)\n", mpi_skel_time, mpi_skel_time * time_ratio ); |
4066 |
|
printf( "Cache KIJ ----------------------------- %5.2lfs (%5.1lf%%)\n", cachefarnodes_time, cachefarnodes_time * time_ratio ); |
4067 |
|
printf( "========================================================\n"); |
4068 |
|
printf( "%5.3lf%% and %5.3lf%% uncompressed--------- %5.2lfs (%5.1lf%%)\n", |
4069 |
|
100 * ratio.first, 100 * ratio.second, compress_time, compress_time * time_ratio ); |
4070 |
|
printf( "========================================================\n\n"); |
4071 |
|
} |
4072 |
|
|
4073 |
|
/** Cleanup all w/r dependencies on tree nodes */ |
4074 |
|
tree_ptr->DependencyCleanUp(); |
4075 |
|
/** Global barrier to make sure all processes have completed */ |
4076 |
|
mpi::Barrier( tree.GetComm() ); |
4077 |
|
|
4078 |
|
return tree_ptr; |
4079 |
|
} |
4080 |
|
catch ( const exception & e ) |
4081 |
|
{ |
4082 |
|
cout << e.what() << endl; |
4083 |
|
exit( 1 ); |
4084 |
|
} |
4085 |
|
}; /** end Compress() */ |
4086 |
|
|
4087 |
|
|
4088 |
|
|
4089 |
|
template<typename TREE, typename T> |
4090 |
|
pair<T, T> ComputeError( TREE &tree, size_t gid, Data<T> potentials ) |
4091 |
|
{ |
4092 |
|
int comm_rank; mpi::Comm_rank( tree.GetComm(), &comm_rank ); |
4093 |
|
int comm_size; mpi::Comm_size( tree.GetComm(), &comm_size ); |
4094 |
|
|
4095 |
|
/** ( sum of square errors, square 2-norm of true values ) */ |
4096 |
|
pair<T, T> ret( 0, 0 ); |
4097 |
|
|
4098 |
|
auto &K = *tree.setup.K; |
4099 |
|
auto &w = *tree.setup.w; |
4100 |
|
|
4101 |
|
auto I = vector<size_t>( 1, gid ); |
4102 |
|
auto &J = tree.treelist[ 0 ]->gids; |
4103 |
|
|
4104 |
|
/** Bcast gid and its parameter to all MPI processes. */ |
4105 |
|
K.BcastIndices( I, gid % comm_size, tree.GetComm() ); |
4106 |
|
|
4107 |
|
Data<T> Kab = K( I, J ); |
4108 |
|
|
4109 |
|
auto loc_exact = potentials; |
4110 |
|
auto glb_exact = potentials; |
4111 |
|
|
4112 |
|
xgemm( "N", "N", Kab.row(), w.col(), w.row(), |
4113 |
|
1.0, Kab.data(), Kab.row(), |
4114 |
|
w.data(), w.row(), |
4115 |
|
0.0, loc_exact.data(), loc_exact.row() ); |
4116 |
|
//gemm::xgemm( (T)1.0, Kab, w, (T)0.0, loc_exact ); |
4117 |
|
|
4118 |
|
|
4119 |
|
|
4120 |
|
|
4121 |
|
/** Allreduce u( gid, : ) = K( gid, CBLK ) * w( RBLK, : ) */ |
4122 |
|
mpi::Allreduce( loc_exact.data(), glb_exact.data(), |
4123 |
|
loc_exact.size(), MPI_SUM, tree.GetComm() ); |
4124 |
|
|
4125 |
|
for ( uint64_t j = 0; j < w.col(); j ++ ) |
4126 |
|
{ |
4127 |
|
T exac = glb_exact[ j ]; |
4128 |
|
T pred = potentials[ j ]; |
4129 |
|
/** Accumulate SSE and sqaure 2-norm. */ |
4130 |
|
ret.first += ( pred - exac ) * ( pred - exac ); |
4131 |
|
ret.second += exac * exac; |
4132 |
|
} |
4133 |
|
|
4134 |
|
return ret; |
4135 |
|
}; /** end ComputeError() */ |
4136 |
|
|
4137 |
|
|
4138 |
|
|
4139 |
|
|
4140 |
|
|
4141 |
|
|
4142 |
|
|
4143 |
|
|
4144 |
|
|
4145 |
|
|
4146 |
|
template<typename TREE> |
4147 |
|
void SelfTesting( TREE &tree, size_t ntest, size_t nrhs ) |
4148 |
|
{ |
4149 |
|
/** Derive type T from TREE. */ |
4150 |
|
using T = typename TREE::T; |
4151 |
|
/** MPI Support. */ |
4152 |
|
int rank; mpi::Comm_rank( tree.GetComm(), &rank ); |
4153 |
|
int size; mpi::Comm_size( tree.GetComm(), &size ); |
4154 |
|
/** Size of right hand sides. */ |
4155 |
|
size_t n = tree.n; |
4156 |
|
/** Shrink ntest if ntest > n. */ |
4157 |
|
if ( ntest > n ) ntest = n; |
4158 |
|
/** all_rhs = [ 0, 1, ..., nrhs - 1 ]. */ |
4159 |
|
vector<size_t> all_rhs( nrhs ); |
4160 |
|
for ( size_t rhs = 0; rhs < nrhs; rhs ++ ) all_rhs[ rhs ] = rhs; |
4161 |
|
|
4162 |
|
//auto A = tree.CheckAllInteractions(); |
4163 |
|
|
4164 |
|
/** Input and output in RIDS and RBLK. */ |
4165 |
|
DistData<RIDS, STAR, T> w_rids( n, nrhs, tree.treelist[ 0 ]->gids, tree.GetComm() ); |
4166 |
|
DistData<RBLK, STAR, T> u_rblk( n, nrhs, tree.GetComm() ); |
4167 |
|
/** Initialize with random N( 0, 1 ). */ |
4168 |
|
w_rids.randn(); |
4169 |
|
/** Evaluate u ~ K * w. */ |
4170 |
|
auto u_rids = mpigofmm::Evaluate<true>( tree, w_rids ); |
4171 |
|
/** Sanity check for INF and NAN. */ |
4172 |
|
assert( !u_rids.HasIllegalValue() ); |
4173 |
|
/** Redistribute potentials from RIDS to RBLK. */ |
4174 |
|
u_rblk = u_rids; |
4175 |
|
/** Report elementwise and F-norm accuracy. */ |
4176 |
|
if ( rank == 0 ) |
4177 |
|
{ |
4178 |
|
printf( "========================================================\n"); |
4179 |
|
printf( "Accuracy report\n" ); |
4180 |
|
printf( "========================================================\n"); |
4181 |
|
} |
4182 |
|
/** All statistics. */ |
4183 |
|
T nnerr_avg = 0.0, nonnerr_avg = 0.0, fmmerr_avg = 0.0; |
4184 |
|
T sse_2norm = 0.0, ssv_2norm = 0.0; |
4185 |
|
/** Loop over all testing gids and right hand sides. */ |
4186 |
|
for ( size_t i = 0; i < ntest; i ++ ) |
4187 |
|
{ |
4188 |
|
size_t tar = i * n / ntest; |
4189 |
|
Data<T> potentials( (size_t)1, nrhs ); |
4190 |
|
if ( rank == ( tar % size ) ) potentials = u_rblk( vector<size_t>( 1, tar ), all_rhs ); |
4191 |
|
/** Bcast potentials to all MPI processes. */ |
4192 |
|
mpi::Bcast( potentials.data(), nrhs, tar % size, tree.GetComm() ); |
4193 |
|
/** Compare potentials with exact MATVEC. */ |
4194 |
|
auto sse_ssv = mpigofmm::ComputeError( tree, tar, potentials ); |
4195 |
|
/** Compute element-wise 2-norm error. */ |
4196 |
|
auto fmmerr = sqrt( sse_ssv.first / sse_ssv.second ); |
4197 |
|
/** Accumulate element-wise 2-norm error. */ |
4198 |
|
fmmerr_avg += fmmerr; |
4199 |
|
/** Accumulate SSE and SSV. */ |
4200 |
|
sse_2norm += sse_ssv.first; |
4201 |
|
ssv_2norm += sse_ssv.second; |
4202 |
|
/** Only print 10 values. */ |
4203 |
|
if ( i < 10 && rank == 0 ) |
4204 |
|
{ |
4205 |
|
printf( "gid %6lu, ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n", |
4206 |
|
tar, 0.0, 0.0, fmmerr ); |
4207 |
|
} |
4208 |
|
} |
4209 |
|
if ( rank == 0 ) |
4210 |
|
{ |
4211 |
|
printf( "========================================================\n"); |
4212 |
|
printf( "Elementwise ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n", |
4213 |
|
nnerr_avg / ntest , nonnerr_avg / ntest, fmmerr_avg / ntest ); |
4214 |
|
printf( "F-norm ASKIT %3.1E, HODLR %3.1E, GOFMM %3.1E\n", |
4215 |
|
0.0, 0.0, sqrt( sse_2norm / ssv_2norm ) ); |
4216 |
|
printf( "========================================================\n"); |
4217 |
|
} |
4218 |
|
|
4219 |
|
/** Factorization */ |
4220 |
|
T lambda = 10.0; |
4221 |
|
mpigofmm::DistFactorize( tree, lambda ); |
4222 |
|
mpigofmm::ComputeError( tree, lambda, w_rids, u_rids ); |
4223 |
|
}; /** end SelfTesting() */ |
4224 |
|
|
4225 |
|
|
4226 |
|
/** @brief Instantiate the splitters here. */ |
4227 |
|
template<typename SPDMATRIX> |
4228 |
|
void LaunchHelper( SPDMATRIX &K, gofmm::CommandLineHelper &cmd, mpi::Comm CommGOFMM ) |
4229 |
|
{ |
4230 |
|
using T = typename SPDMATRIX::T; |
4231 |
|
const int N_CHILDREN = 2; |
4232 |
|
/** Use geometric-oblivious splitters. */ |
4233 |
|
using SPLITTER = mpigofmm::centersplit<SPDMATRIX, N_CHILDREN, T>; |
4234 |
|
using RKDTSPLITTER = mpigofmm::randomsplit<SPDMATRIX, N_CHILDREN, T>; |
4235 |
|
/** GOFMM tree splitter. */ |
4236 |
|
SPLITTER splitter( K ); |
4237 |
|
splitter.Kptr = &K; |
4238 |
|
splitter.metric = cmd.metric; |
4239 |
|
/** Randomized tree splitter. */ |
4240 |
|
RKDTSPLITTER rkdtsplitter( K ); |
4241 |
|
rkdtsplitter.Kptr = &K; |
4242 |
|
rkdtsplitter.metric = cmd.metric; |
4243 |
|
/** Create configuration for all user-define arguments. */ |
4244 |
|
gofmm::Configuration<T> config( cmd.metric, |
4245 |
|
cmd.n, cmd.m, cmd.k, cmd.s, cmd.stol, cmd.budget ); |
4246 |
|
/** (Optional) provide neighbors, leave uninitialized otherwise. */ |
4247 |
|
DistData<STAR, CBLK, pair<T, size_t>> NN( 0, cmd.n, CommGOFMM ); |
4248 |
|
/** Compress matrix K. */ |
4249 |
|
auto *tree_ptr = mpigofmm::Compress( K, NN, splitter, rkdtsplitter, config, CommGOFMM ); |
4250 |
|
auto &tree = *tree_ptr; |
4251 |
|
|
4252 |
|
/** Examine accuracies. */ |
4253 |
|
mpigofmm::SelfTesting( tree, 100, cmd.nrhs ); |
4254 |
|
|
4255 |
|
/** Delete tree_ptr. */ |
4256 |
|
delete tree_ptr; |
4257 |
|
}; /** end test_gofmm_setup() */ |
4258 |
|
|
4259 |
|
|
4260 |
|
}; /** end namespace gofmm */ |
4261 |
|
}; /** end namespace hmlp */ |
4262 |
|
|
4263 |
|
#endif /** define GOFMM_MPI_HPP */ |