GCC Code Coverage Report
Directory: . Exec Total Coverage
File: gofmm/igofmm.hpp Lines: 0 490 0.0 %
Date: 2019-01-14 Branches: 0 1738 0.0 %

Line Exec Source
1
/**
2
 *  HMLP (High-Performance Machine Learning Primitives)
3
 *
4
 *  Copyright (C) 2014-2017, The University of Texas at Austin
5
 *
6
 *  This program is free software: you can redistribute it and/or modify
7
 *  it under the terms of the GNU General Public License as published by
8
 *  the Free Software Foundation, either version 3 of the License, or
9
 *  (at your option) any later version.
10
 *
11
 *  This program is distributed in the hope that it will be useful,
12
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
 *  GNU General Public License for more details.
15
 *
16
 *  You should have received a copy of the GNU General Public License
17
 *  along with this program. If not, see the LICENSE file.
18
 *
19
 **/
20
21
22
23
#ifndef IGOFMM_HPP
24
#define IGOFMM_HPP
25
26
27
/** Use STL and HMLP namespaces. */
28
using namespace std;
29
using namespace hmlp;
30
31
32
33
34
namespace hmlp
35
{
36
namespace gofmm
37
{
38
39
40
/**
41
 *
42
 * for each level
43
 *   for each alpha
44
 *
45
 *     if ( leaf )
46
 *
47
 *       LL' = Chol( Kaa )
48
 *       U   = inv( L ) * proj' or
49
 *       U'  = proj * inv( L' )
50
 *       QR  = qr( U ) or
51
 *       LQ' = lq( U' )
52
 *
53
 *     else
54
 *
55
 *       LL' = Chol( I + k.R * C * k.R' )
56
 *
57
 *       if ( not root )
58
 *
59
 *         U   = inv( L ) * [ l.R        * proj' or
60
 *                               r.R ]
61
 *         U'  = proj' *  [ l.R        * inv( L' )
62
 *                             r.R ]
63
 *         QR  = qr( U ) or
64
 *         LQ' = lq( U' )
65
 *
66
 **/
67
68
69
template<typename T>
70
class Factor
71
{
72
  public:
73
74
    Factor() {};
75
76
    void SetupFactor
77
    (
78
      bool issymmetric, bool do_ulv_factorization,
79
      bool isleaf, bool isroot,
80
      /** n == nl + nr (left + right) */
81
      size_t n, size_t nl, size_t nr,
82
      /** s <= sl + sr */
83
      size_t s, size_t sl, size_t sr
84
    )
85
    {
86
      this->issymmetric = issymmetric;
87
      this->do_ulv_factorization = do_ulv_factorization;
88
      this->isleaf = isleaf;
89
      this->isroot = isroot;
90
      this->n = n; this->nl = nl; this->nr = nr;
91
      this->s = s; this->sl = sl; this->sr = sr;
92
    };
93
94
    void SetupFactor
95
    (
96
      bool issymmetric, bool do_ulv_factorization,
97
      bool isleaf, bool isroot,
98
      size_t n, size_t nl, size_t nr,
99
      size_t s, size_t sl, size_t sr,
100
      /** n-by-?; its rank depends on mu sibling */
101
      Data<T> &U,
102
      /** ?-by-n; its rank depends on my sibling */
103
      Data<T> &V
104
    )
105
    {
106
      SetupFactor( issymmetric, do_ulv_factorization,
107
          isleaf, isroot, n, nl, nr, s, sl, sr );
108
    };
109
110
    bool DoULVFactorization()
111
    {
112
      return do_ulv_factorization;
113
    };
114
115
    bool IsSymmetric() { return issymmetric; };
116
117
118
119
120
121
    void CheckCondition()
122
    {
123
      assert( do_ulv_factorization && issymmetric );
124
      T max_diag = 0.0;
125
      T min_diag = 0.0;
126
127
      for ( size_t i = 0; i < Z.row(); i ++ )
128
      {
129
        T abs_diag =  std::abs( Z( i, i ) );
130
131
        if ( !i )
132
        {
133
          max_diag = abs_diag;
134
          min_diag = abs_diag;
135
        }
136
137
        if ( abs_diag > max_diag ) max_diag = abs_diag;
138
        if ( abs_diag < min_diag ) min_diag = abs_diag;
139
      }
140
141
      printf( "condiditon( Z ): min_diag %3.1E max_diag %3.1E ratio %3.1E\n",
142
          min_diag, max_diag, min_diag / max_diag );
143
    };
144
145
    void Factorize( Data<T> &Kaa )
146
    {
147
      assert( isleaf );
148
      assert( Kaa.row() == n ); assert( Kaa.col() == n );
149
150
      /** Initialize with Kaa. */
151
      Z = Kaa;
152
153
      /** Record the partial pivoting order. */
154
      ipiv.resize( n, 0 );
155
156
      /** Compute 1-norm of Z. */
157
      T nrm1 = 0.0;
158
      for ( auto &z : Z ) nrm1 += z;
159
160
      /** Pivoted LU factorization. */
161
      xgetrf( n, n, Z.data(), n, ipiv.data() );
162
163
      /** Compute 1-norm condition number. */
164
      T rcond1 = 0.0;
165
      Data<T> work( Z.row(), 4 );
166
      vector<int> iwork( Z.row() );
167
      xgecon( "1", Z.row(), Z.data(), Z.row(), nrm1,
168
          &rcond1, work.data(), iwork.data() );
169
      if ( 1.0 / rcond1 > 1E+6 )
170
        printf( "Warning! large 1-norm condition number %3.1E, nrm1( Z ) %3.1E\n",
171
            1.0 / rcond1, nrm1 );
172
    }; /** end Factorize() */
173
174
175
    /**
176
     *  Kaa = [ P     [ L11      [ I     [ U11 U12
177
     *            I ]   L21  I ]     C ]         I ]
178
     */
179
    void PartialFactorize( Data<T> &A )
180
    {
181
      /** Similar transformation ( Q' * Z * Q ). */
182
      Z = A;
183
      ChangeBasis( Z );
184
185
      /** Create matrix views for Z. */
186
      Zv.Set( false, Z );
187
      Zv.Partition2x2( Ztl, Ztr,
188
                       Zbl, Zbr, s, s, BOTTOMRIGHT );
189
190
      //printf( "Ztl %lux%lu Ztr %lux%lu\n", Ztl.row(), Ztl.col(), Ztr.row(), Ztr.col() ); fflush( stdout );
191
      //printf( "Zbl %lux%lu Zbr %lux%lu\n", Zbl.row(), Zbl.col(), Zbr.row(), Zbr.col() ); fflush( stdout );
192
193
      /** Initialize pivoting rows. */
194
      ipiv.resize( Ztl.row(), 0 );
195
      /** [Ztl, Ztr] = PLU */
196
      xgetrf( Ztl.row(), Z.col(), Z.data(), Z.row(), ipiv.data() );
197
      /** Zbl * U^{-1} */
198
      xtrsm( "Right", "Upper", "No transpose", "Non-unit", Zbl.row(), Zbl.col(),
199
          1.0,  Ztl.data(), Ztl.ld(), Zbl.data(), Zbl.ld() );
200
      /** Update Schur complement Zbr. */
201
      xgemm( "No transpose", "No transpose", Zbr.row(), Zbr.col(), Ztl.col(),
202
          -1.0, Zbl.data(), Zbl.ld(),
203
                Ztr.data(), Ztr.ld(),
204
           1.0, Zbr.data(), Zbr.ld() );
205
206
    }; /** end PartialFactorize() */
207
208
209
210
211
    /**
212
     *   two-sided UCVt   one-sided UBt
213
     *
214
     *      | sl   sr        | sl   sr
215
     *   -------------    -------------
216
     *   nl | Ul          nl | Ul
217
     *   nr |      Ur     nr |      Ur
218
     *
219
     *      | sl   sr
220
     *   -------------
221
     *   sl |     Clr
222
     *   sr | Crl
223
     *
224
     *      | nl   nr        | nl   nr
225
     *   -------------    -------------
226
     *   sl | Vlt         sl |      Brt
227
     *   sr |      Vrt    sr | Blt
228
     *
229
     *
230
     **/
231
    void Factorize
232
    (
233
      /** Ul,  nl-by-sl */
234
      Data<T> &Ul,
235
      /** Ur,  nr-by-sr */
236
      Data<T> &Ur,
237
      /** Vl,  nl-by-sr */
238
      Data<T> &Vl,
239
      /** Vr,  nr-by-sr */
240
      Data<T> &Vr
241
    )
242
    {
243
      assert( !isleaf );
244
      //assert( Ul.row() == nl ); assert( Ul.col() == sl );
245
      //assert( Ur.row() == nr ); assert( Ur.col() == sr );
246
      //assert( Vl.row() == nl ); assert( Vl.col() == sl );
247
      //assert( Vr.row() == nr ); assert( Vr.col() == sr );
248
249
      /** even SYMMETRIC this routine uses LU factorization */
250
      if ( issymmetric )
251
      {
252
        assert( Crl.row() == sr ); assert( Crl.col() == sl );
253
      }
254
      else
255
      {
256
        assert( Clr.row() == sl ); assert( Clr.col() == sr );
257
        assert( Crl.row() == sr ); assert( Crl.col() == sl );
258
      }
259
260
      /**
261
       *  clean up and begin with Z = eye( sl + sr ) =     | sl  sr
262
       *                                                ------------
263
       *                                                sl | Zrl Ztr
264
       *                                                sr | Zbl Zbr
265
       **/
266
      Z.resize( 0, 0 );
267
      Z.resize( sl + sr, sl + sr, 0.0 );
268
      for ( size_t i = 0; i < sl + sr; i ++ ) Z[ i * Z.row() + i ] = 1.0;
269
270
271
272
      if ( do_ulv_factorization )
273
      {
274
        /**
275
         *  Z = I + UR * C * VR' = [                 I  URl * Clr * VRr'
276
         *                            URr * Crl * VRl'                 I ]
277
         **/
278
        if ( issymmetric ) /** Cholesky */
279
        {
280
          /** Zbl = URr * Crl * VRl' */
281
          hmlp::Data<T> Zbl = Crl;
282
283
          //printf( "Crl\n" );
284
          //Crl.Print();
285
286
287
          /** trmm */
288
          xtrmm
289
          (
290
            "Right", "Upper", "Transpose", "Non-unit",
291
            Zbl.row(), Zbl.col(),
292
            1.0,  Ul.data(),  Ul.row(),
293
                 Zbl.data(), Zbl.row()
294
          );
295
          //printf( "Ul.row() %lu Zbl.row() %lu Zbl.col() %lu\n",
296
          //    Ul.row(), Zbl.row(), Zbl.col() );
297
298
          /** trmm */
299
          xtrmm
300
          (
301
            "Left", "Upper", "Non-transpose", "Non-unit",
302
            Zbl.row(), Zbl.col(),
303
            1.0,  Ur.data(),  Ur.row(),
304
                 Zbl.data(), Zbl.row()
305
          );
306
          //printf( "Ur.row() %lu Zbl.row() %lu Zbl.col() %lu\n",
307
          //    Ur.row(), Zbl.row(), Zbl.col() );
308
309
          /** Zbl */
310
          for ( size_t j = 0; j < sl; j ++ )
311
            for ( size_t i = 0; i < sr; i ++ )
312
            {
313
              Z( sl + i, j ) = Zbl( i, j );
314
              Z( j, sl + i ) = Zbl( i, j );
315
            }
316
317
          /** LL' = potrf( Z ) */
318
          if ( 1 )
319
          {
320
            xpotrf( "Lower", Z.row(), Z.data(), Z.row() );
321
            //CheckCondition();
322
          }
323
          else
324
          {
325
            /** pivoting row indices */
326
            ipiv.resize( Z.row(), 0 );
327
            xgetrf( Z.row(), Z.col(), Z.data(), Z.row(), ipiv.data() );
328
          }
329
        }
330
        else /** LU */
331
        {
332
          /** pivoting row indices */
333
          ipiv.resize( Z.row(), 0 );
334
        }
335
      }
336
      else /** Sherman-Morrison-Woodbury */
337
      {
338
        /** pivoting row indices */
339
        ipiv.resize( Z.row(), 0 );
340
341
        /**
342
         *  Z = I + CVtU =  [        I  ClrVrtUr
343
         *                    CrlVltUl         I ]
344
         **/
345
        std::vector<T> VltUl( sl * sl, 0.0 );
346
        std::vector<T> VrtUr( sr * sr, 0.0 );
347
348
        /** VltUl */
349
        xgemm( "T", "N", sl, sl, nl,
350
            1.0,    Vl.data(), nl,
351
            Ul.data(), nl,
352
            0.0, VltUl.data(), sl );
353
354
        /** VrtUr */
355
        xgemm( "T", "N", sr, sr, nr,
356
            1.0,    Vr.data(), nr,
357
            Ur.data(), nr,
358
            0.0, VrtUr.data(), sr );
359
360
        /** CrlVltUl */
361
        xgemm( "N", "N", sr, sl, sl,
362
            1.0,   Crl.data(), sr,
363
            VltUl.data(), sl,
364
            0.0,     Z.data() + sl, sl + sr );
365
366
367
        if ( issymmetric )
368
        {
369
          /** Crl'VrtUr */
370
          xgemm( "T", "N", sl, sr, sr,
371
              1.0,   Crl.data(), sr,
372
              VrtUr.data(), sr,
373
              0.0,     Z.data() + ( sl + sr ) * sl, sl + sr );
374
        }
375
        else
376
        {
377
          printf( "bug\n" ); exit( 1 );
378
          /** ClrVrtUr */
379
          xgemm( "N", "N", sl, sr, sr,
380
              1.0,   Clr.data(), sl,
381
              VrtUr.data(), sr,
382
              0.0,     Z.data() + ( sl + sr ) * sl, sl + sr );
383
        }
384
385
        /** compute 1-norm of Z */
386
        T nrm1 = 0.0;
387
        for ( size_t i = 0; i < Z.size(); i ++ )
388
          nrm1 += std::abs( Z[ i ] );
389
390
        /** LU factorization */
391
        xgetrf( Z.row(), Z.col(), Z.data(), Z.row(), ipiv.data() );
392
393
        /** record points of children factors */
394
        this->Ul = &Ul;
395
        this->Ur = &Ur;
396
        this->Vl = &Vl;
397
        this->Vr = &Vr;
398
399
        /** compute 1-norm condition number */
400
        T rcond1 = 0.0;
401
        hmlp::Data<T> work( Z.row(), 4 );
402
        std::vector<int> iwork( Z.row() );
403
        xgecon( "1", Z.row(), Z.data(), Z.row(), nrm1,
404
            &rcond1, work.data(), iwork.data() );
405
        if ( 1.0 / rcond1 > 1E+6 )
406
          printf( "Warning! large 1-norm condition number %3.1E\n",
407
              1.0 / rcond1 ); fflush( stdout );
408
      }
409
410
    }; /** end Factorize() */
411
412
413
    void PartialFactorize(
414
      /** Zl,  nl-by-nl,  Zr,  nr-by-nr */
415
      View<T> &Zl, View<T> &Zr,
416
      /** Ul,  nl-by-sl,  Ur,  nr-by-sr */
417
      Data<T> &Ul, Data<T> &Ur,
418
      /** Vl,  nl-by-sr,  Vr,  nr-by-sr */
419
      Data<T> &Vl, Data<T> &Vr )
420
    {
421
      Z.resize( 0, 0 );
422
      Z.resize( sl + sr, sl + sr, 0.0 );
423
424
      /** Create matrix views for Z. */
425
      Zv.Set( false, Z );
426
      Zv.Partition2x2( Ztl, Ztr,
427
                       Zbl, Zbr, sl, sl, TOPLEFT );
428
429
      //printf( "Ztl %lux%lu Ztr %lux%lu\n", Ztl.row(), Ztl.col(), Ztr.row(), Ztr.col() ); fflush( stdout );
430
      //printf( "Zbl %lux%lu Zbr %lux%lu\n", Zbl.row(), Zbl.col(), Zbr.row(), Zbr.col() ); fflush( stdout );
431
432
433
      Zbl.CopyValuesFrom( Crl );
434
      /** trmm */
435
      xtrmm( "Right", "Upper",     "Transpose", "Non-unit", Zbl.row(), Zbl.col(),
436
        1.0,  Ul.data(),  Ul.row(), Zbl.data(), Zbl.ld() );
437
      /** trmm */
438
      xtrmm(  "Left", "Upper", "Non-transpose", "Non-unit", Zbl.row(), Zbl.col(),
439
        1.0,  Ur.data(),  Ur.row(), Zbl.data(), Zbl.ld() );
440
441
      Ztl.CopyValuesFrom( Zl );
442
      Zbr.CopyValuesFrom( Zr );
443
444
      for ( size_t j = 0; j < sl; j ++ )
445
        for ( size_t i = 0; i < sr; i ++ )
446
          Ztr( j, i ) = Zbl( i, j );
447
448
      PartialFactorize( Z );
449
450
    }; /** end PartialFactorize() */
451
452
453
454
455
    /** */
456
    void Multiply( View<T> &bl, View<T> &br )
457
    {
458
      assert( !isleaf && bl.col() == br.col() );
459
460
      size_t nrhs = bl.col();
461
462
      std::vector<T> ta( ( sl + sr ) * nrhs );
463
      std::vector<T> tl(      sl * nrhs );
464
      std::vector<T> tr(      sr * nrhs );
465
466
      /** Vl' * bl */
467
      xgemm( "T", "N", sl, nrhs, nl,
468
          1.0, Vl->data(), nl,
469
                bl.data(), bl.ld(),
470
          0.0,  tl.data(), sl );
471
      /** Vr' * br */
472
      xgemm( "T", "N", sr, nrhs, nr,
473
          1.0, Vr->data(), nr,
474
                br.data(), br.ld(),
475
          0.0,  tr.data(), sr );
476
477
      /** Crl * Vl' * bl */
478
      xgemm( "N", "N", sr, nrhs, sl,
479
          1.0, Crl.data(), sr,
480
                tl.data(), sl,
481
          0.0,  ta.data() + sl, sl + sr );
482
483
      if ( issymmetric )
484
      {
485
        /** Crl' * Vr' * br */
486
        xgemm( "T", "N", sl, nrhs, sr,
487
            1.0, Crl.data(), sr,
488
                  tr.data(), sr,
489
            0.0,  ta.data(), sl + sr );
490
      }
491
      else
492
      {
493
        printf( "bug here !!!!!\n" ); fflush( stdout ); exit( 1 );
494
        /** Clr * Vr' * br */
495
        xgemm( "N", "N", sl, nrhs, sr,
496
            1.0, Clr.data(), sl,
497
                  tr.data(), sr,
498
            0.0,  ta.data(), sl + sr );
499
      }
500
501
      /** bl += Ul * xl */
502
      xgemm( "N", "N", nl, nrhs, sl,
503
        -1.0, Ul->data(), nl,
504
               ta.data(), sl + sr,
505
         1.0,  bl.data(), bl.ld() );
506
507
      /** br += Ur * xr */
508
      xgemm( "N", "N", nr, nrhs, sr,
509
         -1.0, Ur->data(), nr,
510
                ta.data() + sl, sl + sr,
511
          1.0,  br.data(), br.ld() );
512
    };
513
514
    /**
515
     *  @brief Solver for leaf nodes
516
     */
517
    void Solve( View<T> &rhs )
518
    {
519
      /** assure this is a leaf node */
520
      assert( isleaf );
521
      assert( !do_ulv_factorization );
522
      assert( rhs.data() && Z.data() );
523
      assert( ipiv.data() );
524
525
      //rhs.Print();
526
527
      size_t nrhs = rhs.col();
528
529
      /** LU solver */
530
      xgetrs( "Non-transpose", rhs.row(), nrhs,
531
          Z.data(), Z.row(), ipiv.data(),
532
          rhs.data(), rhs.ld() );
533
534
    }; /** end Solve() */
535
536
537
538
    /**
539
     *  @brief b - U * inv( Z ) * C * V' * b
540
     */
541
    void Solve( View<T> &bl, View<T> &br )
542
    {
543
      size_t nrhs = bl.col();
544
545
      //bl.Print();
546
      //br.Print();
547
548
      /** assertion */
549
      assert( !do_ulv_factorization );
550
      assert( bl.col() == br.col() );
551
      assert( bl.row() == nl );
552
      assert( br.row() == nr );
553
      assert( Ul && Ur && Vl && Vr );
554
555
      /** buffer */
556
//      hmlp::Data<T> ta( sl + sr, nrhs );
557
//      hmlp::Data<T> tl(      sl, nrhs );
558
//      hmlp::Data<T> tr(      sr, nrhs );
559
560
      vector<T> ta( ( sl + sr ) * nrhs );
561
      vector<T> tl(      sl * nrhs );
562
      vector<T> tr(      sr * nrhs );
563
564
565
      ///** views of buffer */
566
      //hmlp::View<T> xa( ta ), xl, xr;
567
568
      ///** xa = [ xl; xr; ] */
569
      //xa.Partition2x1
570
      //(
571
      //  xl,
572
      //  xr, sl
573
      //);
574
575
576
        /** Vl' * bl */
577
        xgemm( "T", "N", sl, nrhs, nl,
578
            1.0, Vl->data(), nl,
579
                  bl.data(), bl.ld(),
580
            0.0,  tl.data(), sl );
581
        /** Vr' * br */
582
        xgemm( "T", "N", sr, nrhs, nr,
583
            1.0, Vr->data(), nr,
584
                  br.data(), br.ld(),
585
            0.0,  tr.data(), sr );
586
587
588
        /** Crl * Vl' * bl */
589
        xgemm( "N", "N", sr, nrhs, sl,
590
            1.0, Crl.data(), sr,
591
            tl.data(), sl,
592
            0.0,  ta.data() + sl, sl + sr );
593
594
        if ( issymmetric )
595
        {
596
          /** Crl' * Vr' * br */
597
          xgemm( "T", "N", sl, nrhs, sr,
598
              1.0, Crl.data(), sr,
599
              tr.data(), sr,
600
              0.0,  ta.data(), sl + sr );
601
        }
602
        else
603
        {
604
          printf( "bug here !!!!!\n" ); fflush( stdout ); exit( 1 );
605
          /** Clr * Vr' * br */
606
          xgemm( "N", "N", sl, nrhs, sr,
607
              1.0, Clr.data(), sl,
608
              tr.data(), sr,
609
              0.0,  ta.data(), sl + sr );
610
        }
611
612
        /** inv( Z ) * x */
613
        xgetrs( "N", sl + sr, nrhs,
614
            Z.data(), Z.row(), ipiv.data(),
615
            ta.data(), sl + sr );
616
617
      /** bl -= Ul * xl */
618
      xgemm( "N", "N", nl, nrhs, sl,
619
          -1.0, Ul->data(), nl,
620
          ta.data(), sl + sr,
621
          1.0,  bl.data(), bl.ld() );
622
623
      /** br -= Ur * xr */
624
      xgemm( "N", "N", nr, nrhs, sr,
625
          -1.0, Ur->data(), nr,
626
          ta.data() + sl, sl + sr,
627
          1.0,  br.data(), br.ld() );
628
629
    }; /** end Solve() */
630
631
632
633
634
    void Telescope
635
    (
636
      bool DO_INVERSE,
637
      /** n-by-s */
638
      Data<T> &Pa,
639
      /** s-by-(sl+sr) */
640
      Data<T> &Palr
641
    )
642
    {
643
      assert( isleaf );
644
      /** Initialize Pa */
645
      Pa.resize( n, s, 0.0 );
646
647
      /** create view and subviews for Pa */
648
      //hmlp::View<T> Xa;
649
650
      //Xa.Set( Pa );
651
652
      assert( Palr.row() == s ); assert( Palr.col() == n );
653
654
      /** Pa = Palr' */
655
      for ( size_t j = 0; j < Pa.col(); j ++ )
656
        for ( size_t i = 0; i < Pa.row(); i ++ )
657
          Pa( i, j ) = Palr( j, i );
658
659
      if ( DO_INVERSE )
660
      {
661
        if ( do_ulv_factorization )
662
        {
663
          xtrsm( "Left", "Lower", "No transpose", "Non-unit",
664
              Pa.row(), Pa.col(),
665
              1.0,  Z.data(),  Z.row(), Pa.data(), Pa.row() );
666
        }
667
        else
668
        {
669
          assert( ipiv.size() );
670
          /** LU solver */
671
          xgetrs( "Non-transpose",
672
              n, s, Z.data(), n, ipiv.data(), Pa.data(), n );
673
        }
674
      }
675
676
677
      //printf( "call solve from telescope\n" ); fflush( stdout );
678
      //if ( DO_INVERSE ) Solve<true>( Xa );
679
      //printf( "call solve from telescope (exist)\n" ); fflush( stdout );
680
681
    }; /** end Telescope() */
682
683
684
    /** RIGHT: V = [ P(:, 0:st-1) * Vl , P(:,st:st+sb-1) * Vr ]
685
     *  LEFT:  U = [ Ul * P(:, 0:st-1)'; Ur * P(:,st:st+sb-1) ] */
686
    void Telescope
687
    (
688
      bool DO_INVERSE,
689
      /** n-by-s */
690
      Data<T> &Pa,
691
      /** s-by-(sl+sr) */
692
      Data<T> &Palr,
693
      /** nl-by-sl */
694
      Data<T> &Pl,
695
      /** nr-by-sr */
696
      Data<T> &Pr
697
    )
698
    {
699
      assert( !isleaf );
700
      assert( n == nl + nr );
701
      assert( Pl.col() == sl );
702
      assert( Pr.col() == sr );
703
      assert( Palr.row() == s  ); assert( Palr.col() == ( sl + sr ) );
704
705
      /** Initialize Pa */
706
      Pa.resize( 0, 0 );
707
708
      /** create view and subviews for Pa */
709
      //hmlp::View<T> Xa;
710
711
      //Xa.Set( Pa );
712
      //assert( Xa.row() == Pa.row() );
713
      //assert( Xa.col() == Pa.col() );
714
715
      if ( do_ulv_factorization )
716
      {
717
        Pa.resize( sl + sr, s, 0.0 );
718
719
        /** Pa = Palr' */
720
        for ( size_t j = 0; j < Pa.col(); j ++ )
721
          for ( size_t i = 0; i < Pa.row(); i ++ )
722
            Pa[ j * Pa.row() + i ] = Palr[ i * Palr.row() + j ];
723
724
        /** Pa( 0:sl-1, : ) = Pl * Palr( :, 0:sl-1 )' */
725
        xtrmm( "Left", "Upper", "No Transpose", "Non-unit", sl, s,
726
           1.0,   Pl.data(), Pl.row(),
727
                  Pa.data(), Pa.row() );
728
        //  printf( "Pl.row() %lu Pa.row() %lu Pa.col() %lu\n",
729
        //      Pl.row(), Pa.row(), Pa.col() );
730
731
        /** Pa( sl:sl+sr-1, : ) = Pr * Palr( :, sl:sl+sr-1 )' */
732
        xtrmm( "Left", "Upper", "No Transpose", "Non-unit", sr, s,
733
           1.0,   Pr.data()     , Pr.row(),
734
                  Pa.data() + sl, Pa.row() );
735
        //  printf( "Pr.row() %lu Pa.row() %lu Pa.col() %lu\n",
736
        //      Pr.row(), Pa.row(), Pa.col() );
737
738
        /** inv( L ) * Pa */
739
        if ( DO_INVERSE )
740
        {
741
          if ( 1 )
742
          {
743
            xtrsm( "Left", "Lower", "No transpose", "Non-unit",
744
                Pa.row(), Pa.col(),
745
                1.0,  Z.data(),  Z.row(),
746
                     Pa.data(), Pa.row() );
747
          }
748
          else
749
          {
750
            xlaswp( Pa.col(), Pa.data(), Pa.row(),
751
                1, Pa.row(), ipiv.data(), 1 );
752
            xtrsm( "Left", "Lower", "No transpose", "Unit", Pa.row(), Pa.col(),
753
                1.0,  Z.data(),  Z.row(),
754
                     Pa.data(), Pa.row() );
755
          }
756
          //printf( "Z.row() %lu Z.col() %lu\n", Z.row(), Z.col() );
757
        }
758
759
      }
760
      else /** Shernman-Morrison-Woodbury */
761
      {
762
        Pa.resize( nl + nr, s, 0.0 );
763
764
        ///** */
765
        //hmlp::View<T> Xl, Xr;
766
767
        ///** Xa = [ Xl; Xr; ] */
768
        //Xa.Partition2x1
769
        //(
770
        //  Xl,
771
        //  Xr, nl
772
        //);
773
774
        //assert( Xl.row() == nl );
775
        //assert( Xr.row() == nr );
776
        //assert( Xl.col() == s );
777
        //assert( Xr.col() == s );
778
779
780
        /** Pa( 0:nl-1, : ) = Pl * Palr( :, 0:sl-1 )' */
781
        xgemm( "N", "T", nl, s, sl,
782
            1.0,   Pl.data(), nl,
783
                 Palr.data(), s,
784
            0.0,   Pa.data(), n );
785
        /** Pa( nl:n-1, : ) = Pr * Palr( :, sl:sl+sr-1 )' */
786
        xgemm( "N", "T", nr, s, sr,
787
            1.0,   Pr.data(), nr,
788
                 Palr.data() + s * sl, s,
789
            0.0,   Pa.data() + nl, n );
790
791
792
793
        //if ( DO_INVERSE ) Solve<true>( Xl, Xr );
794
        //printf( "end inner solve from telescope\n" ); fflush( stdout );
795
796
        if ( DO_INVERSE )
797
        {
798
            Data<T> x( sl + sr, s );
799
            Data<T> xl( sl, s );
800
            Data<T> xr( sr, s );
801
802
            /** xl = Vlt * Pa( 0:nl-1, : ) */
803
            xgemm( "T", "N", sl, s, nl,
804
                1.0, Vl->data(), nl,
805
                Pa.data(), n,
806
                0.0,  xl.data(), sl );
807
            /** xr = Vrt * Pa( nl:n-1, : ) */
808
            xgemm( "T", "N", sr, s, nr,
809
                1.0, Vr->data(), nr,
810
                Pa.data() + nl, n,
811
                0.0,  xr.data(), sr );
812
813
            /** b = [ Crl' * xr;
814
             *        Crl  * xl; ] */
815
            xgemm( "T", "N", sl, s, sr,
816
                1.0, Crl.data(), sr,
817
                xr.data(), sr,
818
                0.0,   x.data(), sl + sr );
819
            xgemm( "N", "N", sr, s, sl,
820
                1.0, Crl.data(), sr,
821
                xl.data(), sl,
822
                0.0,   x.data() + sl, sl + sr );
823
824
            /** b = inv( Z ) * b */
825
            xgetrs( "N", x.row(), x.col(),
826
                Z.data(), Z.row(), ipiv.data(),
827
                x.data(), x.row() );
828
829
            /** Pa( 0:nl-1, : ) -= Ul * b( 0:sl-1, : ) */
830
            xgemm( "N", "N", nl, s, sl,
831
                -1.0, Ul->data(), nl,
832
                x.data(), sl + sr,
833
                1.0,  Pa.data(), n );
834
            /** Pa( nl:n-1, : ) -= Ur * b( sl:sl+sr-1, : ) */
835
            xgemm( "N", "N", nr, s, sr,
836
                -1.0, Ur->data(), nr,
837
                x.data() + sl, sl + sr,
838
                1.0,  Pa.data() + nl, n );
839
        } /** end if ( DO_INVERSE ) */
840
      } /** end if ( do_ulv_factorization )*/
841
    };
842
843
    /** */
844
    void Orthogonalization()
845
    {
846
      /** Initialize householder reflectors "tau". */
847
      tau.resize( std::min( U.row(), U.col() ) );
848
      /** Initialize work space for xgeqrf. */
849
      Data<T> work( U.col() * 512, 1 );
850
      /** QR factorization without column pivoting. */
851
      xgeqrf( U.row(), U.col(), U.data(), U.row(),
852
          tau.data(), work.data(), work.size() );
853
      /** Copy U to Q to generate the full orthonormal basis. */
854
      Q = U;
855
      /** Increase the rank of Q to full rank. */
856
      Q.resize( U.row(), U.row() );
857
      /** Generate the full orthonormal basis Q. */
858
      xorgqr( Q.row(), Q.col(), U.col(), Q.data(), Q.row(), tau.data(),
859
          work.data(), work.size() );
860
861
862
863
      /** Create views Qv = [Q1, Q2] for Q. */
864
      Qv.Set( false, Q );
865
      Qv.Partition1x2( Q1, Q2, tau.size(), LEFT );
866
      /** Sanity check for Q1'Q1 and Q2'Q2 and Q1'Q2. */
867
      Data<T> C = Q;
868
      Data<T> D = Q;
869
870
      xgemm( "Transpose", "No Transpose", C.row(), C.col(), Q.row(),
871
          1.0, Q.data(), Q.row(),
872
               Q.data(), Q.row(),
873
          0.0, C.data(), C.row() );
874
875
      xgemm( "No Transpose", "Transpose", D.row(), D.col(), Q.row(),
876
          1.0, Q.data(), Q.row(),
877
               Q.data(), Q.row(),
878
          0.0, D.data(), D.row() );
879
880
      for ( size_t j = 0; j < Q.col(); j ++ )
881
      {
882
        for ( size_t i = 0; i < Q.row(); i ++ )
883
        {
884
          if ( i == j ) assert( std::fabs( C( i, j ) - 1 ) < 1E-5 );
885
          else          assert( std::fabs( C( i, j ) - 0 ) < 1E-5 );
886
        }
887
      }
888
      for ( size_t j = 0; j < Q.col(); j ++ )
889
      {
890
        for ( size_t i = 0; i < Q.row(); i ++ )
891
        {
892
          if ( i == j ) assert( std::fabs( D( i, j ) - 1 ) < 1E-5 );
893
          else          assert( std::fabs( D( i, j ) - 0 ) < 1E-5 );
894
        }
895
      }
896
897
    };
898
899
900
901
902
    /** [Q2 Q1]' * B or B * [Q2 Q1] */
903
    void ChangeBasis( SideType side, Data<T> &B )
904
    {
905
      /** Early return if Q does not exist. */
906
      if ( !Q.size() ) return;
907
908
      /** Create a deep copy of B. */
909
      Data<T> A = B;
910
911
      /** Create matrix views for A and B. */
912
      View<T> Av( false, A );
913
      View<T> Bv( false, B );
914
      View<T> Bl, Br, Bt, Bb;
915
916
      /** Enumerate case "LEFT", "RIGHT", and execptions. */
917
      switch ( side )
918
      {
919
        case LEFT:
920
        {
921
          /** Partition Bv = [ Bt; Bb ]. */
922
          Bv.Partition2x1( Bt,
923
                           Bb,     Q2.col(),  TOP );
924
          //printf("Bt %lux%lu Bb %lux%lu\n", Bt.row(), Bt.col(),
925
          //                                  Bb.row(), Bb.col() ); fflush( stdout );
926
927
          /** Bt = Q2' * A */
928
          xgemm( "Transpose", "No Transpose", Bt.row(), Bt.col(), Q2.row(),
929
              1.0, Q2.data(), Q2.ld(),
930
                   Av.data(), Av.ld(),
931
              0.0, Bt.data(), Bt.ld() );
932
          /** Bb = Q1' * A */
933
          xgemm( "Transpose", "No Transpose", Bb.row(), Bb.col(), Q1.row(),
934
              1.0, Q1.data(), Q1.ld(),
935
                   Av.data(), Av.ld(),
936
              0.0, Bb.data(), Bb.ld() );
937
          break;
938
        }
939
        case RIGHT:
940
        {
941
          /** Partition Bv = [ Bl, Br ]. */
942
          Bv.Partition1x2( Bl, Br, Q2.col(), LEFT );
943
944
          //printf("Bl %lux%lu Br %lux%lu\n", Bl.row(), Bl.col(),
945
          //                                  Br.row(), Br.col() ); fflush( stdout );
946
947
          /** Bl = A * Q2 */
948
          xgemm( "No Transpose", "No Transpose", Bl.row(), Bl.col(), Q2.row(),
949
              1.0, Av.data(), Av.ld(),
950
                   Q2.data(), Q2.ld(),
951
              0.0, Bl.data(), Bl.ld() );
952
          /** Br = A * Q1 */
953
          xgemm( "No Transpose", "No Transpose", Br.row(), Br.col(), Q1.row(),
954
              1.0, Av.data(), Av.ld(),
955
                   Q1.data(), Q1.ld(),
956
              0.0, Br.data(), Br.ld() );
957
          break;
958
        }
959
        default:
960
        {
961
          /** Do nothing and throw exception. */
962
          throw "Value of (SideType) side is not recognized.";
963
        }
964
      }
965
      //printf( "end ChangeBasis\n" ); fflush( stdout );
966
    }; /** changeBasis() */
967
968
969
    /** [Q2 Q1]' * A * [Q2 Q1] */
970
    void ChangeBasis( Data<T> &A )
971
    {
972
      ChangeBasis(  LEFT, A );
973
      ChangeBasis( RIGHT, A );
974
    }; /** changeBasis() */
975
976
977
978
    void ULVForward()
979
    {
980
      /** For internal nodes, B has been initialized by children. */
981
      if ( isleaf ) B = bview.toData();
982
      /** B = Q' * B */
983
      ChangeBasis( LEFT, B );
984
      /** P * Bf */
985
      xlaswp( Bf.col(), Bf.data(), Bf.ld(), 1, Bf.row(), ipiv.data(), 1 );
986
      /** Lff^{-1} * P * Bf, where Lff is the lower-triangular part of Ztl. */
987
      xtrsm( "Left", "Lower", "No transpose", "Unit", Bf.row(), Bf.col(),
988
          1.0, Ztl.data(), Ztl.ld(), Bf.data(), Bf.ld() );
989
      /** Bc -= Lcf * Bf, where Lcf is Zbl. */
990
      xgemm( "No Transpose", "No Transpose", Bc.row(), Bc.col(), Bf.row(),
991
          -1.0, Zbl.data(), Zbl.ld(), Bf.data(), Bf.ld(), 1.0, Bc.data(), Bc.ld() );
992
      //printf( "Bc %lux%lu Bp %lux%lu\n", Bc.row(), Bc.col(), Bp.row(), Bp.col() ); fflush( stdout );
993
      /** Copy Bc to Bp (subview of parent's B). */
994
      Bp.CopyValuesFrom( Bc );
995
    }; /** end ULVForward() */
996
997
998
    void ULVBackward()
999
    {
1000
      /** Copy Bp (subview of parent's B) to Bc. */
1001
      Bc.CopyValuesFrom( Bp );
1002
      /** Bf -= Ufc * Bc, where Ufc is Ztr. */
1003
      xgemm( "No Transpose", "No Transpose", Bf.row(), Bf.col(), Bc.row(),
1004
          -1.0, Ztr.data(), Ztr.ld(), Bc.data(), Bc.ld(), 1.0, Bf.data(), Bf.ld() );
1005
      /** Lff^{-1} * P * Bf, where Lff is the lower-triangular part of Ztl. */
1006
      xtrsm( "Left", "Upper", "No transpose", "Non-unit", Bf.row(), Bf.col(),
1007
          1.0, Ztl.data(), Ztl.ld(), Bf.data(), Bf.ld() );
1008
      if ( Q.size() )
1009
      {
1010
        /** Create a temporary buffer for projection Q2 * Bf + Q1 * Bc. */
1011
        Data<T> A = B;
1012
        xgemm( "No Transpose", "No Transpose", A.row(), A.col(), Bf.row(),
1013
            1.0, Q2.data(), Q2.ld(), Bf.data(), Bf.ld(), 0.0, A.data(), A.row() );
1014
        xgemm( "No Transpose", "No Transpose", A.row(), A.col(), Bc.row(),
1015
            1.0, Q1.data(), Q1.ld(), Bc.data(), Bc.ld(), 1.0, A.data(), A.row() );
1016
        /** Copy A back to B. */
1017
        if ( isleaf ) bview.CopyValuesFrom( A );
1018
        else Bv.CopyValuesFrom( A );
1019
      }
1020
    }; /** end ULVBackward() */
1021
1022
1023
1024
1025
1026
1027
    bool isleaf = false;
1028
1029
    bool isroot = false;
1030
1031
    size_t n = 0;
1032
1033
    size_t nl = 0;
1034
1035
    size_t nr = 0;
1036
1037
    size_t s = 0;
1038
1039
    size_t sl = 0;
1040
1041
    size_t sr = 0;
1042
1043
1044
    /** Reduced system Z = [ I  VU   if ( HODLR || p-HSS )
1045
     *                       VU  I ] */
1046
    Data<T> Z;
1047
    View<T> Zv;
1048
    View<T> Ztl, Ztr, Zbl, Zbr;
1049
1050
    /** Partial pivoting order (used in GETRF). */
1051
    vector<int> ipiv;
1052
1053
    /** n-by-s (SMW) or (sl+sr)-by-s (ULV) */
1054
    Data<T> U, V;
1055
1056
    /** sr-by-sl and sl-by-sr, skeleton row and column basis. */
1057
    Data<T> Crl, Clr;
1058
1059
    /** A correspinding view of the right hand side of this node. */
1060
    View<T> bview;
1061
1062
    /** Pointers to children's factors */
1063
    Data<T> *Ul = NULL;
1064
    Data<T> *Ur = NULL;
1065
    Data<T> *Vl = NULL;
1066
    Data<T> *Vr = NULL;
1067
1068
    /** Q, (sl+sr)-by-s (ULV) */
1069
    Data<T> Q;
1070
    View<T> Qv, Q1, Q2;
1071
1072
    /** tau, sl+sr (used in xgeqrf( U ) of ULV) */
1073
    vector<T> tau;
1074
1075
    /** Temporary buffer for the solve. */
1076
    Data<T> B;
1077
    View<T> Bv, Bp, Bsibling, Bf, Bc;
1078
1079
  private: /** this class will be public inherit by gofmm::Data<T> */
1080
1081
    bool issymmetric = true;
1082
1083
    bool do_ulv_factorization = false;
1084
1085
}; /** end class Factor */
1086
1087
1088
/**
1089
 *  @brief
1090
 */
1091
template<typename NODE, typename T>
1092
void SetupFactor( NODE *node )
1093
{
1094
  size_t n, nl, nr, s, sl, sr;
1095
  bool issymmetric, do_ulv_factorization;
1096
1097
1098
#ifdef DEBUG_IGOFMM
1099
  printf( "begin SetupFactor %lu\n", node->treelist_id ); fflush( stdout );
1100
#endif
1101
1102
  issymmetric = node->setup->IsSymmetric();
1103
  do_ulv_factorization = node->setup->do_ulv_factorization;
1104
  n  = node->n;
1105
  nl = 0;
1106
  nr = 0;
1107
  s  = node->data.skels.size();
1108
  sl = 0;
1109
  sr = 0;
1110
1111
  if ( !node->isleaf )
1112
  {
1113
    nl = node->lchild->n;
1114
    nr = node->rchild->n;
1115
    sl = node->lchild->data.skels.size();
1116
    sr = node->rchild->data.skels.size();
1117
  }
1118
1119
1120
  node->data.SetupFactor( issymmetric, do_ulv_factorization,
1121
    node->isleaf, !node->l, n, nl, nr, s, sl, sr );
1122
1123
#ifdef DEBUG_IGOFMM
1124
  printf( "end SetupFactor %lu\n", node->treelist_id ); fflush( stdout );
1125
#endif
1126
1127
}; /** end void SetupFactor() */
1128
1129
1130
/**
1131
 *  @brief
1132
 */
1133
template<typename NODE, typename T>
1134
class SetupFactorTask : public Task
1135
{
1136
  public:
1137
1138
    NODE *arg = NULL;
1139
1140
    void Set( NODE *user_arg )
1141
    {
1142
      arg = user_arg;
1143
      name = string( "sf" );
1144
      label = to_string( arg->treelist_id );
1145
      cost = 1.0;
1146
    };
1147
1148
    void GetEventRecord()
1149
    {
1150
      double flops = 0.0, mops = 0.0;
1151
      event.Set( label + name, flops, mops );
1152
    };
1153
1154
    void DependencyAnalysis()
1155
    {
1156
      arg->DependencyAnalysis( W, this );
1157
      this->TryEnqueue();
1158
    };
1159
1160
    void Execute( Worker* user_worker )
1161
    {
1162
      SetupFactor<NODE, T>( arg );
1163
    };
1164
1165
}; /** end class SetupFactorTask */
1166
1167
1168
1169
1170
template<typename NODE>
1171
void SolverTreeView( NODE *node )
1172
{
1173
  auto &data   = node->data;
1174
  auto *setup  = node->setup;
1175
  auto &input  = *(setup->input);
1176
  auto &output = *(setup->output);
1177
  /** Allocate working buffer for ULV solve. */
1178
  if ( node->isleaf ) data.B.resize( data.n, input.col() );
1179
  else data.B.resize( data.sl + data.sr, input.col() );
1180
1181
  /** Partition B = [ Bf; Bc ] with matrix view. */
1182
  data.Bv.Set( data.B );
1183
  data.Bv.Partition2x1( data.Bf,
1184
                        data.Bc,  data.s, BOTTOM );
1185
1186
  /** Create contigious matrix view for output at root level. */
1187
  if ( !node->parent ) data.bview.Set( output );
1188
1189
  /** Hierarchical tree view. */
1190
  if ( !node->isleaf )
1191
  {
1192
    auto &ldata = node->lchild->data;
1193
    auto &rdata = node->rchild->data;
1194
    /** Partition b = [ bl; br; ] with matrix view. */
1195
    data.bview.Partition2x1( ldata.bview,
1196
                             rdata.bview, data.nl, TOP );
1197
    data.Bv.Partition2x1( ldata.Bp,
1198
                          rdata.Bp, data.sl, TOP );
1199
  }
1200
}; /** end SolverTreeView() */
1201
1202
1203
1204
1205
/** @brief Creates an hierarchical tree view for a matrix. */
1206
template<typename NODE>
1207
class SolverTreeViewTask : public Task
1208
{
1209
  public:
1210
1211
    NODE *arg = NULL;
1212
1213
    void Set( NODE *user_arg )
1214
    {
1215
      arg = user_arg;
1216
      name = string( "TreeView" );
1217
      label = to_string( arg->treelist_id );
1218
      cost = 1.0;
1219
    };
1220
1221
    void GetEventRecord()
1222
    {
1223
      double flops = 0.0, mops = 0.0;
1224
      event.Set( label + name, flops, mops );
1225
    };
1226
1227
    /** Preorder dependencies (with a single source node) */
1228
    void DependencyAnalysis() { arg->DependOnParent( this ); };
1229
1230
    void Execute( Worker* user_worker ) { SolverTreeView( arg ); };
1231
1232
}; /** end class TreeViewTask */
1233
1234
1235
1236
/**
1237
 *  @brief doward traversal to create matrix views, at the leaf
1238
 *         level execute explicit permutation.
1239
 */
1240
template<bool FORWARD, typename NODE>
1241
class MatrixPermuteTask : public hmlp::Task
1242
{
1243
  public:
1244
1245
    NODE *arg;
1246
1247
    void Set( NODE *user_arg )
1248
    {
1249
      name = std::string( "MatrixPermutation" );
1250
      arg = user_arg;
1251
      cost = 1.0;
1252
    };
1253
1254
    void GetEventRecord()
1255
    {
1256
      double flops = 0.0, mops = 0.0;
1257
      event.Set( label + name, flops, mops );
1258
    };
1259
1260
    /** depends on previous task */
1261
    void DependencyAnalysis()
1262
    {
1263
      if ( FORWARD )
1264
      {
1265
        arg->DependencyAnalysis( RW, this );
1266
      }
1267
      else
1268
      {
1269
        this->Enqueue();
1270
      }
1271
    };
1272
1273
    void Execute( Worker* user_worker )
1274
    {
1275
      //printf( "PermuteMatrix %lu\n", arg->treelist_id );
1276
      auto *node   = arg;
1277
      auto &gids   = node->gids;
1278
      auto &input  = *(node->setup->input);
1279
      auto &output = *(node->setup->output);
1280
      auto &A      = node->data.bview;
1281
1282
      assert( A.row() == gids.size() );
1283
      assert( A.col() == input.col() );
1284
1285
      //for ( size_t i = 0; i < gids.size(); i ++ )
1286
      //  printf( "%lu ", gids[ i ] );
1287
      //printf( "\n" );
1288
1289
      /** perform permutation and output */
1290
      for ( size_t j = 0; j < input.col(); j ++ )
1291
        for ( size_t i = 0; i < gids.size(); i ++ )
1292
          /** foward  permutation */
1293
          if ( FORWARD ) A( i, j ) = input( gids[ i ], j );
1294
          /** inverse permutation */
1295
          else           input( gids[ i ], j ) = A( i, j );
1296
1297
      //for ( size_t j = 0; j < 1; j ++ )
1298
      //  for ( size_t i = 0; i < gids.size(); i ++ )
1299
      //    printf( "%E ", A( i, j ) );
1300
      //printf( "\n" );
1301
1302
      //printf( "end PermuteMatrix %lu\n", arg->treelist_id );
1303
    };
1304
1305
}; /** end class MatrixPermuteTask */
1306
1307
1308
1309
/**
1310
 *  @brief
1311
 */
1312
template<typename NODE, typename T>
1313
void Apply( NODE *node )
1314
{
1315
  auto &data = node->data;
1316
  auto &setup = node->setup;
1317
  auto &K = *setup->K;
1318
1319
  if ( node->isleaf )
1320
  {
1321
    auto lambda = setup->lambda;
1322
    auto &amap = node->gids;
1323
    /** evaluate the diagonal block */
1324
    auto Kaa = K( amap, amap );
1325
    /** apply the regularization */
1326
    for ( size_t i = 0; i < Kaa.row(); i ++ )
1327
      Kaa[ i * Kaa.row() + i ] += lambda;
1328
  }
1329
  else
1330
  {
1331
    auto &bl = node->lchild->data.bview;
1332
    auto &br = node->rchild->data.bview;
1333
    data.Apply<true>( bl, br );
1334
  }
1335
}; /** end Apply() */
1336
1337
1338
1339
//template<typename NODE, typename T>
1340
//void ULVForwardSolve( NODE *node ) { node->data.ULVForward(); };
1341
1342
1343
1344
template<typename NODE, typename T>
1345
class ULVForwardSolveTask : public Task
1346
{
1347
  public:
1348
1349
    NODE *arg = NULL;
1350
1351
    void Set( NODE *user_arg )
1352
    {
1353
      arg = user_arg;
1354
      name = string( "ulvforward" );
1355
      label = to_string( arg->treelist_id );
1356
      cost = 1.0;
1357
    };
1358
1359
    //void DependencyAnalysis()
1360
    //{
1361
    //  arg->DependencyAnalysis( RW, this );
1362
    //  /** depend on two children */
1363
    //  if ( !arg->isleaf )
1364
    //  {
1365
    //    arg->lchild->DependencyAnalysis( R, this );
1366
    //    arg->rchild->DependencyAnalysis( R, this );
1367
    //  }
1368
    //  /** dispatch the task if there is no dependency */
1369
    //  this->TryEnqueue();
1370
    //};
1371
1372
    void DependencyAnalysis() { arg->DependOnChildren( this ); };
1373
1374
1375
    void Execute( Worker* user_worker ) { arg->data.ULVForward(); };
1376
1377
}; /** end class ULVForwardSolveTask */
1378
1379
1380
1381
1382
template<typename NODE, typename T>
1383
class ULVBackwardSolveTask : public Task
1384
{
1385
  public:
1386
1387
    NODE *arg;
1388
1389
    void Set( NODE *user_arg )
1390
    {
1391
      arg = user_arg;
1392
      name = string( "ulvbackward" );
1393
      label = std::to_string( arg->treelist_id );
1394
      cost = 1.0;
1395
1396
      //printf( "Set treelist_id %lu\n", arg->treelist_id ); fflush( stdout );
1397
    };
1398
1399
    //void DependencyAnalysis()
1400
    //{
1401
    //  /** depend on parent */
1402
    //  if ( arg->parent )
1403
    //    arg->parent->DependencyAnalysis( hmlp::ReadWriteType::R, this );
1404
    //  arg->DependencyAnalysis( hmlp::ReadWriteType::RW, this );
1405
    //  /** dispatch the task if there is no dependency */
1406
    //  this->TryEnqueue();
1407
    //};
1408
1409
1410
    void DependencyAnalysis() { arg->DependOnParent( this ); };
1411
1412
    void Execute( Worker* user_worker ) { arg->data.ULVBackward(); };
1413
1414
}; /** end class ULVBackwardSolveTask */
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
/**
1434
 *  @brief
1435
 */
1436
template<typename NODE, typename T>
1437
void Solve( NODE *node )
1438
{
1439
1440
  auto &data = node->data;
1441
  auto &setup = node->setup;
1442
  auto &K = *setup->K;
1443
1444
1445
  //printf( "%lu beg Solve\n", node->treelist_id ); fflush( stdout );
1446
1447
  /** TODO: need to decide to use LU or not */
1448
  if ( node->isleaf )
1449
  {
1450
    auto &b = data.bview;
1451
    data.Solve( b );
1452
    //printf( "Solve %lu, m %lu n %lu\n", node->treelist_id, b.row(), b.col() );
1453
  }
1454
  else
1455
  {
1456
    auto &bl = node->lchild->data.bview;
1457
    auto &br = node->rchild->data.bview;
1458
    data.Solve( bl, br );
1459
    //printf( "Solve %lu, m %lu n %lu\n", node->treelist_id, bl.row(), bl.col() );
1460
  }
1461
1462
  //printf( "%lu end Solve\n", node->treelist_id ); fflush( stdout );
1463
1464
}; /** end Solve() */
1465
1466
1467
/**
1468
 *  @brief
1469
 */
1470
template<typename NODE, typename T>
1471
class SolveTask : public Task
1472
{
1473
  public:
1474
1475
    NODE *arg = NULL;
1476
1477
    void Set( NODE *user_arg )
1478
    {
1479
      arg = user_arg;
1480
      name = string( "sl" );
1481
      label = to_string( arg->treelist_id );
1482
      cost = 1.0;
1483
1484
      //printf( "Set treelist_id %lu\n", arg->treelist_id ); fflush( stdout );
1485
    };
1486
1487
    void GetEventRecord()
1488
    {
1489
      double flops = 0.0, mops = 0.0;
1490
      event.Set( label + name, flops, mops );
1491
    };
1492
1493
    void DependencyAnalysis()
1494
    {
1495
      arg->DependencyAnalysis( RW, this );
1496
      if ( !arg->isleaf )
1497
      {
1498
        arg->lchild->DependencyAnalysis( R, this );
1499
        arg->rchild->DependencyAnalysis( R, this );
1500
      }
1501
    };
1502
1503
    void Execute( Worker* user_worker )
1504
    {
1505
      Solve<NODE, T>( arg );
1506
    };
1507
1508
}; /** end class SolveTask */
1509
1510
1511
/**
1512
 *
1513
 */
1514
template<typename T, typename TREE>
1515
void Solve( TREE &tree, Data<T> &input )
1516
{
1517
  using NODE = typename TREE::NODE;
1518
1519
  const bool AUTO_DEPENDENCY = true;
1520
  const bool USE_RUNTIME     = true;
1521
1522
  /** copy input to output */
1523
  auto *output = new Data<T>( input.row(), input.col() );
1524
1525
  SolverTreeViewTask<NODE>             treeviewtask;
1526
  MatrixPermuteTask<true,  NODE> forwardpermutetask;
1527
  MatrixPermuteTask<false, NODE> inversepermutetask;
1528
  /** Sherman-Morrison-Woodbury */
1529
  SolveTask<NODE, T>      solvetask1;
1530
  /** ULV */
1531
  ULVForwardSolveTask<NODE, T>   ulvforwardsolvetask;
1532
  ULVBackwardSolveTask<NODE, T>  ulvbackwardsolvetask;
1533
1534
  /** attach the pointer to the tree structure */
1535
  tree.setup.input  = &input;
1536
  tree.setup.output = output;
1537
1538
  if ( tree.setup.do_ulv_factorization )
1539
  {
1540
    /** clean up all dependencies on tree nodes */
1541
    tree.DependencyCleanUp();
1542
    tree.TraverseDown( treeviewtask );
1543
    tree.TraverseLeafs( forwardpermutetask );
1544
    tree.TraverseUp( ulvforwardsolvetask );
1545
    tree.TraverseDown( ulvbackwardsolvetask );
1546
    if ( USE_RUNTIME ) hmlp_run();
1547
1548
    /** clean up all dependencies on tree nodes */
1549
    tree.DependencyCleanUp();
1550
    tree.TraverseLeafs( inversepermutetask );
1551
    if ( USE_RUNTIME ) hmlp_run();
1552
  }
1553
  else
1554
  {
1555
    /** clean up all dependencies on tree nodes */
1556
    tree.DependencyCleanUp();
1557
    tree.TraverseDown( treeviewtask );
1558
    tree.TraverseLeafs( forwardpermutetask );
1559
    tree.TraverseUp( solvetask1 );
1560
    if ( USE_RUNTIME ) hmlp_run();
1561
    /** clean up all dependencies on tree nodes */
1562
    tree.DependencyCleanUp();
1563
    tree.TraverseLeafs( inversepermutetask );
1564
    if ( USE_RUNTIME ) hmlp_run();
1565
  }
1566
1567
  /** delete buffer space */
1568
  delete output;
1569
1570
}; /** end Solve() */
1571
1572
1573
1574
1575
1576
/**
1577
 *  @brief Compute relative Forbenius error for two-sided
1578
 *  interpolative decomposition.
1579
 */
1580
template<typename NODE, typename T>
1581
void LowRankError( NODE *node )
1582
{
1583
  auto &data = node->data;
1584
  auto &setup = node->setup;
1585
  auto &K = *setup->K;
1586
1587
  if ( !node->isleaf )
1588
  {
1589
    auto Krl = K( node->rchild->gids, node->lchild->gids );
1590
1591
    auto nrm2 = hmlp_norm( Krl.row(),  Krl.col(),
1592
                           Krl.data(), Krl.row() );
1593
1594
1595
    hmlp::Data<T> VrCrl( data.nr, data.sl );
1596
1597
    /** VrCrl = Vr * Crl */
1598
    xgemm( "N", "N", data.nr, data.sl, data.sr,
1599
        1.0, data.Vr->data(), data.nr,
1600
             data.Crl.data(), data.sr,
1601
        0.0, VrCrl.data(), data.nr );
1602
1603
    /** Krl - VrCrlVl' */
1604
    xgemm( "N", "T", data.nr, data.nl, data.sl,
1605
       -1.0, VrCrl.data(), data.nr,
1606
             data.Vl->data(), data.nl,
1607
        1.0, Krl.data(), data.nr );
1608
1609
    auto err = hmlp_norm( Krl.row(),  Krl.col(),
1610
                          Krl.data(), Krl.row() );
1611
1612
    printf( "%4lu ||Krl -VrCrlVl|| %3.1E\n",
1613
        node->treelist_id, std::sqrt( err / nrm2 ) );
1614
  }
1615
1616
}; /** end LowRankError() */
1617
1618
1619
1620
/**
1621
 *  @brief Factorizarion using LU and SMW
1622
 */
1623
template<typename NODE, typename T>
1624
void Factorize( NODE *node )
1625
{
1626
  auto &data = node->data;
1627
  auto &setup = node->setup;
1628
  auto &K = *setup->K;
1629
  auto &proj = data.proj;
1630
1631
  auto do_ulv_factorization = setup->do_ulv_factorization;
1632
1633
  if ( node->isleaf )
1634
  {
1635
    auto lambda = setup->lambda;
1636
    auto &amap = node->gids;
1637
1638
    /** Evaluate the diagonal block. */
1639
    Data<T> Kaa = K( amap, amap );
1640
1641
    /** Apply the regularization */
1642
    for ( size_t i = 0; i < Kaa.row(); i ++ ) Kaa( i, i ) += lambda;
1643
1644
    if ( do_ulv_factorization )
1645
    {
1646
      /** U = proj */
1647
      data.Telescope( false, data.U, proj );
1648
      /** QR factorization */
1649
      data.Orthogonalization();
1650
      /** LU factorization */
1651
      data.PartialFactorize( Kaa );
1652
    }
1653
    else
1654
    {
1655
      /** LU factorization */
1656
      data.Factorize( Kaa );
1657
      /** U = inv( Kaa ) * proj' */
1658
      data.Telescope( true, data.U, proj );
1659
      /** V = proj' */
1660
      data.Telescope( false, data.V, proj );
1661
    }
1662
  }
1663
  else
1664
  {
1665
    auto &Ul = node->lchild->data.U;
1666
    auto &Vl = node->lchild->data.V;
1667
    auto &Zl = node->lchild->data.Zbr;
1668
    auto &Ur = node->rchild->data.U;
1669
    auto &Vr = node->rchild->data.V;
1670
    auto &Zr = node->rchild->data.Zbr;
1671
1672
    /** Evluate the skeleton rows and columns. */
1673
    auto &amap = node->lchild->data.skels;
1674
    auto &bmap = node->rchild->data.skels;
1675
1676
    /** Get the skeleton rows and columns */
1677
    node->data.Crl = K( bmap, amap );
1678
1679
    if ( do_ulv_factorization )
1680
    {
1681
      if ( !node->data.isroot )
1682
      {
1683
        data.Telescope( false, data.U, proj, Ul, Ur );
1684
        data.Orthogonalization();
1685
      }
1686
      data.PartialFactorize( Zl, Zr, Ul, Ur, Vl, Vr );
1687
    }
1688
    else
1689
    {
1690
      /** SMW factorization (LU or Cholesky) */
1691
      data.Factorize( Ul, Ur, Vl, Vr );
1692
      /** telescope U and V */
1693
      if ( !node->data.isroot )
1694
      {
1695
        /** U = inv( I + UCV' ) * [ Ul; Ur ] * proj' */
1696
        data.Telescope(  true, data.U, proj, Ul, Ur );
1697
        /** V = [ Vl; Vr ] * proj' */
1698
        data.Telescope( false, data.V, proj, Vl, Vr );
1699
      }
1700
    }
1701
  }
1702
1703
1704
1705
1706
1707
1708
1709
1710
//    /** SMW factorization (LU or Cholesky) */
1711
//    data.Factorize<true>( Ul, Ur, Vl, Vr );
1712
//
1713
//    /** telescope U and V */
1714
//    if ( !node->data.isroot )
1715
//    {
1716
//      if ( do_ulv_factorization )
1717
//      {
1718
//        data.Telescope( true, data.U, proj, Ul, Ur );
1719
//        data.Orthogonalization();
1720
//      }
1721
//      else
1722
//      {
1723
//        /** U = inv( I + UCV' ) * [ Ul; Ur ] * proj' */
1724
//        data.Telescope( true, data.U, proj, Ul, Ur );
1725
//        /** V = [ Vl; Vr ] * proj' */
1726
//        data.Telescope( false, data.V, proj, Vl, Vr );
1727
//      }
1728
//    }
1729
//    else
1730
//    {
1731
//      /** output Crl from children */
1732
//
1733
//      //size_t L = 3;
1734
//
1735
//      auto *cl = node->lchild;
1736
//      auto *cr = node->rchild;
1737
//      auto *c1 = cl->lchild;
1738
//      auto *c2 = cl->rchild;
1739
//      auto *c3 = cr->lchild;
1740
//      auto *c4 = cr->rchild;
1741
//
1742
//      //hmlp::Data<T> C21 = K( c2->data.skels, c1->data.skels );
1743
//      //hmlp::Data<T> C31 = K( c3->data.skels, c1->data.skels );
1744
//      //hmlp::Data<T> C41 = K( c4->data.skels, c1->data.skels );
1745
//      //hmlp::Data<T> C32 = K( c3->data.skels, c2->data.skels );
1746
//      //hmlp::Data<T> C42 = K( c4->data.skels, c2->data.skels );
1747
//      //hmlp::Data<T> C43 = K( c4->data.skels, c3->data.skels );
1748
//
1749
//      //C21.WriteFile( "C21.m" );
1750
//      //C31.WriteFile( "C31.m" );
1751
//      //C41.WriteFile( "C41.m" );
1752
//      //C32.WriteFile( "C32.m" );
1753
//      //C42.WriteFile( "C42.m" );
1754
//      //C43.WriteFile( "C43.m" );
1755
//
1756
//
1757
//      //hmlp::Data<T> V11( c1->data.V.col(), c1->data.V.col() );
1758
//      //hmlp::Data<T> V22( c2->data.V.col(), c2->data.V.col() );
1759
//      //hmlp::Data<T> V33( c3->data.V.col(), c3->data.V.col() );
1760
//      //hmlp::Data<T> V44( c4->data.V.col(), c4->data.V.col() );
1761
//
1762
//      //xgemm( "T", "N", c1->data.V.col(), c1->data.V.col(), c1->data.V.row(),
1763
//      //    1.0, c1->data.V.data(), c1->data.V.row(),
1764
//      //         c1->data.V.data(), c1->data.V.row(),
1765
//      //    0.0,        V11.data(), V11.row() );
1766
//
1767
//      //xgemm( "T", "N", c2->data.V.col(), c2->data.V.col(), c2->data.V.row(),
1768
//      //    1.0, c2->data.V.data(), c2->data.V.row(),
1769
//      //         c2->data.V.data(), c2->data.V.row(),
1770
//      //    0.0,        V22.data(), V22.row() );
1771
//
1772
//      //xgemm( "T", "N", c3->data.V.col(), c3->data.V.col(), c3->data.V.row(),
1773
//      //    1.0, c3->data.V.data(), c3->data.V.row(),
1774
//      //         c3->data.V.data(), c3->data.V.row(),
1775
//      //    0.0,        V33.data(), V33.row() );
1776
//
1777
//      //xgemm( "T", "N", c4->data.V.col(), c4->data.V.col(), c4->data.V.row(),
1778
//      //    1.0, c4->data.V.data(), c4->data.V.row(),
1779
//      //         c4->data.V.data(), c4->data.V.row(),
1780
//      //    0.0,        V44.data(), V44.row() );
1781
//
1782
//      //V11.WriteFile( "V11.m" );
1783
//      //V22.WriteFile( "V22.m" );
1784
//      //V33.WriteFile( "V33.m" );
1785
//      //V44.WriteFile( "V44.m" );
1786
//    }
1787
//    //printf( "end inner forward telescoping\n" ); fflush( stdout );
1788
//
1789
//    /** check the offdiagonal block VrCrlVl' accuracy */
1790
//    if ( !do_ulv_factorization )
1791
//      LowRankError<NODE, T>( node );
1792
//  }
1793
1794
}; /** end void Factorize() */
1795
1796
1797
1798
/**
1799
 *  @brief
1800
 */
1801
template<typename NODE, typename T>
1802
class FactorizeTask : public Task
1803
{
1804
  public:
1805
1806
    NODE *arg = NULL;
1807
1808
    void Set( NODE *user_arg )
1809
    {
1810
      arg = user_arg;
1811
      name = string( "fa" );
1812
      label = to_string( arg->treelist_id );
1813
      // Need an accurate cost model.
1814
      cost = 1.0;
1815
    };
1816
1817
    void GetEventRecord()
1818
    {
1819
      double flops = 0.0, mops = 0.0;
1820
      event.Set( label + name, flops, mops );
1821
    };
1822
1823
    void DependencyAnalysis() { arg->DependOnChildren( this ); };
1824
1825
    void Execute( Worker* user_worker ) { Factorize<NODE, T>( arg ); };
1826
1827
}; /** end class FactorizeTask */
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
/** @biref Top-level factorization routine. */
1842
template<typename T, typename TREE>
1843
void Factorize( TREE &tree, T lambda )
1844
{
1845
  using NODE = typename TREE::NODE;
1846
1847
  /** Clean up all dependencies on tree nodes. */
1848
  tree.DependencyCleanUp();
1849
1850
  /** Regularization parameter lambda. */
1851
  tree.setup.lambda = lambda;
1852
1853
  /** Perform ULV factorization. */
1854
  tree.setup.do_ulv_factorization = true;
1855
1856
  /** Setup  */
1857
  SetupFactorTask<NODE, T> setupfactortask;
1858
  tree.TraverseUp( setupfactortask );
1859
  tree.ExecuteAllTasks();
1860
1861
  /** Factorization */
1862
  FactorizeTask<NODE, T> factorizetask;
1863
  tree.TraverseUp( factorizetask );
1864
  tree.ExecuteAllTasks();
1865
1866
}; /** end Factorize() */
1867
1868
1869
1870
/**
1871
 *  @brief Compute the average 2-norm error. That is given
1872
 *         lambda and weights,
1873
 */
1874
template<typename TREE, typename T>
1875
void ComputeError( TREE &tree, T lambda, Data<T> weights, Data<T> potentials )
1876
{
1877
  using NODE = typename TREE::NODE;
1878
1879
1880
  /** assure the dimension matches */
1881
  assert( weights.row() == potentials.row() );
1882
  assert( weights.col() == potentials.col() );
1883
1884
  size_t n    = weights.row();
1885
  size_t nrhs = weights.col();
1886
1887
  /** shift lambda and make it a column vector */
1888
  Data<T> rhs( n, nrhs );
1889
  for ( size_t j = 0; j < nrhs; j ++ )
1890
    for ( size_t i = 0; i < n; i ++ )
1891
      rhs( i, j ) = potentials( i, j ) + lambda * weights( i, j );
1892
1893
  /** potentials = inv( K + lambda * I ) * potentials */
1894
  Solve( tree, rhs );
1895
1896
1897
  /** Compute relative error = sqrt( err / nrm2 ) for each rhs */
1898
  printf( "========================================================\n" );
1899
  printf( "Inverse accuracy report\n" );
1900
  printf( "========================================================\n" );
1901
  printf( "#rhs,  max err,        @,  min err,        @,  relative \n" );
1902
  printf( "========================================================\n" );
1903
  size_t ntest = 10;
1904
  T total_err  = 0.0;
1905
  for ( size_t j = 0; j < std::min( nrhs, ntest ); j ++ )
1906
  {
1907
    /** counters */
1908
    T nrm2 = 0.0, err2 = 0.0;
1909
    T max2 = 0.0, min2 = std::numeric_limits<T>::max();
1910
    /** indecies */
1911
    size_t maxi = 0, mini = 0;
1912
1913
    for ( size_t i = 0; i < n; i ++ )
1914
    {
1915
      T sse = rhs( i, j ) - weights( i, j );
1916
      assert( rhs( i, j ) == rhs( i, j ) );
1917
      sse = sse * sse;
1918
1919
      nrm2 += weights( i, j ) * weights( i, j );
1920
      err2 += sse;
1921
1922
      //printf( "%lu %3.1E\n", i, sse );
1923
1924
1925
      if ( sse > max2 ) { max2 = sse; maxi = i; }
1926
      if ( sse < min2 ) { min2 = sse; mini = i; }
1927
    }
1928
    total_err += std::sqrt( err2 / nrm2 );
1929
1930
    printf( "%4lu,  %3.1E,  %7lu,  %3.1E,  %7lu,   %3.1E\n",
1931
        j, std::sqrt( max2 ), maxi, std::sqrt( min2 ), mini,
1932
        std::sqrt( err2 / nrm2 ) );
1933
  }
1934
  printf( "========================================================\n" );
1935
  printf( "                             avg over %2lu rhs,   %3.1E \n",
1936
      std::min( nrhs, ntest ), total_err / std::min( nrhs, ntest ) );
1937
  printf( "========================================================\n\n" );
1938
1939
}; /** end ComputeError() */
1940
1941
1942
1943
1944
1945
1946
1947
1948
}; /** end namespace gofmm */
1949
}; /** end namespace hmlp */
1950
1951
#endif /** define IGOFMM_HPP */