31 #include <type_traits> 38 #define HANDLE_ERROR( err ) (hmlp::handleError( err, __FILE__, __LINE__ )) 51 void handleError( hmlpError_t error,
const char* file,
int line );
58 template<
int ALIGN_SIZE,
typename T>
62 #ifdef HMLP_MIC_AVX512 63 int err = hbw_posix_memalign( (
void**)&ptr, (
size_t)ALIGN_SIZE, size * m * n );
65 int err = posix_memalign( (
void**)&ptr, (
size_t)ALIGN_SIZE, size * m * n );
70 printf(
"hmlp_malloc(): posix_memalign() failures\n" );
78 template<
int ALIGN_SIZE,
typename T>
81 return hmlp_malloc<ALIGN_SIZE, T>( n, 1,
sizeof(T) );
90 #ifdef HMLP_MIC_AVX512 91 if ( ptr ) hbw_free( ptr );
93 if ( ptr ) free( ptr );
98 void hmlp_print_binary( T number )
102 for (
int i = 31; i >= 0; i -- )
104 if ( i % 5 ) printf(
" " );
105 else printf(
"%d", i / 5 );
110 for (
size_t i = 0; i <
sizeof(T) * 4; i ++ )
112 if ( number & 1 ) binary[ 31 - i ] =
'1';
113 else binary[ 31 - i ] =
'0';
115 if ( i == 31 )
break;
132 printf(
"%s\n", binary );
144 hmlpOperation_t transX,
157 if ( transX == HMLP_OP_N )
159 *dst_buff = &src_buff[ ( m / x * i ) + ( n / y * j ) * lda ];
163 *dst_buff = &src_buff[ ( m / x * i ) * lda + ( n / y * j ) ];
176 for (
int j = 0; j < n; j ++ )
178 for (
int i = 0; i < m; i ++ )
180 nrm2 += A[ j * lda + i ] * A[ j * lda + i ];
183 return std::sqrt( nrm2 );
187 template<
typename TA,
typename TB>
188 TB hmlp_relative_error
198 std::tuple<int, int, TB> max_error( -1, -1, 0.0 );
199 for (
int j = 0; j < n; j ++ )
201 for (
int i = 0; i < m; i ++ )
203 TA a = A[ j * lda + i ];
204 TB b = B[ j * ldb + i ];
208 if ( r * r > std::get<2>( max_error ) )
210 max_error = std::make_tuple( i, j, r );
214 nrm2 = std::sqrt( nrm2 ) / std::sqrt( nrmB );
218 printf(
"relative error % .2E maxinum elemenwise error % .2E ( %d, %d )\n",
220 std::get<2>( max_error ),
221 std::get<0>( max_error ), std::get<1>( max_error ) );
228 template<
typename TA,
typename TB>
229 TB hmlp_relative_error
232 TA *A,
int lda,
int loa,
233 TB *B,
int ldb,
int lob,
238 for (
int b = 0; b < batchSize; b ++ )
250 printf(
"average relative error % .2E\n", err / batchSize );
266 std::tuple<int, int> err_location( -1, -1 );
268 for (
int j = 0; j < n; j ++ )
270 for (
int i = 0; i < m; i ++ )
272 T a = A[ j * lda + i ];
273 T b = B[ j * ldb + i ];
276 err_location = std::make_tuple( i, j );
283 printf(
"total error count %d\n", error_count );
293 T *A,
int lda,
int loa,
294 T *B,
int ldb,
int lob,
299 for (
int b = 0; b < batchSize; b ++ )
311 printf(
"total error count %d\n", error_count );
318 template<
bool IGNOREZERO=false,
bool COLUMNINDEX=true,
typename T>
319 void hmlp_printmatrix(
int m,
int n, T *A,
int lda )
323 for (
int j = 0; j < n; j ++ )
325 if ( j % 5 == 0 || j == 0 || j == n - 1 )
327 if ( is_same<T, pair<double, size_t>>::value || is_same<T, pair<float, size_t>>::value )
329 printf(
"col[%10d] ", j );
333 printf(
"col[%4d] ", j );
338 if ( is_same<T, pair<double, size_t>>::value || is_same<T, pair<float, size_t>>::value )
349 if ( is_same<T, pair<double, size_t>>::value || is_same<T, pair<float, size_t>>::value )
351 printf(
"===============================================================================\n" );
355 printf(
"===========================================================\n" );
359 for (
int i = 0; i < m; i ++ )
361 for (
int j = 0; j < n; j ++ )
363 if ( is_same<T, pair<double, size_t>>::value )
365 auto* A_pair =
reinterpret_cast<pair<double, size_t>*
>( A );
366 printf(
"(% .1E,%5lu)", (
double) A_pair[ j * lda + i ].first, A_pair[ j * lda + i ].second );
368 else if ( is_same<T, pair<float, size_t>>::value )
370 auto* A_pair =
reinterpret_cast<pair<float, size_t>*
>( A );
371 printf(
"(% .1E,%5lu)", (
double) A_pair[ j * lda + i ].first, A_pair[ j * lda + i ].second );
373 else if ( is_same<T, double>::value )
375 auto* A_double =
reinterpret_cast<double*
>( A );
376 if ( std::fabs( A_double[ j * lda + i ] ) < 1E-15 )
382 printf(
"% .4E ", (
double) A_double[ j * lda + i ] );
385 else if ( is_same<T, double>::value )
387 auto* A_float =
reinterpret_cast<float*
>( A );
388 if ( std::fabs( A_float[ j * lda + i ] ) < 1E-15 )
394 printf(
"% .4E ", (
double) A_float[ j * lda + i ] );
405 static inline int hmlp_ceildiv(
int x,
int y )
407 return ( x + y - 1 ) / y;
410 static inline int hmlp_read_nway_from_env(
const char* env )
413 char* str = getenv( env );
416 number = strtol( str, NULL, 10 );
426 inline void swap( T *x,
int i,
int j )
449 while ( 2 * s + 1 < n )
454 if ( D[ j ] < D[ j + 1 ] ) j ++;
456 if ( D[ s ] < D[ j ] )
459 swap<int>( I, s, j );
467 inline void heap_select
479 for ( i = 0; i < m; i ++ )
481 if ( x[ i ] > D[ 0 ] )
489 heap_adjust<T>( D, 0, r, I );
499 std::pair<T, size_t> *NN
502 while ( 2 * s + 1 < n )
504 size_t j = 2 * s + 1;
507 if ( NN[ j ].first < NN[ j + 1 ].first ) j ++;
509 if ( NN[ s ].first < NN[ j ].first )
511 std::swap( NN[ s ], NN[ j ] );
522 std::pair<T, size_t> *Query,
523 std::pair<T, size_t> *NN
526 for (
size_t i = 0; i < n; i ++ )
528 if ( Query[ i ].first > NN[ 0 ].first )
534 NN[ 0 ] = Query[ i ];
535 HeapAdjust<T>( 0, k, NN );
554 for ( i = 0; i < n - 1; i ++ )
556 for ( j = 0; j < n - 1 - i; j ++ )
558 if ( D[ j ] > D[ j + 1 ] )
560 swap<T>( D, j, j + 1 );
561 swap<int>( I, j, j + 1 );
579 _max = std::numeric_limits<double>::min();
580 _min = std::numeric_limits<double>::max();
592 void Update(
double query )
597 _max = std::max( _max, query );
598 _min = std::min( _min, query );
605 printf(
"num %5lu min %.1E max %.1E avg %.1E\n", _num, _min, _max, _avg );
void handleError(hmlpError_t error, const char *file, int line)
Definition: util.cpp:33
void bubble_sort(int n, T *D, int *I)
A bubble sort for reference.
Definition: util.hpp:546
T * hmlp_malloc(int m, int n, int size)
The default function to allocate memory for HMLP. Memory allocated by this function is aligned...
Definition: util.hpp:59
void HeapSelect(size_t n, size_t k, std::pair< T, size_t > *Query, std::pair< T, size_t > *NN)
Definition: util.hpp:520
void hmlp_acquire_mpart(hmlpOperation_t transX, int m, int n, T *src_buff, int lda, int x, int y, int i, int j, T **dst_buff)
Split into m x n, get the subblock starting from ith row and jth column. (for STRASSEN) ...
Definition: util.hpp:143
void hmlp_free(T *ptr)
Free the aligned memory.
Definition: util.hpp:88
void heap_adjust(T *D, int s, int n, int *I)
This function is called after the root of the heap is replaced by an new candidate. We need to readjust such the heap condition is satisfied.
Definition: util.hpp:440
void swap(T *x, int i, int j)
A swap function. Just in case we do not have one. (for GSKNN)
Definition: util.hpp:426