GCC Code Coverage Report
Directory: . Exec Total Coverage
File: frame/containers/KernelMatrix.hpp Lines: 0 188 0.0 %
Date: 2019-01-14 Branches: 0 308 0.0 %

Line Exec Source
1
/**
2
 *  HMLP (High-Performance Machine Learning Primitives)
3
 *
4
 *  Copyright (C) 2014-2017, The University of Texas at Austin
5
 *
6
 *  This program is free software: you can redistribute it and/or modify
7
 *  it under the terms of the GNU General Public License as published by
8
 *  the Free Software Foundation, either version 3 of the License, or
9
 *  (at your option) any later version.
10
 *
11
 *  This program is distributed in the hope that it will be useful,
12
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
 *  GNU General Public License for more details.
15
 *
16
 *  You should have received a copy of the GNU General Public License
17
 *  along with this program. If not, see the LICENSE file.
18
 *
19
 **/
20
21
#ifndef KERNELMATRIX_HPP
22
#define KERNELMATRIX_HPP
23
24
/** Using tgamma, M_PI, M_SQRT2 ... */
25
#include <cmath>
26
/** BLAS/LAPACK support. */
27
#include <base/blas_lapack.hpp>
28
/** KernelMatrix uses VirtualMatrix<T> as base. */
29
#include <containers/VirtualMatrix.hpp>
30
/** DistData is used to store the data points. */
31
#include <DistData.hpp>
32
33
using namespace std;
34
using namespace hmlp;
35
36
namespace hmlp
37
{
38
39
typedef enum
40
{
41
  GAUSSIAN,
42
  SIGMOID,
43
  POLYNOMIAL,
44
  LAPLACE,
45
  GAUSSIAN_VAR_BANDWIDTH,
46
  TANH,
47
  QUARTIC,
48
  MULTIQUADRATIC,
49
  EPANECHNIKOV,
50
  USER_DEFINE
51
} kernel_type;
52
53
template<typename T, typename TP>
54
struct kernel_s
55
{
56
  kernel_type type;
57
58
  /** Compute a single inner product. */
59
  static inline T innerProduct( const TP* x, const TP* y, size_t d )
60
  {
61
    T accumulator = 0.0;
62
    #pragma omp parallel for reduction(+:accumulator)
63
    for ( size_t i = 0; i < d; i ++ ) accumulator += x[ i ] * y[ i ];
64
    return accumulator;
65
  }
66
67
  /** Compute all pairwise inner products using GEMM_TN( 1.0, X, Y, 0.0, K ). */
68
  static inline void innerProducts( const TP* X, const TP* Y, size_t d, T* K, size_t m, size_t n )
69
  {
70
    /** This BLAS function is defined in frame/base/blas_lapack.hpp.  */
71
    xgemm( "Transpose", "No-transpose", (int)m, (int)n, (int)d, (T)1.0, X, (int)d, Y, (int)d, (T)0.0, K, (int)m );
72
  }
73
74
  /** Compute a single squared distance. */
75
  static inline T squaredDistance( const TP* x, const TP* y, size_t d )
76
  {
77
    T accumulator = 0.0;
78
    #pragma omp parallel for reduction(+:accumulator)
79
    for ( size_t i = 0; i < d; i ++ ) accumulator += ( x[ i ] - y[ i ] ) * ( x[ i ] - y[ i ] );
80
    return accumulator;
81
  }
82
83
  /** Compute all pairwise squared distances. */
84
  static inline void squaredDistances( const TP* X, const TP* Y, size_t d, T* K, size_t m, size_t n )
85
  {
86
    innerProducts( X, Y, d, K, m, n );
87
    vector<T> squaredNrmX( m, 0 ), squaredNrmY( n, 0 );
88
    /** This BLAS function is defined in frame/base/blas_lapack.hpp.  */
89
    #pragma omp parallel for
90
    for ( size_t i = 0; i < m; i ++ ) squaredNrmX[ i ] = xdot( d, X + i * d, 1, X + i * d, 1 );
91
    #pragma omp parallel for
92
    for ( size_t j = 0; j < n; j ++ ) squaredNrmY[ j ] = xdot( d, Y + j * d, 1, Y + j * d, 1 );
93
    #pragma omp parallel for collapse(2)
94
    for ( size_t j = 0; j < n; j ++ )
95
      for ( size_t i = 0; i < m; i ++ )
96
        K[ j * m + i ] = squaredNrmX[ i ] - 2 * K[ j * m + i ] + squaredNrmY[ j ];
97
  }
98
99
  inline T operator () ( const void* param, const TP* x, const TP* y, size_t d ) const
100
  {
101
    switch ( type )
102
    {
103
      case GAUSSIAN:
104
        return exp( scal * squaredDistance( x, y, d ) );
105
      case SIGMOID:
106
        return tanh( scal* innerProduct( x, y, d ) + cons );
107
      case LAPLACE:
108
        return 0;
109
      case QUARTIC:
110
        return 0;
111
      case USER_DEFINE:
112
        return user_element_function( param, x, y, d );
113
      default:
114
        printf( "invalid kernel type\n" );
115
        exit( 1 );
116
    } /** end switch ( type ) */
117
  };
118
119
  /** X should be at least d-by-m, and Y should be at least d-by-n. */
120
  inline void operator () ( const void* param, const TP* X, const TP* Y, size_t d, T* K, size_t m, size_t n ) const
121
  {
122
    switch ( type )
123
    {
124
      case GAUSSIAN:
125
        squaredDistances( X, Y, d, K, m, n );
126
        #pragma omp parallel for
127
        for ( size_t i = 0; i < m * n; i ++ ) K[ i ] = exp( scal * K[ i ] );
128
        break;
129
      case SIGMOID:
130
        innerProducts( X, Y, d, K, m, n );
131
        #pragma omp parallel for
132
        for ( size_t i = 0; i < m * n; i ++ ) K[ i ] = tanh( scal * K[ i ] + cons );
133
        break;
134
      case LAPLACE:
135
        break;
136
      case QUARTIC:
137
        break;
138
      case USER_DEFINE:
139
        user_matrix_function( param, X, Y, d, K, m, n );
140
        break;
141
      default:
142
        printf( "invalid kernel type\n" );
143
        exit( 1 );
144
    } /** end switch ( type ) */
145
  };
146
147
  T powe = 1;
148
  T scal = 1;
149
  T cons = 0;
150
  T *hi;
151
  T *hj;
152
  T *h;
153
154
  /** User-defined kernel functions. */
155
  T (*user_element_function)( const void* param, const TP* x, const TP* y, size_t d ) = nullptr;
156
  void (*user_matrix_function)( const void* param, const T* X, const T* Y, size_t d, T* K,  size_t m, size_t n ) = nullptr;
157
158
};
159
160
161
template<typename T, class Allocator = std::allocator<T>>
162
class KernelMatrix : public VirtualMatrix<T, Allocator>,
163
                     public ReadWrite
164
{
165
  public:
166
167
    /** (Default) constructor for non-symmetric kernel matrices. */
168
    KernelMatrix( size_t m_, size_t n_, size_t d_, kernel_s<T, T> &kernel_,
169
        Data<T> &sources_, Data<T> &targets_ )
170
      : VirtualMatrix<T>( m_, n_ ), d( d_ ),
171
        sources( sources_ ), targets( targets_ ),
172
        kernel( kernel_ ), all_dimensions( d_ )
173
    {
174
      this->is_symmetric = false;
175
      for ( size_t i = 0; i < d; i ++ ) all_dimensions[ i ] = i;
176
    };
177
178
    /** (Default) constructor for symmetric kernel matrices. */
179
    KernelMatrix( size_t m, size_t n, size_t d, kernel_s<T, T>& kernel, Data<T> &sources )
180
      : KernelMatrix( m, n, d, kernel, sources, sources )
181
    {
182
      assert( m == n );
183
      this->is_symmetric = true;
184
    };
185
186
    KernelMatrix( Data<T> &sources )
187
      : sources( sources ),
188
        targets( sources ),
189
        VirtualMatrix<T>( sources.col(), sources.col() ),
190
        all_dimensions( sources.row() )
191
    {
192
      this->is_symmetric = true;
193
      this->d = sources.row();
194
      this->kernel.type = GAUSSIAN;
195
      this->kernel.scal = -0.5;
196
      for ( size_t i = 0; i < d; i ++ ) all_dimensions[ i ] = i;
197
    };
198
199
    /** (Default) destructor. */
200
    ~KernelMatrix() {};
201
202
		/** ESSENTIAL: override the virtual function */
203
    virtual T operator()( size_t i, size_t j ) override
204
    {
205
      return kernel( nullptr, targets.columndata( i ), sources.columndata( j ), d );
206
		};
207
208
    /** (Overwrittable) ESSENTIAL: return K( I, J ) */
209
    virtual Data<T> operator() ( const vector<size_t>& I, const vector<size_t>& J ) override
210
    {
211
      Data<T> KIJ( I.size(), J.size() );
212
			/** Early return if possible. */
213
			if ( !I.size() || !J.size() ) return KIJ;
214
			/** Request for coordinates: A (targets), B (sources). */
215
      Data<T> X = ( is_symmetric ) ? sources( all_dimensions, I ) : targets( all_dimensions, I );
216
      Data<T> Y = sources( all_dimensions, J );
217
      /** Evaluate KIJ using legacy interface. */
218
      kernel( nullptr, X.data(), Y.data(), d, KIJ.data(), I.size(), J.size() );
219
      /** Return K( I, J ). */
220
      return KIJ;
221
    };
222
223
224
    virtual Data<T> GeometryDistances( const vector<size_t>& I, const vector<size_t>& J ) override
225
    {
226
      Data<T> KIJ( I.size(), J.size() );
227
			/** Early return if possible. */
228
			if ( !I.size() || !J.size() ) return KIJ;
229
			/** Request for coordinates: A (targets), B (sources). */
230
      Data<T> A, B;
231
      /** For symmetry matrices, extract from sources. */
232
      if ( is_symmetric )
233
      {
234
        A = sources( all_dimensions, I );
235
        B = sources( all_dimensions, J );
236
      }
237
      /** For non-symmetry matrices, extract from targets and sources. */
238
      else
239
      {
240
        A = targets( all_dimensions, I );
241
        B = sources( all_dimensions, J );
242
      }
243
      /** Compute inner products. */
244
      xgemm( "Transpose", "No-transpose", I.size(), J.size(), d,
245
        -2.0, A.data(), A.row(),
246
              B.data(), B.row(),
247
         0.0, KIJ.data(), KIJ.row() );
248
      /** Compute square 2-norms. */
249
      vector<T> A2( I.size() ), B2( J.size() );
250
251
      #pragma omp parallel for
252
      for ( size_t i = 0; i < I.size(); i ++ )
253
      {
254
        A2[ i ] = xdot( d, A.columndata( i ), 1, A.columndata( i ), 1 );
255
      } /** end omp parallel for */
256
257
      #pragma omp parallel for
258
      for ( size_t j = 0; j < J.size(); j ++ )
259
      {
260
        B2[ j ] = xdot( d, B.columndata( j ), 1, B.columndata( j ), 1 );
261
      } /** end omp parallel for */
262
263
      /** Add square norms to inner products to get square distances. */
264
      #pragma omp parallel for
265
      for ( size_t j = 0; j < J.size(); j ++ )
266
        for ( size_t i = 0; i < I.size(); i ++ )
267
          KIJ( i, j ) += A2[ i ] + B2[ j ];
268
      /** Return all pair-wise distances. */
269
      return KIJ;
270
    }; /** end GeometryDistances() */
271
272
273
    /** get the diagonal of KII, i.e. diag( K( I, I ) ) */
274
    Data<T> Diagonal( vector<size_t> &I )
275
    {
276
      /**
277
       *  return values
278
       */
279
      Data<T> DII( I.size(), 1, 0.0 );
280
281
      /** at this moment we already have the corrdinates on this process */
282
      switch ( kernel.type )
283
      {
284
        case GAUSSIAN:
285
        {
286
          for ( size_t i = 0; i < I.size(); i ++ ) DII[ i ] = 1.0;
287
          break;
288
        }
289
        default:
290
        {
291
          printf( "invalid kernel type\n" );
292
          exit( 1 );
293
          break;
294
        }
295
      }
296
297
      return DII;
298
299
    };
300
301
    /** important sampling */
302
    pair<T, size_t> ImportantSample( size_t j )
303
    {
304
      size_t i = std::rand() % this->col();
305
      pair<T, size_t> sample( (*this)( i, j ), i );
306
      return sample;
307
    };
308
309
    void Print()
310
    {
311
      for ( size_t j = 0; j < this->col(); j ++ )
312
      {
313
        printf( "%8lu ", j );
314
      }
315
      printf( "\n" );
316
      for ( size_t i = 0; i < this->row(); i ++ )
317
      {
318
        for ( size_t j = 0; j < this->col(); j ++ )
319
        {
320
          printf( "% 3.1E ", (*this)( i, j ) );
321
        }
322
        printf( "\n" );
323
      }
324
    }; /** end Print() */
325
326
    /** Return number of attributes. */
327
    size_t dim() { return d; };
328
329
    /** flops required for Kab */
330
    double flops( size_t na, size_t nb )
331
    {
332
      double flopcount = 0.0;
333
334
      switch ( kernel.type )
335
      {
336
        case GAUSSIAN:
337
        {
338
          flopcount = na * nb * ( 2.0 * d + 35.0 );
339
          break;
340
        }
341
        default:
342
        {
343
          printf( "invalid kernel type\n" );
344
          exit( 1 );
345
          break;
346
        }
347
      }
348
      return flopcount;
349
    };
350
351
  private:
352
353
    bool is_symmetric = true;
354
355
    size_t d = 0;
356
357
    Data<T> &sources;
358
359
    Data<T> &targets;
360
361
    /** legacy data structure */
362
    kernel_s<T, T> kernel;
363
    /** [ 0, 1, ..., d-1 ] */
364
    vector<size_t> all_dimensions;
365
366
}; /** end class KernelMatrix */
367
368
369
370
template<typename T, typename TP, class Allocator = std::allocator<TP>>
371
class DistKernelMatrix : public DistVirtualMatrix<T, Allocator>,
372
                         public ReadWrite
373
{
374
  public:
375
376
    /** (Default) unsymmetric kernel matrix */
377
    DistKernelMatrix
378
    (
379
      size_t m, size_t n, size_t d,
380
      /** by default we assume sources are distributed in [STAR, CBLK] */
381
      DistData<STAR, CBLK, TP> &sources,
382
      /** by default we assume targets are distributed in [STAR, CBLK] */
383
      DistData<STAR, CBLK, TP> &targets,
384
      mpi::Comm comm
385
    )
386
    : all_dimensions( d ), sources_user( sources ), targets_user( targets ),
387
      DistVirtualMatrix<T>( m, n, comm )
388
    {
389
      this->is_symmetric = false;
390
      this->d = d;
391
      this->sources = &sources;
392
      this->targets = &targets;
393
      for ( size_t i = 0; i < d; i ++ ) all_dimensions[ i ] = i;
394
    };
395
396
    /** Unsymmetric kernel matrix with legacy kernel_s<T> */
397
    DistKernelMatrix
398
    (
399
      size_t m, size_t n, size_t d,
400
      kernel_s<T, TP> &kernel,
401
      /** by default we assume sources are distributed in [STAR, CBLK] */
402
      DistData<STAR, CBLK, TP> &sources,
403
      /** by default we assume targets are distributed in [STAR, CBLK] */
404
      DistData<STAR, CBLK, TP> &targets,
405
      mpi::Comm comm
406
    )
407
    : DistKernelMatrix( m, n, d, sources, targets, comm )
408
    {
409
      this->kernel = kernel;
410
    };
411
412
    /** (Default) symmetric kernel matrix */
413
    DistKernelMatrix
414
    (
415
      size_t n, size_t d,
416
      /** by default we assume sources are distributed in [STAR, CBLK] */
417
      DistData<STAR, CBLK, TP> &sources,
418
      mpi::Comm comm
419
    )
420
    : all_dimensions( d ), sources_user( sources ), targets_user( d, 0, comm ),
421
      DistVirtualMatrix<T>( n, n, comm )
422
    {
423
      this->is_symmetric = true;
424
      this->d = d;
425
      this->sources = &sources;
426
      for ( size_t i = 0; i < d; i ++ ) all_dimensions[ i ] = i;
427
    };
428
429
    /** Symmetric kernel matrix with legacy kernel_s<T> */
430
    DistKernelMatrix
431
    (
432
      size_t n, size_t d,
433
      kernel_s<T, TP> &kernel,
434
      /** by default we assume sources are distributed in [STAR, CBLK] */
435
      DistData<STAR, CBLK, TP> &sources,
436
      mpi::Comm comm
437
    )
438
    : DistKernelMatrix( n, d, sources, comm )
439
    {
440
      this->kernel = kernel;
441
    };
442
443
444
    DistKernelMatrix( DistData<STAR, CBLK, TP> &sources, mpi::Comm comm )
445
    : DistKernelMatrix( sources.col(), sources.row(), sources, comm )
446
    {
447
      this->kernel.type = GAUSSIAN;
448
      this->kernel.scal = -0.5;
449
    };
450
451
    /** (Default) destructor */
452
    ~DistKernelMatrix() {};
453
454
    /** (Overwrittable) Request a single Kij */
455
    virtual T operator () ( size_t i, size_t j ) override
456
    {
457
      TP* x = ( is_symmetric ) ? sources_user.columndata( i ) : targets_user.columndata( i );
458
      TP* y = sources_user.columndata( j );
459
      return kernel( nullptr, x, y, d );
460
    }; /** end operator () */
461
462
463
    /** (Overwrittable) ESSENTIAL: return K( I, J ) */
464
    virtual Data<T> operator() ( const vector<size_t>& I, const vector<size_t>& J ) override
465
    {
466
      Data<T> KIJ( I.size(), J.size() );
467
			/** Early return if possible. */
468
			if ( !I.size() || !J.size() ) return KIJ;
469
			/** Request for coordinates: A (targets), B (sources). */
470
      Data<T> X = ( is_symmetric ) ? sources_user( all_dimensions, I ) : targets_user( all_dimensions, I );
471
      Data<T> Y = sources_user( all_dimensions, J );
472
      kernel( nullptr, X.data(), Y.data(), d, KIJ.data(), I.size(), J.size() );
473
      return KIJ;
474
    };
475
476
477
    /** */
478
    virtual Data<T> GeometryDistances( const vector<size_t>& I, const vector<size_t>& J ) override
479
    {
480
      Data<T> KIJ( I.size(), J.size() );
481
			/** Early return if possible. */
482
			if ( !I.size() || !J.size() ) return KIJ;
483
			/** Request for coordinates: A (targets), B (sources). */
484
      Data<TP> A, B;
485
      /** For symmetry matrices, extract from sources. */
486
      if ( is_symmetric )
487
      {
488
        A = sources_user( all_dimensions, I );
489
        B = sources_user( all_dimensions, J );
490
      }
491
      /** For non-symmetry matrices, extract from targets and sources. */
492
      else
493
      {
494
        A = targets_user( all_dimensions, I );
495
        B = sources_user( all_dimensions, J );
496
      }
497
      /** Compute inner products. */
498
      xgemm( "Transpose", "No-transpose", I.size(), J.size(), d,
499
        -2.0, A.data(), A.row(),
500
              B.data(), B.row(),
501
         0.0, KIJ.data(), KIJ.row() );
502
      /** Compute square 2-norms. */
503
      vector<TP> A2( I.size() ), B2( J.size() );
504
505
      #pragma omp parallel for
506
      for ( size_t i = 0; i < I.size(); i ++ )
507
      {
508
        A2[ i ] = xdot( d, A.columndata( i ), 1, A.columndata( i ), 1 );
509
      } /** end omp parallel for */
510
511
      #pragma omp parallel for
512
      for ( size_t j = 0; j < J.size(); j ++ )
513
      {
514
        B2[ j ] = xdot( d, B.columndata( j ), 1, B.columndata( j ), 1 );
515
      } /** end omp parallel for */
516
517
      /** Add square norms to inner products to get square distances. */
518
      #pragma omp parallel for
519
      for ( size_t j = 0; j < J.size(); j ++ )
520
        for ( size_t i = 0; i < I.size(); i ++ )
521
          KIJ( i, j ) += A2[ i ] + B2[ j ];
522
      /** Return all pair-wise distances. */
523
      return KIJ;
524
    }; /** end GeometryDistances() */
525
526
527
    /** Get the diagonal of KII, i.e. diag( K( I, I ) ). */
528
    Data<T> Diagonal( vector<size_t> &I )
529
    {
530
      /** MPI */
531
      int size = this->Comm_size();
532
      int rank = this->Comm_rank();
533
534
      /**
535
       *  return values
536
       *
537
       *  NOTICE: even KIJ can be an 0-by-0 matrix for this MPI rank,
538
       *  yet early return is not allowed. All MPI process must execute
539
       *  all collaborative communication routines to avoid deadlock.
540
       */
541
      Data<T> DII( I.size(), 1, 0.0 );
542
543
      /** at this moment we already have the corrdinates on this process */
544
      switch ( kernel.type )
545
      {
546
        case GAUSSIAN:
547
        {
548
          for ( size_t i = 0; i < I.size(); i ++ ) DII[ i ] = 1.0;
549
          break;
550
        }
551
        case LAPLACE:
552
        {
553
          for ( size_t i = 0; i < I.size(); i ++ ) DII[ i ] = 1.0;
554
          break;
555
        }
556
        case QUARTIC:
557
        {
558
          for ( size_t i = 0; i < I.size(); i ++ ) DII[ i ] = 1.0;
559
          break;
560
        }
561
        default:
562
        {
563
          printf( "invalid kernel type\n" );
564
          exit( 1 );
565
          break;
566
        }
567
      }
568
569
      return DII;
570
571
    }; /** end Diagonal() */
572
573
574
    /** Important sampling */
575
    pair<T, size_t> ImportantSample( size_t j )
576
    {
577
      size_t i = std::rand() % this->col();
578
      while( !sources_user.HasColumn( i ) ) i = std::rand() % this->col();
579
      assert( sources_user.HasColumn( i ) );
580
      pair<T, size_t> sample( 0, i );
581
      return sample;
582
    };
583
584
    void Print()
585
    {
586
      for ( size_t j = 0; j < this->col(); j ++ )
587
      {
588
        printf( "%8lu ", j );
589
      }
590
      printf( "\n" );
591
      for ( size_t i = 0; i < this->row(); i ++ )
592
      {
593
        for ( size_t j = 0; j < this->col(); j ++ )
594
        {
595
          printf( "% 3.1E ", (*this)( i, j ) );
596
        }
597
        printf( "\n" );
598
      }
599
    }; /** end Print() */
600
601
    /** return number of attributes */
602
    size_t dim() { return d; };
603
604
    /** flops required for Kab */
605
    double flops( size_t na, size_t nb )
606
    {
607
      double flopcount = 0.0;
608
609
      switch ( kernel.type )
610
      {
611
        case GAUSSIAN:
612
        {
613
          flopcount = na * nb * ( 2.0 * d + 35.0 );
614
          break;
615
        }
616
        default:
617
        {
618
          printf( "invalid kernel type\n" );
619
          exit( 1 );
620
          break;
621
        }
622
      }
623
      return flopcount;
624
    };
625
626
    void SendIndices( vector<size_t> ids, int dest, mpi::Comm comm )
627
    {
628
      auto param = sources_user( all_dimensions, ids );
629
      mpi::SendVector(   ids, dest, 90, comm );
630
      mpi::SendVector( param, dest, 90, comm );
631
    };
632
633
    void RecvIndices( int src, mpi::Comm comm, mpi::Status *status )
634
    {
635
      vector<size_t> ids;
636
      Data<TP> param;
637
      mpi::RecvVector(   ids, src, 90, comm, status );
638
      mpi::RecvVector( param, src, 90, comm, status );
639
      assert( param.size() == ids.size() * dim() );
640
      param.resize( dim(), param.size() / dim() );
641
      /** Insert into hash table */
642
      sources_user.InsertColumns( ids, param );
643
    };
644
645
    /** Bcast cids from sender for K( :, cids ) evaluation. */
646
    void BcastIndices( vector<size_t> ids, int root, mpi::Comm comm )
647
    {
648
      int rank; mpi::Comm_rank( comm, &rank );
649
      /** Bcast size of cids from root */
650
      size_t recv_size = ids.size();
651
      mpi::Bcast( &recv_size, 1, root, comm );
652
      /** Resize to receive cids and parameters */
653
      Data<TP> param;
654
      if ( rank == root )
655
      {
656
         param = sources_user( all_dimensions, ids );
657
      }
658
      else
659
      {
660
         ids.resize( recv_size );
661
         param.resize( dim(), recv_size );
662
      }
663
      /** Bcast cids and parameters from root */
664
      mpi::Bcast( ids.data(), recv_size, root, comm );
665
      mpi::Bcast( param.data(), dim() * recv_size, root, comm );
666
      /** Insert into hash table */
667
      sources_user.InsertColumns( ids, param );
668
    };
669
670
    void RequestIndices( const vector<vector<size_t>>& ids ) override
671
    {
672
      auto comm = this->GetComm();
673
      auto rank = this->GetCommRank();
674
      auto size = this->GetCommSize();
675
676
      assert( ids.size() == size );
677
678
      vector<vector<size_t>> recv_ids( size );
679
      vector<vector<TP>>     send_para( size );
680
      vector<vector<TP>>     recv_para( size );
681
682
      /** Send out cids request to each rank. */
683
      mpi::AlltoallVector( ids, recv_ids, comm );
684
685
      for ( int p = 0; p < size; p ++ )
686
      {
687
        Data<TP> para = sources_user( all_dimensions, recv_ids[ p ] );
688
        send_para[ p ].insert( send_para[ p ].end(), para.begin(), para.end() );
689
        //send_para[ p ] = sources_user( all_dimensions, recv_ids[ p ] );
690
      }
691
692
      /** Exchange out parameters. */
693
      mpi::AlltoallVector( send_para, recv_para, comm );
694
695
      for ( int p = 0; p < size; p ++ )
696
      {
697
        assert( recv_para[ p ].size() == dim() * ids[ p ].size() );
698
        if ( p != rank && ids[ p ].size() )
699
        {
700
          Data<TP> para;
701
          para.insert( para.end(), recv_para[ p ].begin(), recv_para[ p ].end() );
702
          para.resize( dim(), para.size() / dim() );
703
          sources_user.InsertColumns( ids[ p ], para );
704
        }
705
      }
706
    };
707
708
  private:
709
710
    bool is_symmetric = false;
711
712
    size_t d = 0;
713
714
    /** Legacy data structure for kernel matrices. */
715
    kernel_s<T, TP> kernel;
716
717
    /** Pointers to user provided data points in block cylic distribution. */
718
    DistData<STAR, CBLK, TP> *sources = NULL;
719
    DistData<STAR, CBLK, TP> *targets = NULL;
720
721
    /** For local essential tree [LET]. */
722
    DistData<STAR, USER, TP> sources_user;
723
    DistData<STAR, USER, TP> targets_user;
724
725
    /** [ 0, 1, ..., d-1 ] */
726
    vector<size_t> all_dimensions;
727
}; /** end class DistKernelMatrix */
728
729
}; /** end namespace hmlp */
730
731
#endif /** define KERNELMATRIX_HPP */