1 |
|
/** |
2 |
|
* HMLP (High-Performance Machine Learning Primitives) |
3 |
|
* |
4 |
|
* Copyright (C) 2014-2017, The University of Texas at Austin |
5 |
|
* |
6 |
|
* This program is free software: you can redistribute it and/or modify |
7 |
|
* it under the terms of the GNU General Public License as published by |
8 |
|
* the Free Software Foundation, either version 3 of the License, or |
9 |
|
* (at your option) any later version. |
10 |
|
* |
11 |
|
* This program is distributed in the hope that it will be useful, |
12 |
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of |
13 |
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
14 |
|
* GNU General Public License for more details. |
15 |
|
* |
16 |
|
* You should have received a copy of the GNU General Public License |
17 |
|
* along with this program. If not, see the LICENSE file. |
18 |
|
* |
19 |
|
**/ |
20 |
|
|
21 |
|
|
22 |
|
|
23 |
|
#ifndef IGOFMM_HPP |
24 |
|
#define IGOFMM_HPP |
25 |
|
|
26 |
|
|
27 |
|
/** Use STL and HMLP namespaces. */ |
28 |
|
using namespace std; |
29 |
|
using namespace hmlp; |
30 |
|
|
31 |
|
|
32 |
|
|
33 |
|
|
34 |
|
namespace hmlp |
35 |
|
{ |
36 |
|
namespace gofmm |
37 |
|
{ |
38 |
|
|
39 |
|
|
40 |
|
/** |
41 |
|
* |
42 |
|
* for each level |
43 |
|
* for each alpha |
44 |
|
* |
45 |
|
* if ( leaf ) |
46 |
|
* |
47 |
|
* LL' = Chol( Kaa ) |
48 |
|
* U = inv( L ) * proj' or |
49 |
|
* U' = proj * inv( L' ) |
50 |
|
* QR = qr( U ) or |
51 |
|
* LQ' = lq( U' ) |
52 |
|
* |
53 |
|
* else |
54 |
|
* |
55 |
|
* LL' = Chol( I + k.R * C * k.R' ) |
56 |
|
* |
57 |
|
* if ( not root ) |
58 |
|
* |
59 |
|
* U = inv( L ) * [ l.R * proj' or |
60 |
|
* r.R ] |
61 |
|
* U' = proj' * [ l.R * inv( L' ) |
62 |
|
* r.R ] |
63 |
|
* QR = qr( U ) or |
64 |
|
* LQ' = lq( U' ) |
65 |
|
* |
66 |
|
**/ |
67 |
|
|
68 |
|
|
69 |
|
template<typename T> |
70 |
|
class Factor |
71 |
|
{ |
72 |
|
public: |
73 |
|
|
74 |
|
Factor() {}; |
75 |
|
|
76 |
|
void SetupFactor |
77 |
|
( |
78 |
|
bool issymmetric, bool do_ulv_factorization, |
79 |
|
bool isleaf, bool isroot, |
80 |
|
/** n == nl + nr (left + right) */ |
81 |
|
size_t n, size_t nl, size_t nr, |
82 |
|
/** s <= sl + sr */ |
83 |
|
size_t s, size_t sl, size_t sr |
84 |
|
) |
85 |
|
{ |
86 |
|
this->issymmetric = issymmetric; |
87 |
|
this->do_ulv_factorization = do_ulv_factorization; |
88 |
|
this->isleaf = isleaf; |
89 |
|
this->isroot = isroot; |
90 |
|
this->n = n; this->nl = nl; this->nr = nr; |
91 |
|
this->s = s; this->sl = sl; this->sr = sr; |
92 |
|
}; |
93 |
|
|
94 |
|
void SetupFactor |
95 |
|
( |
96 |
|
bool issymmetric, bool do_ulv_factorization, |
97 |
|
bool isleaf, bool isroot, |
98 |
|
size_t n, size_t nl, size_t nr, |
99 |
|
size_t s, size_t sl, size_t sr, |
100 |
|
/** n-by-?; its rank depends on mu sibling */ |
101 |
|
Data<T> &U, |
102 |
|
/** ?-by-n; its rank depends on my sibling */ |
103 |
|
Data<T> &V |
104 |
|
) |
105 |
|
{ |
106 |
|
SetupFactor( issymmetric, do_ulv_factorization, |
107 |
|
isleaf, isroot, n, nl, nr, s, sl, sr ); |
108 |
|
}; |
109 |
|
|
110 |
|
bool DoULVFactorization() |
111 |
|
{ |
112 |
|
return do_ulv_factorization; |
113 |
|
}; |
114 |
|
|
115 |
|
bool IsSymmetric() { return issymmetric; }; |
116 |
|
|
117 |
|
|
118 |
|
|
119 |
|
|
120 |
|
|
121 |
|
void CheckCondition() |
122 |
|
{ |
123 |
|
assert( do_ulv_factorization && issymmetric ); |
124 |
|
T max_diag = 0.0; |
125 |
|
T min_diag = 0.0; |
126 |
|
|
127 |
|
for ( size_t i = 0; i < Z.row(); i ++ ) |
128 |
|
{ |
129 |
|
T abs_diag = std::abs( Z( i, i ) ); |
130 |
|
|
131 |
|
if ( !i ) |
132 |
|
{ |
133 |
|
max_diag = abs_diag; |
134 |
|
min_diag = abs_diag; |
135 |
|
} |
136 |
|
|
137 |
|
if ( abs_diag > max_diag ) max_diag = abs_diag; |
138 |
|
if ( abs_diag < min_diag ) min_diag = abs_diag; |
139 |
|
} |
140 |
|
|
141 |
|
printf( "condiditon( Z ): min_diag %3.1E max_diag %3.1E ratio %3.1E\n", |
142 |
|
min_diag, max_diag, min_diag / max_diag ); |
143 |
|
}; |
144 |
|
|
145 |
|
void Factorize( Data<T> &Kaa ) |
146 |
|
{ |
147 |
|
assert( isleaf ); |
148 |
|
assert( Kaa.row() == n ); assert( Kaa.col() == n ); |
149 |
|
|
150 |
|
/** Initialize with Kaa. */ |
151 |
|
Z = Kaa; |
152 |
|
|
153 |
|
/** Record the partial pivoting order. */ |
154 |
|
ipiv.resize( n, 0 ); |
155 |
|
|
156 |
|
/** Compute 1-norm of Z. */ |
157 |
|
T nrm1 = 0.0; |
158 |
|
for ( auto &z : Z ) nrm1 += z; |
159 |
|
|
160 |
|
/** Pivoted LU factorization. */ |
161 |
|
xgetrf( n, n, Z.data(), n, ipiv.data() ); |
162 |
|
|
163 |
|
/** Compute 1-norm condition number. */ |
164 |
|
T rcond1 = 0.0; |
165 |
|
Data<T> work( Z.row(), 4 ); |
166 |
|
vector<int> iwork( Z.row() ); |
167 |
|
xgecon( "1", Z.row(), Z.data(), Z.row(), nrm1, |
168 |
|
&rcond1, work.data(), iwork.data() ); |
169 |
|
if ( 1.0 / rcond1 > 1E+6 ) |
170 |
|
printf( "Warning! large 1-norm condition number %3.1E, nrm1( Z ) %3.1E\n", |
171 |
|
1.0 / rcond1, nrm1 ); |
172 |
|
}; /** end Factorize() */ |
173 |
|
|
174 |
|
|
175 |
|
/** |
176 |
|
* Kaa = [ P [ L11 [ I [ U11 U12 |
177 |
|
* I ] L21 I ] C ] I ] |
178 |
|
*/ |
179 |
|
void PartialFactorize( Data<T> &A ) |
180 |
|
{ |
181 |
|
/** Similar transformation ( Q' * Z * Q ). */ |
182 |
|
Z = A; |
183 |
|
ChangeBasis( Z ); |
184 |
|
|
185 |
|
/** Create matrix views for Z. */ |
186 |
|
Zv.Set( false, Z ); |
187 |
|
Zv.Partition2x2( Ztl, Ztr, |
188 |
|
Zbl, Zbr, s, s, BOTTOMRIGHT ); |
189 |
|
|
190 |
|
//printf( "Ztl %lux%lu Ztr %lux%lu\n", Ztl.row(), Ztl.col(), Ztr.row(), Ztr.col() ); fflush( stdout ); |
191 |
|
//printf( "Zbl %lux%lu Zbr %lux%lu\n", Zbl.row(), Zbl.col(), Zbr.row(), Zbr.col() ); fflush( stdout ); |
192 |
|
|
193 |
|
/** Initialize pivoting rows. */ |
194 |
|
ipiv.resize( Ztl.row(), 0 ); |
195 |
|
/** [Ztl, Ztr] = PLU */ |
196 |
|
xgetrf( Ztl.row(), Z.col(), Z.data(), Z.row(), ipiv.data() ); |
197 |
|
/** Zbl * U^{-1} */ |
198 |
|
xtrsm( "Right", "Upper", "No transpose", "Non-unit", Zbl.row(), Zbl.col(), |
199 |
|
1.0, Ztl.data(), Ztl.ld(), Zbl.data(), Zbl.ld() ); |
200 |
|
/** Update Schur complement Zbr. */ |
201 |
|
xgemm( "No transpose", "No transpose", Zbr.row(), Zbr.col(), Ztl.col(), |
202 |
|
-1.0, Zbl.data(), Zbl.ld(), |
203 |
|
Ztr.data(), Ztr.ld(), |
204 |
|
1.0, Zbr.data(), Zbr.ld() ); |
205 |
|
|
206 |
|
}; /** end PartialFactorize() */ |
207 |
|
|
208 |
|
|
209 |
|
|
210 |
|
|
211 |
|
/** |
212 |
|
* two-sided UCVt one-sided UBt |
213 |
|
* |
214 |
|
* | sl sr | sl sr |
215 |
|
* ------------- ------------- |
216 |
|
* nl | Ul nl | Ul |
217 |
|
* nr | Ur nr | Ur |
218 |
|
* |
219 |
|
* | sl sr |
220 |
|
* ------------- |
221 |
|
* sl | Clr |
222 |
|
* sr | Crl |
223 |
|
* |
224 |
|
* | nl nr | nl nr |
225 |
|
* ------------- ------------- |
226 |
|
* sl | Vlt sl | Brt |
227 |
|
* sr | Vrt sr | Blt |
228 |
|
* |
229 |
|
* |
230 |
|
**/ |
231 |
|
void Factorize |
232 |
|
( |
233 |
|
/** Ul, nl-by-sl */ |
234 |
|
Data<T> &Ul, |
235 |
|
/** Ur, nr-by-sr */ |
236 |
|
Data<T> &Ur, |
237 |
|
/** Vl, nl-by-sr */ |
238 |
|
Data<T> &Vl, |
239 |
|
/** Vr, nr-by-sr */ |
240 |
|
Data<T> &Vr |
241 |
|
) |
242 |
|
{ |
243 |
|
assert( !isleaf ); |
244 |
|
//assert( Ul.row() == nl ); assert( Ul.col() == sl ); |
245 |
|
//assert( Ur.row() == nr ); assert( Ur.col() == sr ); |
246 |
|
//assert( Vl.row() == nl ); assert( Vl.col() == sl ); |
247 |
|
//assert( Vr.row() == nr ); assert( Vr.col() == sr ); |
248 |
|
|
249 |
|
/** even SYMMETRIC this routine uses LU factorization */ |
250 |
|
if ( issymmetric ) |
251 |
|
{ |
252 |
|
assert( Crl.row() == sr ); assert( Crl.col() == sl ); |
253 |
|
} |
254 |
|
else |
255 |
|
{ |
256 |
|
assert( Clr.row() == sl ); assert( Clr.col() == sr ); |
257 |
|
assert( Crl.row() == sr ); assert( Crl.col() == sl ); |
258 |
|
} |
259 |
|
|
260 |
|
/** |
261 |
|
* clean up and begin with Z = eye( sl + sr ) = | sl sr |
262 |
|
* ------------ |
263 |
|
* sl | Zrl Ztr |
264 |
|
* sr | Zbl Zbr |
265 |
|
**/ |
266 |
|
Z.resize( 0, 0 ); |
267 |
|
Z.resize( sl + sr, sl + sr, 0.0 ); |
268 |
|
for ( size_t i = 0; i < sl + sr; i ++ ) Z[ i * Z.row() + i ] = 1.0; |
269 |
|
|
270 |
|
|
271 |
|
|
272 |
|
if ( do_ulv_factorization ) |
273 |
|
{ |
274 |
|
/** |
275 |
|
* Z = I + UR * C * VR' = [ I URl * Clr * VRr' |
276 |
|
* URr * Crl * VRl' I ] |
277 |
|
**/ |
278 |
|
if ( issymmetric ) /** Cholesky */ |
279 |
|
{ |
280 |
|
/** Zbl = URr * Crl * VRl' */ |
281 |
|
hmlp::Data<T> Zbl = Crl; |
282 |
|
|
283 |
|
//printf( "Crl\n" ); |
284 |
|
//Crl.Print(); |
285 |
|
|
286 |
|
|
287 |
|
/** trmm */ |
288 |
|
xtrmm |
289 |
|
( |
290 |
|
"Right", "Upper", "Transpose", "Non-unit", |
291 |
|
Zbl.row(), Zbl.col(), |
292 |
|
1.0, Ul.data(), Ul.row(), |
293 |
|
Zbl.data(), Zbl.row() |
294 |
|
); |
295 |
|
//printf( "Ul.row() %lu Zbl.row() %lu Zbl.col() %lu\n", |
296 |
|
// Ul.row(), Zbl.row(), Zbl.col() ); |
297 |
|
|
298 |
|
/** trmm */ |
299 |
|
xtrmm |
300 |
|
( |
301 |
|
"Left", "Upper", "Non-transpose", "Non-unit", |
302 |
|
Zbl.row(), Zbl.col(), |
303 |
|
1.0, Ur.data(), Ur.row(), |
304 |
|
Zbl.data(), Zbl.row() |
305 |
|
); |
306 |
|
//printf( "Ur.row() %lu Zbl.row() %lu Zbl.col() %lu\n", |
307 |
|
// Ur.row(), Zbl.row(), Zbl.col() ); |
308 |
|
|
309 |
|
/** Zbl */ |
310 |
|
for ( size_t j = 0; j < sl; j ++ ) |
311 |
|
for ( size_t i = 0; i < sr; i ++ ) |
312 |
|
{ |
313 |
|
Z( sl + i, j ) = Zbl( i, j ); |
314 |
|
Z( j, sl + i ) = Zbl( i, j ); |
315 |
|
} |
316 |
|
|
317 |
|
/** LL' = potrf( Z ) */ |
318 |
|
if ( 1 ) |
319 |
|
{ |
320 |
|
xpotrf( "Lower", Z.row(), Z.data(), Z.row() ); |
321 |
|
//CheckCondition(); |
322 |
|
} |
323 |
|
else |
324 |
|
{ |
325 |
|
/** pivoting row indices */ |
326 |
|
ipiv.resize( Z.row(), 0 ); |
327 |
|
xgetrf( Z.row(), Z.col(), Z.data(), Z.row(), ipiv.data() ); |
328 |
|
} |
329 |
|
} |
330 |
|
else /** LU */ |
331 |
|
{ |
332 |
|
/** pivoting row indices */ |
333 |
|
ipiv.resize( Z.row(), 0 ); |
334 |
|
} |
335 |
|
} |
336 |
|
else /** Sherman-Morrison-Woodbury */ |
337 |
|
{ |
338 |
|
/** pivoting row indices */ |
339 |
|
ipiv.resize( Z.row(), 0 ); |
340 |
|
|
341 |
|
/** |
342 |
|
* Z = I + CVtU = [ I ClrVrtUr |
343 |
|
* CrlVltUl I ] |
344 |
|
**/ |
345 |
|
std::vector<T> VltUl( sl * sl, 0.0 ); |
346 |
|
std::vector<T> VrtUr( sr * sr, 0.0 ); |
347 |
|
|
348 |
|
/** VltUl */ |
349 |
|
xgemm( "T", "N", sl, sl, nl, |
350 |
|
1.0, Vl.data(), nl, |
351 |
|
Ul.data(), nl, |
352 |
|
0.0, VltUl.data(), sl ); |
353 |
|
|
354 |
|
/** VrtUr */ |
355 |
|
xgemm( "T", "N", sr, sr, nr, |
356 |
|
1.0, Vr.data(), nr, |
357 |
|
Ur.data(), nr, |
358 |
|
0.0, VrtUr.data(), sr ); |
359 |
|
|
360 |
|
/** CrlVltUl */ |
361 |
|
xgemm( "N", "N", sr, sl, sl, |
362 |
|
1.0, Crl.data(), sr, |
363 |
|
VltUl.data(), sl, |
364 |
|
0.0, Z.data() + sl, sl + sr ); |
365 |
|
|
366 |
|
|
367 |
|
if ( issymmetric ) |
368 |
|
{ |
369 |
|
/** Crl'VrtUr */ |
370 |
|
xgemm( "T", "N", sl, sr, sr, |
371 |
|
1.0, Crl.data(), sr, |
372 |
|
VrtUr.data(), sr, |
373 |
|
0.0, Z.data() + ( sl + sr ) * sl, sl + sr ); |
374 |
|
} |
375 |
|
else |
376 |
|
{ |
377 |
|
printf( "bug\n" ); exit( 1 ); |
378 |
|
/** ClrVrtUr */ |
379 |
|
xgemm( "N", "N", sl, sr, sr, |
380 |
|
1.0, Clr.data(), sl, |
381 |
|
VrtUr.data(), sr, |
382 |
|
0.0, Z.data() + ( sl + sr ) * sl, sl + sr ); |
383 |
|
} |
384 |
|
|
385 |
|
/** compute 1-norm of Z */ |
386 |
|
T nrm1 = 0.0; |
387 |
|
for ( size_t i = 0; i < Z.size(); i ++ ) |
388 |
|
nrm1 += std::abs( Z[ i ] ); |
389 |
|
|
390 |
|
/** LU factorization */ |
391 |
|
xgetrf( Z.row(), Z.col(), Z.data(), Z.row(), ipiv.data() ); |
392 |
|
|
393 |
|
/** record points of children factors */ |
394 |
|
this->Ul = &Ul; |
395 |
|
this->Ur = &Ur; |
396 |
|
this->Vl = &Vl; |
397 |
|
this->Vr = &Vr; |
398 |
|
|
399 |
|
/** compute 1-norm condition number */ |
400 |
|
T rcond1 = 0.0; |
401 |
|
hmlp::Data<T> work( Z.row(), 4 ); |
402 |
|
std::vector<int> iwork( Z.row() ); |
403 |
|
xgecon( "1", Z.row(), Z.data(), Z.row(), nrm1, |
404 |
|
&rcond1, work.data(), iwork.data() ); |
405 |
|
if ( 1.0 / rcond1 > 1E+6 ) |
406 |
|
printf( "Warning! large 1-norm condition number %3.1E\n", |
407 |
|
1.0 / rcond1 ); fflush( stdout ); |
408 |
|
} |
409 |
|
|
410 |
|
}; /** end Factorize() */ |
411 |
|
|
412 |
|
|
413 |
|
void PartialFactorize( |
414 |
|
/** Zl, nl-by-nl, Zr, nr-by-nr */ |
415 |
|
View<T> &Zl, View<T> &Zr, |
416 |
|
/** Ul, nl-by-sl, Ur, nr-by-sr */ |
417 |
|
Data<T> &Ul, Data<T> &Ur, |
418 |
|
/** Vl, nl-by-sr, Vr, nr-by-sr */ |
419 |
|
Data<T> &Vl, Data<T> &Vr ) |
420 |
|
{ |
421 |
|
Z.resize( 0, 0 ); |
422 |
|
Z.resize( sl + sr, sl + sr, 0.0 ); |
423 |
|
|
424 |
|
/** Create matrix views for Z. */ |
425 |
|
Zv.Set( false, Z ); |
426 |
|
Zv.Partition2x2( Ztl, Ztr, |
427 |
|
Zbl, Zbr, sl, sl, TOPLEFT ); |
428 |
|
|
429 |
|
//printf( "Ztl %lux%lu Ztr %lux%lu\n", Ztl.row(), Ztl.col(), Ztr.row(), Ztr.col() ); fflush( stdout ); |
430 |
|
//printf( "Zbl %lux%lu Zbr %lux%lu\n", Zbl.row(), Zbl.col(), Zbr.row(), Zbr.col() ); fflush( stdout ); |
431 |
|
|
432 |
|
|
433 |
|
Zbl.CopyValuesFrom( Crl ); |
434 |
|
/** trmm */ |
435 |
|
xtrmm( "Right", "Upper", "Transpose", "Non-unit", Zbl.row(), Zbl.col(), |
436 |
|
1.0, Ul.data(), Ul.row(), Zbl.data(), Zbl.ld() ); |
437 |
|
/** trmm */ |
438 |
|
xtrmm( "Left", "Upper", "Non-transpose", "Non-unit", Zbl.row(), Zbl.col(), |
439 |
|
1.0, Ur.data(), Ur.row(), Zbl.data(), Zbl.ld() ); |
440 |
|
|
441 |
|
Ztl.CopyValuesFrom( Zl ); |
442 |
|
Zbr.CopyValuesFrom( Zr ); |
443 |
|
|
444 |
|
for ( size_t j = 0; j < sl; j ++ ) |
445 |
|
for ( size_t i = 0; i < sr; i ++ ) |
446 |
|
Ztr( j, i ) = Zbl( i, j ); |
447 |
|
|
448 |
|
PartialFactorize( Z ); |
449 |
|
|
450 |
|
}; /** end PartialFactorize() */ |
451 |
|
|
452 |
|
|
453 |
|
|
454 |
|
|
455 |
|
/** */ |
456 |
|
void Multiply( View<T> &bl, View<T> &br ) |
457 |
|
{ |
458 |
|
assert( !isleaf && bl.col() == br.col() ); |
459 |
|
|
460 |
|
size_t nrhs = bl.col(); |
461 |
|
|
462 |
|
std::vector<T> ta( ( sl + sr ) * nrhs ); |
463 |
|
std::vector<T> tl( sl * nrhs ); |
464 |
|
std::vector<T> tr( sr * nrhs ); |
465 |
|
|
466 |
|
/** Vl' * bl */ |
467 |
|
xgemm( "T", "N", sl, nrhs, nl, |
468 |
|
1.0, Vl->data(), nl, |
469 |
|
bl.data(), bl.ld(), |
470 |
|
0.0, tl.data(), sl ); |
471 |
|
/** Vr' * br */ |
472 |
|
xgemm( "T", "N", sr, nrhs, nr, |
473 |
|
1.0, Vr->data(), nr, |
474 |
|
br.data(), br.ld(), |
475 |
|
0.0, tr.data(), sr ); |
476 |
|
|
477 |
|
/** Crl * Vl' * bl */ |
478 |
|
xgemm( "N", "N", sr, nrhs, sl, |
479 |
|
1.0, Crl.data(), sr, |
480 |
|
tl.data(), sl, |
481 |
|
0.0, ta.data() + sl, sl + sr ); |
482 |
|
|
483 |
|
if ( issymmetric ) |
484 |
|
{ |
485 |
|
/** Crl' * Vr' * br */ |
486 |
|
xgemm( "T", "N", sl, nrhs, sr, |
487 |
|
1.0, Crl.data(), sr, |
488 |
|
tr.data(), sr, |
489 |
|
0.0, ta.data(), sl + sr ); |
490 |
|
} |
491 |
|
else |
492 |
|
{ |
493 |
|
printf( "bug here !!!!!\n" ); fflush( stdout ); exit( 1 ); |
494 |
|
/** Clr * Vr' * br */ |
495 |
|
xgemm( "N", "N", sl, nrhs, sr, |
496 |
|
1.0, Clr.data(), sl, |
497 |
|
tr.data(), sr, |
498 |
|
0.0, ta.data(), sl + sr ); |
499 |
|
} |
500 |
|
|
501 |
|
/** bl += Ul * xl */ |
502 |
|
xgemm( "N", "N", nl, nrhs, sl, |
503 |
|
-1.0, Ul->data(), nl, |
504 |
|
ta.data(), sl + sr, |
505 |
|
1.0, bl.data(), bl.ld() ); |
506 |
|
|
507 |
|
/** br += Ur * xr */ |
508 |
|
xgemm( "N", "N", nr, nrhs, sr, |
509 |
|
-1.0, Ur->data(), nr, |
510 |
|
ta.data() + sl, sl + sr, |
511 |
|
1.0, br.data(), br.ld() ); |
512 |
|
}; |
513 |
|
|
514 |
|
/** |
515 |
|
* @brief Solver for leaf nodes |
516 |
|
*/ |
517 |
|
void Solve( View<T> &rhs ) |
518 |
|
{ |
519 |
|
/** assure this is a leaf node */ |
520 |
|
assert( isleaf ); |
521 |
|
assert( !do_ulv_factorization ); |
522 |
|
assert( rhs.data() && Z.data() ); |
523 |
|
assert( ipiv.data() ); |
524 |
|
|
525 |
|
//rhs.Print(); |
526 |
|
|
527 |
|
size_t nrhs = rhs.col(); |
528 |
|
|
529 |
|
/** LU solver */ |
530 |
|
xgetrs( "Non-transpose", rhs.row(), nrhs, |
531 |
|
Z.data(), Z.row(), ipiv.data(), |
532 |
|
rhs.data(), rhs.ld() ); |
533 |
|
|
534 |
|
}; /** end Solve() */ |
535 |
|
|
536 |
|
|
537 |
|
|
538 |
|
/** |
539 |
|
* @brief b - U * inv( Z ) * C * V' * b |
540 |
|
*/ |
541 |
|
void Solve( View<T> &bl, View<T> &br ) |
542 |
|
{ |
543 |
|
size_t nrhs = bl.col(); |
544 |
|
|
545 |
|
//bl.Print(); |
546 |
|
//br.Print(); |
547 |
|
|
548 |
|
/** assertion */ |
549 |
|
assert( !do_ulv_factorization ); |
550 |
|
assert( bl.col() == br.col() ); |
551 |
|
assert( bl.row() == nl ); |
552 |
|
assert( br.row() == nr ); |
553 |
|
assert( Ul && Ur && Vl && Vr ); |
554 |
|
|
555 |
|
/** buffer */ |
556 |
|
// hmlp::Data<T> ta( sl + sr, nrhs ); |
557 |
|
// hmlp::Data<T> tl( sl, nrhs ); |
558 |
|
// hmlp::Data<T> tr( sr, nrhs ); |
559 |
|
|
560 |
|
vector<T> ta( ( sl + sr ) * nrhs ); |
561 |
|
vector<T> tl( sl * nrhs ); |
562 |
|
vector<T> tr( sr * nrhs ); |
563 |
|
|
564 |
|
|
565 |
|
///** views of buffer */ |
566 |
|
//hmlp::View<T> xa( ta ), xl, xr; |
567 |
|
|
568 |
|
///** xa = [ xl; xr; ] */ |
569 |
|
//xa.Partition2x1 |
570 |
|
//( |
571 |
|
// xl, |
572 |
|
// xr, sl |
573 |
|
//); |
574 |
|
|
575 |
|
|
576 |
|
/** Vl' * bl */ |
577 |
|
xgemm( "T", "N", sl, nrhs, nl, |
578 |
|
1.0, Vl->data(), nl, |
579 |
|
bl.data(), bl.ld(), |
580 |
|
0.0, tl.data(), sl ); |
581 |
|
/** Vr' * br */ |
582 |
|
xgemm( "T", "N", sr, nrhs, nr, |
583 |
|
1.0, Vr->data(), nr, |
584 |
|
br.data(), br.ld(), |
585 |
|
0.0, tr.data(), sr ); |
586 |
|
|
587 |
|
|
588 |
|
/** Crl * Vl' * bl */ |
589 |
|
xgemm( "N", "N", sr, nrhs, sl, |
590 |
|
1.0, Crl.data(), sr, |
591 |
|
tl.data(), sl, |
592 |
|
0.0, ta.data() + sl, sl + sr ); |
593 |
|
|
594 |
|
if ( issymmetric ) |
595 |
|
{ |
596 |
|
/** Crl' * Vr' * br */ |
597 |
|
xgemm( "T", "N", sl, nrhs, sr, |
598 |
|
1.0, Crl.data(), sr, |
599 |
|
tr.data(), sr, |
600 |
|
0.0, ta.data(), sl + sr ); |
601 |
|
} |
602 |
|
else |
603 |
|
{ |
604 |
|
printf( "bug here !!!!!\n" ); fflush( stdout ); exit( 1 ); |
605 |
|
/** Clr * Vr' * br */ |
606 |
|
xgemm( "N", "N", sl, nrhs, sr, |
607 |
|
1.0, Clr.data(), sl, |
608 |
|
tr.data(), sr, |
609 |
|
0.0, ta.data(), sl + sr ); |
610 |
|
} |
611 |
|
|
612 |
|
/** inv( Z ) * x */ |
613 |
|
xgetrs( "N", sl + sr, nrhs, |
614 |
|
Z.data(), Z.row(), ipiv.data(), |
615 |
|
ta.data(), sl + sr ); |
616 |
|
|
617 |
|
/** bl -= Ul * xl */ |
618 |
|
xgemm( "N", "N", nl, nrhs, sl, |
619 |
|
-1.0, Ul->data(), nl, |
620 |
|
ta.data(), sl + sr, |
621 |
|
1.0, bl.data(), bl.ld() ); |
622 |
|
|
623 |
|
/** br -= Ur * xr */ |
624 |
|
xgemm( "N", "N", nr, nrhs, sr, |
625 |
|
-1.0, Ur->data(), nr, |
626 |
|
ta.data() + sl, sl + sr, |
627 |
|
1.0, br.data(), br.ld() ); |
628 |
|
|
629 |
|
}; /** end Solve() */ |
630 |
|
|
631 |
|
|
632 |
|
|
633 |
|
|
634 |
|
void Telescope |
635 |
|
( |
636 |
|
bool DO_INVERSE, |
637 |
|
/** n-by-s */ |
638 |
|
Data<T> &Pa, |
639 |
|
/** s-by-(sl+sr) */ |
640 |
|
Data<T> &Palr |
641 |
|
) |
642 |
|
{ |
643 |
|
assert( isleaf ); |
644 |
|
/** Initialize Pa */ |
645 |
|
Pa.resize( n, s, 0.0 ); |
646 |
|
|
647 |
|
/** create view and subviews for Pa */ |
648 |
|
//hmlp::View<T> Xa; |
649 |
|
|
650 |
|
//Xa.Set( Pa ); |
651 |
|
|
652 |
|
assert( Palr.row() == s ); assert( Palr.col() == n ); |
653 |
|
|
654 |
|
/** Pa = Palr' */ |
655 |
|
for ( size_t j = 0; j < Pa.col(); j ++ ) |
656 |
|
for ( size_t i = 0; i < Pa.row(); i ++ ) |
657 |
|
Pa( i, j ) = Palr( j, i ); |
658 |
|
|
659 |
|
if ( DO_INVERSE ) |
660 |
|
{ |
661 |
|
if ( do_ulv_factorization ) |
662 |
|
{ |
663 |
|
xtrsm( "Left", "Lower", "No transpose", "Non-unit", |
664 |
|
Pa.row(), Pa.col(), |
665 |
|
1.0, Z.data(), Z.row(), Pa.data(), Pa.row() ); |
666 |
|
} |
667 |
|
else |
668 |
|
{ |
669 |
|
assert( ipiv.size() ); |
670 |
|
/** LU solver */ |
671 |
|
xgetrs( "Non-transpose", |
672 |
|
n, s, Z.data(), n, ipiv.data(), Pa.data(), n ); |
673 |
|
} |
674 |
|
} |
675 |
|
|
676 |
|
|
677 |
|
//printf( "call solve from telescope\n" ); fflush( stdout ); |
678 |
|
//if ( DO_INVERSE ) Solve<true>( Xa ); |
679 |
|
//printf( "call solve from telescope (exist)\n" ); fflush( stdout ); |
680 |
|
|
681 |
|
}; /** end Telescope() */ |
682 |
|
|
683 |
|
|
684 |
|
/** RIGHT: V = [ P(:, 0:st-1) * Vl , P(:,st:st+sb-1) * Vr ] |
685 |
|
* LEFT: U = [ Ul * P(:, 0:st-1)'; Ur * P(:,st:st+sb-1) ] */ |
686 |
|
void Telescope |
687 |
|
( |
688 |
|
bool DO_INVERSE, |
689 |
|
/** n-by-s */ |
690 |
|
Data<T> &Pa, |
691 |
|
/** s-by-(sl+sr) */ |
692 |
|
Data<T> &Palr, |
693 |
|
/** nl-by-sl */ |
694 |
|
Data<T> &Pl, |
695 |
|
/** nr-by-sr */ |
696 |
|
Data<T> &Pr |
697 |
|
) |
698 |
|
{ |
699 |
|
assert( !isleaf ); |
700 |
|
assert( n == nl + nr ); |
701 |
|
assert( Pl.col() == sl ); |
702 |
|
assert( Pr.col() == sr ); |
703 |
|
assert( Palr.row() == s ); assert( Palr.col() == ( sl + sr ) ); |
704 |
|
|
705 |
|
/** Initialize Pa */ |
706 |
|
Pa.resize( 0, 0 ); |
707 |
|
|
708 |
|
/** create view and subviews for Pa */ |
709 |
|
//hmlp::View<T> Xa; |
710 |
|
|
711 |
|
//Xa.Set( Pa ); |
712 |
|
//assert( Xa.row() == Pa.row() ); |
713 |
|
//assert( Xa.col() == Pa.col() ); |
714 |
|
|
715 |
|
if ( do_ulv_factorization ) |
716 |
|
{ |
717 |
|
Pa.resize( sl + sr, s, 0.0 ); |
718 |
|
|
719 |
|
/** Pa = Palr' */ |
720 |
|
for ( size_t j = 0; j < Pa.col(); j ++ ) |
721 |
|
for ( size_t i = 0; i < Pa.row(); i ++ ) |
722 |
|
Pa[ j * Pa.row() + i ] = Palr[ i * Palr.row() + j ]; |
723 |
|
|
724 |
|
/** Pa( 0:sl-1, : ) = Pl * Palr( :, 0:sl-1 )' */ |
725 |
|
xtrmm( "Left", "Upper", "No Transpose", "Non-unit", sl, s, |
726 |
|
1.0, Pl.data(), Pl.row(), |
727 |
|
Pa.data(), Pa.row() ); |
728 |
|
// printf( "Pl.row() %lu Pa.row() %lu Pa.col() %lu\n", |
729 |
|
// Pl.row(), Pa.row(), Pa.col() ); |
730 |
|
|
731 |
|
/** Pa( sl:sl+sr-1, : ) = Pr * Palr( :, sl:sl+sr-1 )' */ |
732 |
|
xtrmm( "Left", "Upper", "No Transpose", "Non-unit", sr, s, |
733 |
|
1.0, Pr.data() , Pr.row(), |
734 |
|
Pa.data() + sl, Pa.row() ); |
735 |
|
// printf( "Pr.row() %lu Pa.row() %lu Pa.col() %lu\n", |
736 |
|
// Pr.row(), Pa.row(), Pa.col() ); |
737 |
|
|
738 |
|
/** inv( L ) * Pa */ |
739 |
|
if ( DO_INVERSE ) |
740 |
|
{ |
741 |
|
if ( 1 ) |
742 |
|
{ |
743 |
|
xtrsm( "Left", "Lower", "No transpose", "Non-unit", |
744 |
|
Pa.row(), Pa.col(), |
745 |
|
1.0, Z.data(), Z.row(), |
746 |
|
Pa.data(), Pa.row() ); |
747 |
|
} |
748 |
|
else |
749 |
|
{ |
750 |
|
xlaswp( Pa.col(), Pa.data(), Pa.row(), |
751 |
|
1, Pa.row(), ipiv.data(), 1 ); |
752 |
|
xtrsm( "Left", "Lower", "No transpose", "Unit", Pa.row(), Pa.col(), |
753 |
|
1.0, Z.data(), Z.row(), |
754 |
|
Pa.data(), Pa.row() ); |
755 |
|
} |
756 |
|
//printf( "Z.row() %lu Z.col() %lu\n", Z.row(), Z.col() ); |
757 |
|
} |
758 |
|
|
759 |
|
} |
760 |
|
else /** Shernman-Morrison-Woodbury */ |
761 |
|
{ |
762 |
|
Pa.resize( nl + nr, s, 0.0 ); |
763 |
|
|
764 |
|
///** */ |
765 |
|
//hmlp::View<T> Xl, Xr; |
766 |
|
|
767 |
|
///** Xa = [ Xl; Xr; ] */ |
768 |
|
//Xa.Partition2x1 |
769 |
|
//( |
770 |
|
// Xl, |
771 |
|
// Xr, nl |
772 |
|
//); |
773 |
|
|
774 |
|
//assert( Xl.row() == nl ); |
775 |
|
//assert( Xr.row() == nr ); |
776 |
|
//assert( Xl.col() == s ); |
777 |
|
//assert( Xr.col() == s ); |
778 |
|
|
779 |
|
|
780 |
|
/** Pa( 0:nl-1, : ) = Pl * Palr( :, 0:sl-1 )' */ |
781 |
|
xgemm( "N", "T", nl, s, sl, |
782 |
|
1.0, Pl.data(), nl, |
783 |
|
Palr.data(), s, |
784 |
|
0.0, Pa.data(), n ); |
785 |
|
/** Pa( nl:n-1, : ) = Pr * Palr( :, sl:sl+sr-1 )' */ |
786 |
|
xgemm( "N", "T", nr, s, sr, |
787 |
|
1.0, Pr.data(), nr, |
788 |
|
Palr.data() + s * sl, s, |
789 |
|
0.0, Pa.data() + nl, n ); |
790 |
|
|
791 |
|
|
792 |
|
|
793 |
|
//if ( DO_INVERSE ) Solve<true>( Xl, Xr ); |
794 |
|
//printf( "end inner solve from telescope\n" ); fflush( stdout ); |
795 |
|
|
796 |
|
if ( DO_INVERSE ) |
797 |
|
{ |
798 |
|
Data<T> x( sl + sr, s ); |
799 |
|
Data<T> xl( sl, s ); |
800 |
|
Data<T> xr( sr, s ); |
801 |
|
|
802 |
|
/** xl = Vlt * Pa( 0:nl-1, : ) */ |
803 |
|
xgemm( "T", "N", sl, s, nl, |
804 |
|
1.0, Vl->data(), nl, |
805 |
|
Pa.data(), n, |
806 |
|
0.0, xl.data(), sl ); |
807 |
|
/** xr = Vrt * Pa( nl:n-1, : ) */ |
808 |
|
xgemm( "T", "N", sr, s, nr, |
809 |
|
1.0, Vr->data(), nr, |
810 |
|
Pa.data() + nl, n, |
811 |
|
0.0, xr.data(), sr ); |
812 |
|
|
813 |
|
/** b = [ Crl' * xr; |
814 |
|
* Crl * xl; ] */ |
815 |
|
xgemm( "T", "N", sl, s, sr, |
816 |
|
1.0, Crl.data(), sr, |
817 |
|
xr.data(), sr, |
818 |
|
0.0, x.data(), sl + sr ); |
819 |
|
xgemm( "N", "N", sr, s, sl, |
820 |
|
1.0, Crl.data(), sr, |
821 |
|
xl.data(), sl, |
822 |
|
0.0, x.data() + sl, sl + sr ); |
823 |
|
|
824 |
|
/** b = inv( Z ) * b */ |
825 |
|
xgetrs( "N", x.row(), x.col(), |
826 |
|
Z.data(), Z.row(), ipiv.data(), |
827 |
|
x.data(), x.row() ); |
828 |
|
|
829 |
|
/** Pa( 0:nl-1, : ) -= Ul * b( 0:sl-1, : ) */ |
830 |
|
xgemm( "N", "N", nl, s, sl, |
831 |
|
-1.0, Ul->data(), nl, |
832 |
|
x.data(), sl + sr, |
833 |
|
1.0, Pa.data(), n ); |
834 |
|
/** Pa( nl:n-1, : ) -= Ur * b( sl:sl+sr-1, : ) */ |
835 |
|
xgemm( "N", "N", nr, s, sr, |
836 |
|
-1.0, Ur->data(), nr, |
837 |
|
x.data() + sl, sl + sr, |
838 |
|
1.0, Pa.data() + nl, n ); |
839 |
|
} /** end if ( DO_INVERSE ) */ |
840 |
|
} /** end if ( do_ulv_factorization )*/ |
841 |
|
}; |
842 |
|
|
843 |
|
/** */ |
844 |
|
void Orthogonalization() |
845 |
|
{ |
846 |
|
/** Initialize householder reflectors "tau". */ |
847 |
|
tau.resize( std::min( U.row(), U.col() ) ); |
848 |
|
/** Initialize work space for xgeqrf. */ |
849 |
|
Data<T> work( U.col() * 512, 1 ); |
850 |
|
/** QR factorization without column pivoting. */ |
851 |
|
xgeqrf( U.row(), U.col(), U.data(), U.row(), |
852 |
|
tau.data(), work.data(), work.size() ); |
853 |
|
/** Copy U to Q to generate the full orthonormal basis. */ |
854 |
|
Q = U; |
855 |
|
/** Increase the rank of Q to full rank. */ |
856 |
|
Q.resize( U.row(), U.row() ); |
857 |
|
/** Generate the full orthonormal basis Q. */ |
858 |
|
xorgqr( Q.row(), Q.col(), U.col(), Q.data(), Q.row(), tau.data(), |
859 |
|
work.data(), work.size() ); |
860 |
|
|
861 |
|
|
862 |
|
|
863 |
|
/** Create views Qv = [Q1, Q2] for Q. */ |
864 |
|
Qv.Set( false, Q ); |
865 |
|
Qv.Partition1x2( Q1, Q2, tau.size(), LEFT ); |
866 |
|
/** Sanity check for Q1'Q1 and Q2'Q2 and Q1'Q2. */ |
867 |
|
Data<T> C = Q; |
868 |
|
Data<T> D = Q; |
869 |
|
|
870 |
|
xgemm( "Transpose", "No Transpose", C.row(), C.col(), Q.row(), |
871 |
|
1.0, Q.data(), Q.row(), |
872 |
|
Q.data(), Q.row(), |
873 |
|
0.0, C.data(), C.row() ); |
874 |
|
|
875 |
|
xgemm( "No Transpose", "Transpose", D.row(), D.col(), Q.row(), |
876 |
|
1.0, Q.data(), Q.row(), |
877 |
|
Q.data(), Q.row(), |
878 |
|
0.0, D.data(), D.row() ); |
879 |
|
|
880 |
|
for ( size_t j = 0; j < Q.col(); j ++ ) |
881 |
|
{ |
882 |
|
for ( size_t i = 0; i < Q.row(); i ++ ) |
883 |
|
{ |
884 |
|
if ( i == j ) assert( std::fabs( C( i, j ) - 1 ) < 1E-5 ); |
885 |
|
else assert( std::fabs( C( i, j ) - 0 ) < 1E-5 ); |
886 |
|
} |
887 |
|
} |
888 |
|
for ( size_t j = 0; j < Q.col(); j ++ ) |
889 |
|
{ |
890 |
|
for ( size_t i = 0; i < Q.row(); i ++ ) |
891 |
|
{ |
892 |
|
if ( i == j ) assert( std::fabs( D( i, j ) - 1 ) < 1E-5 ); |
893 |
|
else assert( std::fabs( D( i, j ) - 0 ) < 1E-5 ); |
894 |
|
} |
895 |
|
} |
896 |
|
|
897 |
|
}; |
898 |
|
|
899 |
|
|
900 |
|
|
901 |
|
|
902 |
|
/** [Q2 Q1]' * B or B * [Q2 Q1] */ |
903 |
|
void ChangeBasis( SideType side, Data<T> &B ) |
904 |
|
{ |
905 |
|
/** Early return if Q does not exist. */ |
906 |
|
if ( !Q.size() ) return; |
907 |
|
|
908 |
|
/** Create a deep copy of B. */ |
909 |
|
Data<T> A = B; |
910 |
|
|
911 |
|
/** Create matrix views for A and B. */ |
912 |
|
View<T> Av( false, A ); |
913 |
|
View<T> Bv( false, B ); |
914 |
|
View<T> Bl, Br, Bt, Bb; |
915 |
|
|
916 |
|
/** Enumerate case "LEFT", "RIGHT", and execptions. */ |
917 |
|
switch ( side ) |
918 |
|
{ |
919 |
|
case LEFT: |
920 |
|
{ |
921 |
|
/** Partition Bv = [ Bt; Bb ]. */ |
922 |
|
Bv.Partition2x1( Bt, |
923 |
|
Bb, Q2.col(), TOP ); |
924 |
|
//printf("Bt %lux%lu Bb %lux%lu\n", Bt.row(), Bt.col(), |
925 |
|
// Bb.row(), Bb.col() ); fflush( stdout ); |
926 |
|
|
927 |
|
/** Bt = Q2' * A */ |
928 |
|
xgemm( "Transpose", "No Transpose", Bt.row(), Bt.col(), Q2.row(), |
929 |
|
1.0, Q2.data(), Q2.ld(), |
930 |
|
Av.data(), Av.ld(), |
931 |
|
0.0, Bt.data(), Bt.ld() ); |
932 |
|
/** Bb = Q1' * A */ |
933 |
|
xgemm( "Transpose", "No Transpose", Bb.row(), Bb.col(), Q1.row(), |
934 |
|
1.0, Q1.data(), Q1.ld(), |
935 |
|
Av.data(), Av.ld(), |
936 |
|
0.0, Bb.data(), Bb.ld() ); |
937 |
|
break; |
938 |
|
} |
939 |
|
case RIGHT: |
940 |
|
{ |
941 |
|
/** Partition Bv = [ Bl, Br ]. */ |
942 |
|
Bv.Partition1x2( Bl, Br, Q2.col(), LEFT ); |
943 |
|
|
944 |
|
//printf("Bl %lux%lu Br %lux%lu\n", Bl.row(), Bl.col(), |
945 |
|
// Br.row(), Br.col() ); fflush( stdout ); |
946 |
|
|
947 |
|
/** Bl = A * Q2 */ |
948 |
|
xgemm( "No Transpose", "No Transpose", Bl.row(), Bl.col(), Q2.row(), |
949 |
|
1.0, Av.data(), Av.ld(), |
950 |
|
Q2.data(), Q2.ld(), |
951 |
|
0.0, Bl.data(), Bl.ld() ); |
952 |
|
/** Br = A * Q1 */ |
953 |
|
xgemm( "No Transpose", "No Transpose", Br.row(), Br.col(), Q1.row(), |
954 |
|
1.0, Av.data(), Av.ld(), |
955 |
|
Q1.data(), Q1.ld(), |
956 |
|
0.0, Br.data(), Br.ld() ); |
957 |
|
break; |
958 |
|
} |
959 |
|
default: |
960 |
|
{ |
961 |
|
/** Do nothing and throw exception. */ |
962 |
|
throw "Value of (SideType) side is not recognized."; |
963 |
|
} |
964 |
|
} |
965 |
|
//printf( "end ChangeBasis\n" ); fflush( stdout ); |
966 |
|
}; /** changeBasis() */ |
967 |
|
|
968 |
|
|
969 |
|
/** [Q2 Q1]' * A * [Q2 Q1] */ |
970 |
|
void ChangeBasis( Data<T> &A ) |
971 |
|
{ |
972 |
|
ChangeBasis( LEFT, A ); |
973 |
|
ChangeBasis( RIGHT, A ); |
974 |
|
}; /** changeBasis() */ |
975 |
|
|
976 |
|
|
977 |
|
|
978 |
|
void ULVForward() |
979 |
|
{ |
980 |
|
/** For internal nodes, B has been initialized by children. */ |
981 |
|
if ( isleaf ) B = bview.toData(); |
982 |
|
/** B = Q' * B */ |
983 |
|
ChangeBasis( LEFT, B ); |
984 |
|
/** P * Bf */ |
985 |
|
xlaswp( Bf.col(), Bf.data(), Bf.ld(), 1, Bf.row(), ipiv.data(), 1 ); |
986 |
|
/** Lff^{-1} * P * Bf, where Lff is the lower-triangular part of Ztl. */ |
987 |
|
xtrsm( "Left", "Lower", "No transpose", "Unit", Bf.row(), Bf.col(), |
988 |
|
1.0, Ztl.data(), Ztl.ld(), Bf.data(), Bf.ld() ); |
989 |
|
/** Bc -= Lcf * Bf, where Lcf is Zbl. */ |
990 |
|
xgemm( "No Transpose", "No Transpose", Bc.row(), Bc.col(), Bf.row(), |
991 |
|
-1.0, Zbl.data(), Zbl.ld(), Bf.data(), Bf.ld(), 1.0, Bc.data(), Bc.ld() ); |
992 |
|
//printf( "Bc %lux%lu Bp %lux%lu\n", Bc.row(), Bc.col(), Bp.row(), Bp.col() ); fflush( stdout ); |
993 |
|
/** Copy Bc to Bp (subview of parent's B). */ |
994 |
|
Bp.CopyValuesFrom( Bc ); |
995 |
|
}; /** end ULVForward() */ |
996 |
|
|
997 |
|
|
998 |
|
void ULVBackward() |
999 |
|
{ |
1000 |
|
/** Copy Bp (subview of parent's B) to Bc. */ |
1001 |
|
Bc.CopyValuesFrom( Bp ); |
1002 |
|
/** Bf -= Ufc * Bc, where Ufc is Ztr. */ |
1003 |
|
xgemm( "No Transpose", "No Transpose", Bf.row(), Bf.col(), Bc.row(), |
1004 |
|
-1.0, Ztr.data(), Ztr.ld(), Bc.data(), Bc.ld(), 1.0, Bf.data(), Bf.ld() ); |
1005 |
|
/** Lff^{-1} * P * Bf, where Lff is the lower-triangular part of Ztl. */ |
1006 |
|
xtrsm( "Left", "Upper", "No transpose", "Non-unit", Bf.row(), Bf.col(), |
1007 |
|
1.0, Ztl.data(), Ztl.ld(), Bf.data(), Bf.ld() ); |
1008 |
|
if ( Q.size() ) |
1009 |
|
{ |
1010 |
|
/** Create a temporary buffer for projection Q2 * Bf + Q1 * Bc. */ |
1011 |
|
Data<T> A = B; |
1012 |
|
xgemm( "No Transpose", "No Transpose", A.row(), A.col(), Bf.row(), |
1013 |
|
1.0, Q2.data(), Q2.ld(), Bf.data(), Bf.ld(), 0.0, A.data(), A.row() ); |
1014 |
|
xgemm( "No Transpose", "No Transpose", A.row(), A.col(), Bc.row(), |
1015 |
|
1.0, Q1.data(), Q1.ld(), Bc.data(), Bc.ld(), 1.0, A.data(), A.row() ); |
1016 |
|
/** Copy A back to B. */ |
1017 |
|
if ( isleaf ) bview.CopyValuesFrom( A ); |
1018 |
|
else Bv.CopyValuesFrom( A ); |
1019 |
|
} |
1020 |
|
}; /** end ULVBackward() */ |
1021 |
|
|
1022 |
|
|
1023 |
|
|
1024 |
|
|
1025 |
|
|
1026 |
|
|
1027 |
|
bool isleaf = false; |
1028 |
|
|
1029 |
|
bool isroot = false; |
1030 |
|
|
1031 |
|
size_t n = 0; |
1032 |
|
|
1033 |
|
size_t nl = 0; |
1034 |
|
|
1035 |
|
size_t nr = 0; |
1036 |
|
|
1037 |
|
size_t s = 0; |
1038 |
|
|
1039 |
|
size_t sl = 0; |
1040 |
|
|
1041 |
|
size_t sr = 0; |
1042 |
|
|
1043 |
|
|
1044 |
|
/** Reduced system Z = [ I VU if ( HODLR || p-HSS ) |
1045 |
|
* VU I ] */ |
1046 |
|
Data<T> Z; |
1047 |
|
View<T> Zv; |
1048 |
|
View<T> Ztl, Ztr, Zbl, Zbr; |
1049 |
|
|
1050 |
|
/** Partial pivoting order (used in GETRF). */ |
1051 |
|
vector<int> ipiv; |
1052 |
|
|
1053 |
|
/** n-by-s (SMW) or (sl+sr)-by-s (ULV) */ |
1054 |
|
Data<T> U, V; |
1055 |
|
|
1056 |
|
/** sr-by-sl and sl-by-sr, skeleton row and column basis. */ |
1057 |
|
Data<T> Crl, Clr; |
1058 |
|
|
1059 |
|
/** A correspinding view of the right hand side of this node. */ |
1060 |
|
View<T> bview; |
1061 |
|
|
1062 |
|
/** Pointers to children's factors */ |
1063 |
|
Data<T> *Ul = NULL; |
1064 |
|
Data<T> *Ur = NULL; |
1065 |
|
Data<T> *Vl = NULL; |
1066 |
|
Data<T> *Vr = NULL; |
1067 |
|
|
1068 |
|
/** Q, (sl+sr)-by-s (ULV) */ |
1069 |
|
Data<T> Q; |
1070 |
|
View<T> Qv, Q1, Q2; |
1071 |
|
|
1072 |
|
/** tau, sl+sr (used in xgeqrf( U ) of ULV) */ |
1073 |
|
vector<T> tau; |
1074 |
|
|
1075 |
|
/** Temporary buffer for the solve. */ |
1076 |
|
Data<T> B; |
1077 |
|
View<T> Bv, Bp, Bsibling, Bf, Bc; |
1078 |
|
|
1079 |
|
private: /** this class will be public inherit by gofmm::Data<T> */ |
1080 |
|
|
1081 |
|
bool issymmetric = true; |
1082 |
|
|
1083 |
|
bool do_ulv_factorization = false; |
1084 |
|
|
1085 |
|
}; /** end class Factor */ |
1086 |
|
|
1087 |
|
|
1088 |
|
/** |
1089 |
|
* @brief |
1090 |
|
*/ |
1091 |
|
template<typename NODE, typename T> |
1092 |
|
void SetupFactor( NODE *node ) |
1093 |
|
{ |
1094 |
|
size_t n, nl, nr, s, sl, sr; |
1095 |
|
bool issymmetric, do_ulv_factorization; |
1096 |
|
|
1097 |
|
|
1098 |
|
#ifdef DEBUG_IGOFMM |
1099 |
|
printf( "begin SetupFactor %lu\n", node->treelist_id ); fflush( stdout ); |
1100 |
|
#endif |
1101 |
|
|
1102 |
|
issymmetric = node->setup->IsSymmetric(); |
1103 |
|
do_ulv_factorization = node->setup->do_ulv_factorization; |
1104 |
|
n = node->n; |
1105 |
|
nl = 0; |
1106 |
|
nr = 0; |
1107 |
|
s = node->data.skels.size(); |
1108 |
|
sl = 0; |
1109 |
|
sr = 0; |
1110 |
|
|
1111 |
|
if ( !node->isleaf ) |
1112 |
|
{ |
1113 |
|
nl = node->lchild->n; |
1114 |
|
nr = node->rchild->n; |
1115 |
|
sl = node->lchild->data.skels.size(); |
1116 |
|
sr = node->rchild->data.skels.size(); |
1117 |
|
} |
1118 |
|
|
1119 |
|
|
1120 |
|
node->data.SetupFactor( issymmetric, do_ulv_factorization, |
1121 |
|
node->isleaf, !node->l, n, nl, nr, s, sl, sr ); |
1122 |
|
|
1123 |
|
#ifdef DEBUG_IGOFMM |
1124 |
|
printf( "end SetupFactor %lu\n", node->treelist_id ); fflush( stdout ); |
1125 |
|
#endif |
1126 |
|
|
1127 |
|
}; /** end void SetupFactor() */ |
1128 |
|
|
1129 |
|
|
1130 |
|
/** |
1131 |
|
* @brief |
1132 |
|
*/ |
1133 |
|
template<typename NODE, typename T> |
1134 |
|
class SetupFactorTask : public Task |
1135 |
|
{ |
1136 |
|
public: |
1137 |
|
|
1138 |
|
NODE *arg = NULL; |
1139 |
|
|
1140 |
|
void Set( NODE *user_arg ) |
1141 |
|
{ |
1142 |
|
arg = user_arg; |
1143 |
|
name = string( "sf" ); |
1144 |
|
label = to_string( arg->treelist_id ); |
1145 |
|
cost = 1.0; |
1146 |
|
}; |
1147 |
|
|
1148 |
|
void GetEventRecord() |
1149 |
|
{ |
1150 |
|
double flops = 0.0, mops = 0.0; |
1151 |
|
event.Set( label + name, flops, mops ); |
1152 |
|
}; |
1153 |
|
|
1154 |
|
void DependencyAnalysis() |
1155 |
|
{ |
1156 |
|
arg->DependencyAnalysis( W, this ); |
1157 |
|
this->TryEnqueue(); |
1158 |
|
}; |
1159 |
|
|
1160 |
|
void Execute( Worker* user_worker ) |
1161 |
|
{ |
1162 |
|
SetupFactor<NODE, T>( arg ); |
1163 |
|
}; |
1164 |
|
|
1165 |
|
}; /** end class SetupFactorTask */ |
1166 |
|
|
1167 |
|
|
1168 |
|
|
1169 |
|
|
1170 |
|
template<typename NODE> |
1171 |
|
void SolverTreeView( NODE *node ) |
1172 |
|
{ |
1173 |
|
auto &data = node->data; |
1174 |
|
auto *setup = node->setup; |
1175 |
|
auto &input = *(setup->input); |
1176 |
|
auto &output = *(setup->output); |
1177 |
|
/** Allocate working buffer for ULV solve. */ |
1178 |
|
if ( node->isleaf ) data.B.resize( data.n, input.col() ); |
1179 |
|
else data.B.resize( data.sl + data.sr, input.col() ); |
1180 |
|
|
1181 |
|
/** Partition B = [ Bf; Bc ] with matrix view. */ |
1182 |
|
data.Bv.Set( data.B ); |
1183 |
|
data.Bv.Partition2x1( data.Bf, |
1184 |
|
data.Bc, data.s, BOTTOM ); |
1185 |
|
|
1186 |
|
/** Create contigious matrix view for output at root level. */ |
1187 |
|
if ( !node->parent ) data.bview.Set( output ); |
1188 |
|
|
1189 |
|
/** Hierarchical tree view. */ |
1190 |
|
if ( !node->isleaf ) |
1191 |
|
{ |
1192 |
|
auto &ldata = node->lchild->data; |
1193 |
|
auto &rdata = node->rchild->data; |
1194 |
|
/** Partition b = [ bl; br; ] with matrix view. */ |
1195 |
|
data.bview.Partition2x1( ldata.bview, |
1196 |
|
rdata.bview, data.nl, TOP ); |
1197 |
|
data.Bv.Partition2x1( ldata.Bp, |
1198 |
|
rdata.Bp, data.sl, TOP ); |
1199 |
|
} |
1200 |
|
}; /** end SolverTreeView() */ |
1201 |
|
|
1202 |
|
|
1203 |
|
|
1204 |
|
|
1205 |
|
/** @brief Creates an hierarchical tree view for a matrix. */ |
1206 |
|
template<typename NODE> |
1207 |
|
class SolverTreeViewTask : public Task |
1208 |
|
{ |
1209 |
|
public: |
1210 |
|
|
1211 |
|
NODE *arg = NULL; |
1212 |
|
|
1213 |
|
void Set( NODE *user_arg ) |
1214 |
|
{ |
1215 |
|
arg = user_arg; |
1216 |
|
name = string( "TreeView" ); |
1217 |
|
label = to_string( arg->treelist_id ); |
1218 |
|
cost = 1.0; |
1219 |
|
}; |
1220 |
|
|
1221 |
|
void GetEventRecord() |
1222 |
|
{ |
1223 |
|
double flops = 0.0, mops = 0.0; |
1224 |
|
event.Set( label + name, flops, mops ); |
1225 |
|
}; |
1226 |
|
|
1227 |
|
/** Preorder dependencies (with a single source node) */ |
1228 |
|
void DependencyAnalysis() { arg->DependOnParent( this ); }; |
1229 |
|
|
1230 |
|
void Execute( Worker* user_worker ) { SolverTreeView( arg ); }; |
1231 |
|
|
1232 |
|
}; /** end class TreeViewTask */ |
1233 |
|
|
1234 |
|
|
1235 |
|
|
1236 |
|
/** |
1237 |
|
* @brief doward traversal to create matrix views, at the leaf |
1238 |
|
* level execute explicit permutation. |
1239 |
|
*/ |
1240 |
|
template<bool FORWARD, typename NODE> |
1241 |
|
class MatrixPermuteTask : public hmlp::Task |
1242 |
|
{ |
1243 |
|
public: |
1244 |
|
|
1245 |
|
NODE *arg; |
1246 |
|
|
1247 |
|
void Set( NODE *user_arg ) |
1248 |
|
{ |
1249 |
|
name = std::string( "MatrixPermutation" ); |
1250 |
|
arg = user_arg; |
1251 |
|
cost = 1.0; |
1252 |
|
}; |
1253 |
|
|
1254 |
|
void GetEventRecord() |
1255 |
|
{ |
1256 |
|
double flops = 0.0, mops = 0.0; |
1257 |
|
event.Set( label + name, flops, mops ); |
1258 |
|
}; |
1259 |
|
|
1260 |
|
/** depends on previous task */ |
1261 |
|
void DependencyAnalysis() |
1262 |
|
{ |
1263 |
|
if ( FORWARD ) |
1264 |
|
{ |
1265 |
|
arg->DependencyAnalysis( RW, this ); |
1266 |
|
} |
1267 |
|
else |
1268 |
|
{ |
1269 |
|
this->Enqueue(); |
1270 |
|
} |
1271 |
|
}; |
1272 |
|
|
1273 |
|
void Execute( Worker* user_worker ) |
1274 |
|
{ |
1275 |
|
//printf( "PermuteMatrix %lu\n", arg->treelist_id ); |
1276 |
|
auto *node = arg; |
1277 |
|
auto &gids = node->gids; |
1278 |
|
auto &input = *(node->setup->input); |
1279 |
|
auto &output = *(node->setup->output); |
1280 |
|
auto &A = node->data.bview; |
1281 |
|
|
1282 |
|
assert( A.row() == gids.size() ); |
1283 |
|
assert( A.col() == input.col() ); |
1284 |
|
|
1285 |
|
//for ( size_t i = 0; i < gids.size(); i ++ ) |
1286 |
|
// printf( "%lu ", gids[ i ] ); |
1287 |
|
//printf( "\n" ); |
1288 |
|
|
1289 |
|
/** perform permutation and output */ |
1290 |
|
for ( size_t j = 0; j < input.col(); j ++ ) |
1291 |
|
for ( size_t i = 0; i < gids.size(); i ++ ) |
1292 |
|
/** foward permutation */ |
1293 |
|
if ( FORWARD ) A( i, j ) = input( gids[ i ], j ); |
1294 |
|
/** inverse permutation */ |
1295 |
|
else input( gids[ i ], j ) = A( i, j ); |
1296 |
|
|
1297 |
|
//for ( size_t j = 0; j < 1; j ++ ) |
1298 |
|
// for ( size_t i = 0; i < gids.size(); i ++ ) |
1299 |
|
// printf( "%E ", A( i, j ) ); |
1300 |
|
//printf( "\n" ); |
1301 |
|
|
1302 |
|
//printf( "end PermuteMatrix %lu\n", arg->treelist_id ); |
1303 |
|
}; |
1304 |
|
|
1305 |
|
}; /** end class MatrixPermuteTask */ |
1306 |
|
|
1307 |
|
|
1308 |
|
|
1309 |
|
/** |
1310 |
|
* @brief |
1311 |
|
*/ |
1312 |
|
template<typename NODE, typename T> |
1313 |
|
void Apply( NODE *node ) |
1314 |
|
{ |
1315 |
|
auto &data = node->data; |
1316 |
|
auto &setup = node->setup; |
1317 |
|
auto &K = *setup->K; |
1318 |
|
|
1319 |
|
if ( node->isleaf ) |
1320 |
|
{ |
1321 |
|
auto lambda = setup->lambda; |
1322 |
|
auto &amap = node->gids; |
1323 |
|
/** evaluate the diagonal block */ |
1324 |
|
auto Kaa = K( amap, amap ); |
1325 |
|
/** apply the regularization */ |
1326 |
|
for ( size_t i = 0; i < Kaa.row(); i ++ ) |
1327 |
|
Kaa[ i * Kaa.row() + i ] += lambda; |
1328 |
|
} |
1329 |
|
else |
1330 |
|
{ |
1331 |
|
auto &bl = node->lchild->data.bview; |
1332 |
|
auto &br = node->rchild->data.bview; |
1333 |
|
data.Apply<true>( bl, br ); |
1334 |
|
} |
1335 |
|
}; /** end Apply() */ |
1336 |
|
|
1337 |
|
|
1338 |
|
|
1339 |
|
//template<typename NODE, typename T> |
1340 |
|
//void ULVForwardSolve( NODE *node ) { node->data.ULVForward(); }; |
1341 |
|
|
1342 |
|
|
1343 |
|
|
1344 |
|
template<typename NODE, typename T> |
1345 |
|
class ULVForwardSolveTask : public Task |
1346 |
|
{ |
1347 |
|
public: |
1348 |
|
|
1349 |
|
NODE *arg = NULL; |
1350 |
|
|
1351 |
|
void Set( NODE *user_arg ) |
1352 |
|
{ |
1353 |
|
arg = user_arg; |
1354 |
|
name = string( "ulvforward" ); |
1355 |
|
label = to_string( arg->treelist_id ); |
1356 |
|
cost = 1.0; |
1357 |
|
}; |
1358 |
|
|
1359 |
|
//void DependencyAnalysis() |
1360 |
|
//{ |
1361 |
|
// arg->DependencyAnalysis( RW, this ); |
1362 |
|
// /** depend on two children */ |
1363 |
|
// if ( !arg->isleaf ) |
1364 |
|
// { |
1365 |
|
// arg->lchild->DependencyAnalysis( R, this ); |
1366 |
|
// arg->rchild->DependencyAnalysis( R, this ); |
1367 |
|
// } |
1368 |
|
// /** dispatch the task if there is no dependency */ |
1369 |
|
// this->TryEnqueue(); |
1370 |
|
//}; |
1371 |
|
|
1372 |
|
void DependencyAnalysis() { arg->DependOnChildren( this ); }; |
1373 |
|
|
1374 |
|
|
1375 |
|
void Execute( Worker* user_worker ) { arg->data.ULVForward(); }; |
1376 |
|
|
1377 |
|
}; /** end class ULVForwardSolveTask */ |
1378 |
|
|
1379 |
|
|
1380 |
|
|
1381 |
|
|
1382 |
|
template<typename NODE, typename T> |
1383 |
|
class ULVBackwardSolveTask : public Task |
1384 |
|
{ |
1385 |
|
public: |
1386 |
|
|
1387 |
|
NODE *arg; |
1388 |
|
|
1389 |
|
void Set( NODE *user_arg ) |
1390 |
|
{ |
1391 |
|
arg = user_arg; |
1392 |
|
name = string( "ulvbackward" ); |
1393 |
|
label = std::to_string( arg->treelist_id ); |
1394 |
|
cost = 1.0; |
1395 |
|
|
1396 |
|
//printf( "Set treelist_id %lu\n", arg->treelist_id ); fflush( stdout ); |
1397 |
|
}; |
1398 |
|
|
1399 |
|
//void DependencyAnalysis() |
1400 |
|
//{ |
1401 |
|
// /** depend on parent */ |
1402 |
|
// if ( arg->parent ) |
1403 |
|
// arg->parent->DependencyAnalysis( hmlp::ReadWriteType::R, this ); |
1404 |
|
// arg->DependencyAnalysis( hmlp::ReadWriteType::RW, this ); |
1405 |
|
// /** dispatch the task if there is no dependency */ |
1406 |
|
// this->TryEnqueue(); |
1407 |
|
//}; |
1408 |
|
|
1409 |
|
|
1410 |
|
void DependencyAnalysis() { arg->DependOnParent( this ); }; |
1411 |
|
|
1412 |
|
void Execute( Worker* user_worker ) { arg->data.ULVBackward(); }; |
1413 |
|
|
1414 |
|
}; /** end class ULVBackwardSolveTask */ |
1415 |
|
|
1416 |
|
|
1417 |
|
|
1418 |
|
|
1419 |
|
|
1420 |
|
|
1421 |
|
|
1422 |
|
|
1423 |
|
|
1424 |
|
|
1425 |
|
|
1426 |
|
|
1427 |
|
|
1428 |
|
|
1429 |
|
|
1430 |
|
|
1431 |
|
|
1432 |
|
|
1433 |
|
/** |
1434 |
|
* @brief |
1435 |
|
*/ |
1436 |
|
template<typename NODE, typename T> |
1437 |
|
void Solve( NODE *node ) |
1438 |
|
{ |
1439 |
|
|
1440 |
|
auto &data = node->data; |
1441 |
|
auto &setup = node->setup; |
1442 |
|
auto &K = *setup->K; |
1443 |
|
|
1444 |
|
|
1445 |
|
//printf( "%lu beg Solve\n", node->treelist_id ); fflush( stdout ); |
1446 |
|
|
1447 |
|
/** TODO: need to decide to use LU or not */ |
1448 |
|
if ( node->isleaf ) |
1449 |
|
{ |
1450 |
|
auto &b = data.bview; |
1451 |
|
data.Solve( b ); |
1452 |
|
//printf( "Solve %lu, m %lu n %lu\n", node->treelist_id, b.row(), b.col() ); |
1453 |
|
} |
1454 |
|
else |
1455 |
|
{ |
1456 |
|
auto &bl = node->lchild->data.bview; |
1457 |
|
auto &br = node->rchild->data.bview; |
1458 |
|
data.Solve( bl, br ); |
1459 |
|
//printf( "Solve %lu, m %lu n %lu\n", node->treelist_id, bl.row(), bl.col() ); |
1460 |
|
} |
1461 |
|
|
1462 |
|
//printf( "%lu end Solve\n", node->treelist_id ); fflush( stdout ); |
1463 |
|
|
1464 |
|
}; /** end Solve() */ |
1465 |
|
|
1466 |
|
|
1467 |
|
/** |
1468 |
|
* @brief |
1469 |
|
*/ |
1470 |
|
template<typename NODE, typename T> |
1471 |
|
class SolveTask : public Task |
1472 |
|
{ |
1473 |
|
public: |
1474 |
|
|
1475 |
|
NODE *arg = NULL; |
1476 |
|
|
1477 |
|
void Set( NODE *user_arg ) |
1478 |
|
{ |
1479 |
|
arg = user_arg; |
1480 |
|
name = string( "sl" ); |
1481 |
|
label = to_string( arg->treelist_id ); |
1482 |
|
cost = 1.0; |
1483 |
|
|
1484 |
|
//printf( "Set treelist_id %lu\n", arg->treelist_id ); fflush( stdout ); |
1485 |
|
}; |
1486 |
|
|
1487 |
|
void GetEventRecord() |
1488 |
|
{ |
1489 |
|
double flops = 0.0, mops = 0.0; |
1490 |
|
event.Set( label + name, flops, mops ); |
1491 |
|
}; |
1492 |
|
|
1493 |
|
void DependencyAnalysis() |
1494 |
|
{ |
1495 |
|
arg->DependencyAnalysis( RW, this ); |
1496 |
|
if ( !arg->isleaf ) |
1497 |
|
{ |
1498 |
|
arg->lchild->DependencyAnalysis( R, this ); |
1499 |
|
arg->rchild->DependencyAnalysis( R, this ); |
1500 |
|
} |
1501 |
|
}; |
1502 |
|
|
1503 |
|
void Execute( Worker* user_worker ) |
1504 |
|
{ |
1505 |
|
Solve<NODE, T>( arg ); |
1506 |
|
}; |
1507 |
|
|
1508 |
|
}; /** end class SolveTask */ |
1509 |
|
|
1510 |
|
|
1511 |
|
/** |
1512 |
|
* |
1513 |
|
*/ |
1514 |
|
template<typename T, typename TREE> |
1515 |
|
void Solve( TREE &tree, Data<T> &input ) |
1516 |
|
{ |
1517 |
|
using NODE = typename TREE::NODE; |
1518 |
|
|
1519 |
|
const bool AUTO_DEPENDENCY = true; |
1520 |
|
const bool USE_RUNTIME = true; |
1521 |
|
|
1522 |
|
/** copy input to output */ |
1523 |
|
auto *output = new Data<T>( input.row(), input.col() ); |
1524 |
|
|
1525 |
|
SolverTreeViewTask<NODE> treeviewtask; |
1526 |
|
MatrixPermuteTask<true, NODE> forwardpermutetask; |
1527 |
|
MatrixPermuteTask<false, NODE> inversepermutetask; |
1528 |
|
/** Sherman-Morrison-Woodbury */ |
1529 |
|
SolveTask<NODE, T> solvetask1; |
1530 |
|
/** ULV */ |
1531 |
|
ULVForwardSolveTask<NODE, T> ulvforwardsolvetask; |
1532 |
|
ULVBackwardSolveTask<NODE, T> ulvbackwardsolvetask; |
1533 |
|
|
1534 |
|
/** attach the pointer to the tree structure */ |
1535 |
|
tree.setup.input = &input; |
1536 |
|
tree.setup.output = output; |
1537 |
|
|
1538 |
|
if ( tree.setup.do_ulv_factorization ) |
1539 |
|
{ |
1540 |
|
/** clean up all dependencies on tree nodes */ |
1541 |
|
tree.DependencyCleanUp(); |
1542 |
|
tree.TraverseDown( treeviewtask ); |
1543 |
|
tree.TraverseLeafs( forwardpermutetask ); |
1544 |
|
tree.TraverseUp( ulvforwardsolvetask ); |
1545 |
|
tree.TraverseDown( ulvbackwardsolvetask ); |
1546 |
|
if ( USE_RUNTIME ) hmlp_run(); |
1547 |
|
|
1548 |
|
/** clean up all dependencies on tree nodes */ |
1549 |
|
tree.DependencyCleanUp(); |
1550 |
|
tree.TraverseLeafs( inversepermutetask ); |
1551 |
|
if ( USE_RUNTIME ) hmlp_run(); |
1552 |
|
} |
1553 |
|
else |
1554 |
|
{ |
1555 |
|
/** clean up all dependencies on tree nodes */ |
1556 |
|
tree.DependencyCleanUp(); |
1557 |
|
tree.TraverseDown( treeviewtask ); |
1558 |
|
tree.TraverseLeafs( forwardpermutetask ); |
1559 |
|
tree.TraverseUp( solvetask1 ); |
1560 |
|
if ( USE_RUNTIME ) hmlp_run(); |
1561 |
|
/** clean up all dependencies on tree nodes */ |
1562 |
|
tree.DependencyCleanUp(); |
1563 |
|
tree.TraverseLeafs( inversepermutetask ); |
1564 |
|
if ( USE_RUNTIME ) hmlp_run(); |
1565 |
|
} |
1566 |
|
|
1567 |
|
/** delete buffer space */ |
1568 |
|
delete output; |
1569 |
|
|
1570 |
|
}; /** end Solve() */ |
1571 |
|
|
1572 |
|
|
1573 |
|
|
1574 |
|
|
1575 |
|
|
1576 |
|
/** |
1577 |
|
* @brief Compute relative Forbenius error for two-sided |
1578 |
|
* interpolative decomposition. |
1579 |
|
*/ |
1580 |
|
template<typename NODE, typename T> |
1581 |
|
void LowRankError( NODE *node ) |
1582 |
|
{ |
1583 |
|
auto &data = node->data; |
1584 |
|
auto &setup = node->setup; |
1585 |
|
auto &K = *setup->K; |
1586 |
|
|
1587 |
|
if ( !node->isleaf ) |
1588 |
|
{ |
1589 |
|
auto Krl = K( node->rchild->gids, node->lchild->gids ); |
1590 |
|
|
1591 |
|
auto nrm2 = hmlp_norm( Krl.row(), Krl.col(), |
1592 |
|
Krl.data(), Krl.row() ); |
1593 |
|
|
1594 |
|
|
1595 |
|
hmlp::Data<T> VrCrl( data.nr, data.sl ); |
1596 |
|
|
1597 |
|
/** VrCrl = Vr * Crl */ |
1598 |
|
xgemm( "N", "N", data.nr, data.sl, data.sr, |
1599 |
|
1.0, data.Vr->data(), data.nr, |
1600 |
|
data.Crl.data(), data.sr, |
1601 |
|
0.0, VrCrl.data(), data.nr ); |
1602 |
|
|
1603 |
|
/** Krl - VrCrlVl' */ |
1604 |
|
xgemm( "N", "T", data.nr, data.nl, data.sl, |
1605 |
|
-1.0, VrCrl.data(), data.nr, |
1606 |
|
data.Vl->data(), data.nl, |
1607 |
|
1.0, Krl.data(), data.nr ); |
1608 |
|
|
1609 |
|
auto err = hmlp_norm( Krl.row(), Krl.col(), |
1610 |
|
Krl.data(), Krl.row() ); |
1611 |
|
|
1612 |
|
printf( "%4lu ||Krl -VrCrlVl|| %3.1E\n", |
1613 |
|
node->treelist_id, std::sqrt( err / nrm2 ) ); |
1614 |
|
} |
1615 |
|
|
1616 |
|
}; /** end LowRankError() */ |
1617 |
|
|
1618 |
|
|
1619 |
|
|
1620 |
|
/** |
1621 |
|
* @brief Factorizarion using LU and SMW |
1622 |
|
*/ |
1623 |
|
template<typename NODE, typename T> |
1624 |
|
void Factorize( NODE *node ) |
1625 |
|
{ |
1626 |
|
auto &data = node->data; |
1627 |
|
auto &setup = node->setup; |
1628 |
|
auto &K = *setup->K; |
1629 |
|
auto &proj = data.proj; |
1630 |
|
|
1631 |
|
auto do_ulv_factorization = setup->do_ulv_factorization; |
1632 |
|
|
1633 |
|
if ( node->isleaf ) |
1634 |
|
{ |
1635 |
|
auto lambda = setup->lambda; |
1636 |
|
auto &amap = node->gids; |
1637 |
|
|
1638 |
|
/** Evaluate the diagonal block. */ |
1639 |
|
Data<T> Kaa = K( amap, amap ); |
1640 |
|
|
1641 |
|
/** Apply the regularization */ |
1642 |
|
for ( size_t i = 0; i < Kaa.row(); i ++ ) Kaa( i, i ) += lambda; |
1643 |
|
|
1644 |
|
if ( do_ulv_factorization ) |
1645 |
|
{ |
1646 |
|
/** U = proj */ |
1647 |
|
data.Telescope( false, data.U, proj ); |
1648 |
|
/** QR factorization */ |
1649 |
|
data.Orthogonalization(); |
1650 |
|
/** LU factorization */ |
1651 |
|
data.PartialFactorize( Kaa ); |
1652 |
|
} |
1653 |
|
else |
1654 |
|
{ |
1655 |
|
/** LU factorization */ |
1656 |
|
data.Factorize( Kaa ); |
1657 |
|
/** U = inv( Kaa ) * proj' */ |
1658 |
|
data.Telescope( true, data.U, proj ); |
1659 |
|
/** V = proj' */ |
1660 |
|
data.Telescope( false, data.V, proj ); |
1661 |
|
} |
1662 |
|
} |
1663 |
|
else |
1664 |
|
{ |
1665 |
|
auto &Ul = node->lchild->data.U; |
1666 |
|
auto &Vl = node->lchild->data.V; |
1667 |
|
auto &Zl = node->lchild->data.Zbr; |
1668 |
|
auto &Ur = node->rchild->data.U; |
1669 |
|
auto &Vr = node->rchild->data.V; |
1670 |
|
auto &Zr = node->rchild->data.Zbr; |
1671 |
|
|
1672 |
|
/** Evluate the skeleton rows and columns. */ |
1673 |
|
auto &amap = node->lchild->data.skels; |
1674 |
|
auto &bmap = node->rchild->data.skels; |
1675 |
|
|
1676 |
|
/** Get the skeleton rows and columns */ |
1677 |
|
node->data.Crl = K( bmap, amap ); |
1678 |
|
|
1679 |
|
if ( do_ulv_factorization ) |
1680 |
|
{ |
1681 |
|
if ( !node->data.isroot ) |
1682 |
|
{ |
1683 |
|
data.Telescope( false, data.U, proj, Ul, Ur ); |
1684 |
|
data.Orthogonalization(); |
1685 |
|
} |
1686 |
|
data.PartialFactorize( Zl, Zr, Ul, Ur, Vl, Vr ); |
1687 |
|
} |
1688 |
|
else |
1689 |
|
{ |
1690 |
|
/** SMW factorization (LU or Cholesky) */ |
1691 |
|
data.Factorize( Ul, Ur, Vl, Vr ); |
1692 |
|
/** telescope U and V */ |
1693 |
|
if ( !node->data.isroot ) |
1694 |
|
{ |
1695 |
|
/** U = inv( I + UCV' ) * [ Ul; Ur ] * proj' */ |
1696 |
|
data.Telescope( true, data.U, proj, Ul, Ur ); |
1697 |
|
/** V = [ Vl; Vr ] * proj' */ |
1698 |
|
data.Telescope( false, data.V, proj, Vl, Vr ); |
1699 |
|
} |
1700 |
|
} |
1701 |
|
} |
1702 |
|
|
1703 |
|
|
1704 |
|
|
1705 |
|
|
1706 |
|
|
1707 |
|
|
1708 |
|
|
1709 |
|
|
1710 |
|
// /** SMW factorization (LU or Cholesky) */ |
1711 |
|
// data.Factorize<true>( Ul, Ur, Vl, Vr ); |
1712 |
|
// |
1713 |
|
// /** telescope U and V */ |
1714 |
|
// if ( !node->data.isroot ) |
1715 |
|
// { |
1716 |
|
// if ( do_ulv_factorization ) |
1717 |
|
// { |
1718 |
|
// data.Telescope( true, data.U, proj, Ul, Ur ); |
1719 |
|
// data.Orthogonalization(); |
1720 |
|
// } |
1721 |
|
// else |
1722 |
|
// { |
1723 |
|
// /** U = inv( I + UCV' ) * [ Ul; Ur ] * proj' */ |
1724 |
|
// data.Telescope( true, data.U, proj, Ul, Ur ); |
1725 |
|
// /** V = [ Vl; Vr ] * proj' */ |
1726 |
|
// data.Telescope( false, data.V, proj, Vl, Vr ); |
1727 |
|
// } |
1728 |
|
// } |
1729 |
|
// else |
1730 |
|
// { |
1731 |
|
// /** output Crl from children */ |
1732 |
|
// |
1733 |
|
// //size_t L = 3; |
1734 |
|
// |
1735 |
|
// auto *cl = node->lchild; |
1736 |
|
// auto *cr = node->rchild; |
1737 |
|
// auto *c1 = cl->lchild; |
1738 |
|
// auto *c2 = cl->rchild; |
1739 |
|
// auto *c3 = cr->lchild; |
1740 |
|
// auto *c4 = cr->rchild; |
1741 |
|
// |
1742 |
|
// //hmlp::Data<T> C21 = K( c2->data.skels, c1->data.skels ); |
1743 |
|
// //hmlp::Data<T> C31 = K( c3->data.skels, c1->data.skels ); |
1744 |
|
// //hmlp::Data<T> C41 = K( c4->data.skels, c1->data.skels ); |
1745 |
|
// //hmlp::Data<T> C32 = K( c3->data.skels, c2->data.skels ); |
1746 |
|
// //hmlp::Data<T> C42 = K( c4->data.skels, c2->data.skels ); |
1747 |
|
// //hmlp::Data<T> C43 = K( c4->data.skels, c3->data.skels ); |
1748 |
|
// |
1749 |
|
// //C21.WriteFile( "C21.m" ); |
1750 |
|
// //C31.WriteFile( "C31.m" ); |
1751 |
|
// //C41.WriteFile( "C41.m" ); |
1752 |
|
// //C32.WriteFile( "C32.m" ); |
1753 |
|
// //C42.WriteFile( "C42.m" ); |
1754 |
|
// //C43.WriteFile( "C43.m" ); |
1755 |
|
// |
1756 |
|
// |
1757 |
|
// //hmlp::Data<T> V11( c1->data.V.col(), c1->data.V.col() ); |
1758 |
|
// //hmlp::Data<T> V22( c2->data.V.col(), c2->data.V.col() ); |
1759 |
|
// //hmlp::Data<T> V33( c3->data.V.col(), c3->data.V.col() ); |
1760 |
|
// //hmlp::Data<T> V44( c4->data.V.col(), c4->data.V.col() ); |
1761 |
|
// |
1762 |
|
// //xgemm( "T", "N", c1->data.V.col(), c1->data.V.col(), c1->data.V.row(), |
1763 |
|
// // 1.0, c1->data.V.data(), c1->data.V.row(), |
1764 |
|
// // c1->data.V.data(), c1->data.V.row(), |
1765 |
|
// // 0.0, V11.data(), V11.row() ); |
1766 |
|
// |
1767 |
|
// //xgemm( "T", "N", c2->data.V.col(), c2->data.V.col(), c2->data.V.row(), |
1768 |
|
// // 1.0, c2->data.V.data(), c2->data.V.row(), |
1769 |
|
// // c2->data.V.data(), c2->data.V.row(), |
1770 |
|
// // 0.0, V22.data(), V22.row() ); |
1771 |
|
// |
1772 |
|
// //xgemm( "T", "N", c3->data.V.col(), c3->data.V.col(), c3->data.V.row(), |
1773 |
|
// // 1.0, c3->data.V.data(), c3->data.V.row(), |
1774 |
|
// // c3->data.V.data(), c3->data.V.row(), |
1775 |
|
// // 0.0, V33.data(), V33.row() ); |
1776 |
|
// |
1777 |
|
// //xgemm( "T", "N", c4->data.V.col(), c4->data.V.col(), c4->data.V.row(), |
1778 |
|
// // 1.0, c4->data.V.data(), c4->data.V.row(), |
1779 |
|
// // c4->data.V.data(), c4->data.V.row(), |
1780 |
|
// // 0.0, V44.data(), V44.row() ); |
1781 |
|
// |
1782 |
|
// //V11.WriteFile( "V11.m" ); |
1783 |
|
// //V22.WriteFile( "V22.m" ); |
1784 |
|
// //V33.WriteFile( "V33.m" ); |
1785 |
|
// //V44.WriteFile( "V44.m" ); |
1786 |
|
// } |
1787 |
|
// //printf( "end inner forward telescoping\n" ); fflush( stdout ); |
1788 |
|
// |
1789 |
|
// /** check the offdiagonal block VrCrlVl' accuracy */ |
1790 |
|
// if ( !do_ulv_factorization ) |
1791 |
|
// LowRankError<NODE, T>( node ); |
1792 |
|
// } |
1793 |
|
|
1794 |
|
}; /** end void Factorize() */ |
1795 |
|
|
1796 |
|
|
1797 |
|
|
1798 |
|
/** |
1799 |
|
* @brief |
1800 |
|
*/ |
1801 |
|
template<typename NODE, typename T> |
1802 |
|
class FactorizeTask : public Task |
1803 |
|
{ |
1804 |
|
public: |
1805 |
|
|
1806 |
|
NODE *arg = NULL; |
1807 |
|
|
1808 |
|
void Set( NODE *user_arg ) |
1809 |
|
{ |
1810 |
|
arg = user_arg; |
1811 |
|
name = string( "fa" ); |
1812 |
|
label = to_string( arg->treelist_id ); |
1813 |
|
// Need an accurate cost model. |
1814 |
|
cost = 1.0; |
1815 |
|
}; |
1816 |
|
|
1817 |
|
void GetEventRecord() |
1818 |
|
{ |
1819 |
|
double flops = 0.0, mops = 0.0; |
1820 |
|
event.Set( label + name, flops, mops ); |
1821 |
|
}; |
1822 |
|
|
1823 |
|
void DependencyAnalysis() { arg->DependOnChildren( this ); }; |
1824 |
|
|
1825 |
|
void Execute( Worker* user_worker ) { Factorize<NODE, T>( arg ); }; |
1826 |
|
|
1827 |
|
}; /** end class FactorizeTask */ |
1828 |
|
|
1829 |
|
|
1830 |
|
|
1831 |
|
|
1832 |
|
|
1833 |
|
|
1834 |
|
|
1835 |
|
|
1836 |
|
|
1837 |
|
|
1838 |
|
|
1839 |
|
|
1840 |
|
|
1841 |
|
/** @biref Top-level factorization routine. */ |
1842 |
|
template<typename T, typename TREE> |
1843 |
|
void Factorize( TREE &tree, T lambda ) |
1844 |
|
{ |
1845 |
|
using NODE = typename TREE::NODE; |
1846 |
|
|
1847 |
|
/** Clean up all dependencies on tree nodes. */ |
1848 |
|
tree.DependencyCleanUp(); |
1849 |
|
|
1850 |
|
/** Regularization parameter lambda. */ |
1851 |
|
tree.setup.lambda = lambda; |
1852 |
|
|
1853 |
|
/** Perform ULV factorization. */ |
1854 |
|
tree.setup.do_ulv_factorization = true; |
1855 |
|
|
1856 |
|
/** Setup */ |
1857 |
|
SetupFactorTask<NODE, T> setupfactortask; |
1858 |
|
tree.TraverseUp( setupfactortask ); |
1859 |
|
tree.ExecuteAllTasks(); |
1860 |
|
|
1861 |
|
/** Factorization */ |
1862 |
|
FactorizeTask<NODE, T> factorizetask; |
1863 |
|
tree.TraverseUp( factorizetask ); |
1864 |
|
tree.ExecuteAllTasks(); |
1865 |
|
|
1866 |
|
}; /** end Factorize() */ |
1867 |
|
|
1868 |
|
|
1869 |
|
|
1870 |
|
/** |
1871 |
|
* @brief Compute the average 2-norm error. That is given |
1872 |
|
* lambda and weights, |
1873 |
|
*/ |
1874 |
|
template<typename TREE, typename T> |
1875 |
|
void ComputeError( TREE &tree, T lambda, Data<T> weights, Data<T> potentials ) |
1876 |
|
{ |
1877 |
|
using NODE = typename TREE::NODE; |
1878 |
|
|
1879 |
|
|
1880 |
|
/** assure the dimension matches */ |
1881 |
|
assert( weights.row() == potentials.row() ); |
1882 |
|
assert( weights.col() == potentials.col() ); |
1883 |
|
|
1884 |
|
size_t n = weights.row(); |
1885 |
|
size_t nrhs = weights.col(); |
1886 |
|
|
1887 |
|
/** shift lambda and make it a column vector */ |
1888 |
|
Data<T> rhs( n, nrhs ); |
1889 |
|
for ( size_t j = 0; j < nrhs; j ++ ) |
1890 |
|
for ( size_t i = 0; i < n; i ++ ) |
1891 |
|
rhs( i, j ) = potentials( i, j ) + lambda * weights( i, j ); |
1892 |
|
|
1893 |
|
/** potentials = inv( K + lambda * I ) * potentials */ |
1894 |
|
Solve( tree, rhs ); |
1895 |
|
|
1896 |
|
|
1897 |
|
/** Compute relative error = sqrt( err / nrm2 ) for each rhs */ |
1898 |
|
printf( "========================================================\n" ); |
1899 |
|
printf( "Inverse accuracy report\n" ); |
1900 |
|
printf( "========================================================\n" ); |
1901 |
|
printf( "#rhs, max err, @, min err, @, relative \n" ); |
1902 |
|
printf( "========================================================\n" ); |
1903 |
|
size_t ntest = 10; |
1904 |
|
T total_err = 0.0; |
1905 |
|
for ( size_t j = 0; j < std::min( nrhs, ntest ); j ++ ) |
1906 |
|
{ |
1907 |
|
/** counters */ |
1908 |
|
T nrm2 = 0.0, err2 = 0.0; |
1909 |
|
T max2 = 0.0, min2 = std::numeric_limits<T>::max(); |
1910 |
|
/** indecies */ |
1911 |
|
size_t maxi = 0, mini = 0; |
1912 |
|
|
1913 |
|
for ( size_t i = 0; i < n; i ++ ) |
1914 |
|
{ |
1915 |
|
T sse = rhs( i, j ) - weights( i, j ); |
1916 |
|
assert( rhs( i, j ) == rhs( i, j ) ); |
1917 |
|
sse = sse * sse; |
1918 |
|
|
1919 |
|
nrm2 += weights( i, j ) * weights( i, j ); |
1920 |
|
err2 += sse; |
1921 |
|
|
1922 |
|
//printf( "%lu %3.1E\n", i, sse ); |
1923 |
|
|
1924 |
|
|
1925 |
|
if ( sse > max2 ) { max2 = sse; maxi = i; } |
1926 |
|
if ( sse < min2 ) { min2 = sse; mini = i; } |
1927 |
|
} |
1928 |
|
total_err += std::sqrt( err2 / nrm2 ); |
1929 |
|
|
1930 |
|
printf( "%4lu, %3.1E, %7lu, %3.1E, %7lu, %3.1E\n", |
1931 |
|
j, std::sqrt( max2 ), maxi, std::sqrt( min2 ), mini, |
1932 |
|
std::sqrt( err2 / nrm2 ) ); |
1933 |
|
} |
1934 |
|
printf( "========================================================\n" ); |
1935 |
|
printf( " avg over %2lu rhs, %3.1E \n", |
1936 |
|
std::min( nrhs, ntest ), total_err / std::min( nrhs, ntest ) ); |
1937 |
|
printf( "========================================================\n\n" ); |
1938 |
|
|
1939 |
|
}; /** end ComputeError() */ |
1940 |
|
|
1941 |
|
|
1942 |
|
|
1943 |
|
|
1944 |
|
|
1945 |
|
|
1946 |
|
|
1947 |
|
|
1948 |
|
}; /** end namespace gofmm */ |
1949 |
|
}; /** end namespace hmlp */ |
1950 |
|
|
1951 |
|
#endif /** define IGOFMM_HPP */ |