1 |
|
/** |
2 |
|
* HMLP (High-Performance Machine Learning Primitives) |
3 |
|
* |
4 |
|
* Copyright (C) 2014-2017, The University of Texas at Austin |
5 |
|
* |
6 |
|
* This program is free software: you can redistribute it and/or modify |
7 |
|
* it under the terms of the GNU General Public License as published by |
8 |
|
* the Free Software Foundation, either version 3 of the License, or |
9 |
|
* (at your option) any later version. |
10 |
|
* |
11 |
|
* This program is distributed in the hope that it will be useful, |
12 |
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of |
13 |
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
14 |
|
* GNU General Public License for more details. |
15 |
|
* |
16 |
|
* You should have received a copy of the GNU General Public License |
17 |
|
* along with this program. If not, see the LICENSE file. |
18 |
|
* |
19 |
|
**/ |
20 |
|
|
21 |
|
|
22 |
|
#ifndef HMLP_PACKING_HPP |
23 |
|
#define HMLP_PACKING_HPP |
24 |
|
|
25 |
|
#include <stdio.h> |
26 |
|
|
27 |
|
//#define DEBUG_PACKING 1 |
28 |
|
|
29 |
|
namespace hmlp |
30 |
|
{ |
31 |
|
|
32 |
|
|
33 |
|
|
34 |
|
/** |
35 |
|
* @biref This is the im2col_gpu() functiobn from |
36 |
|
* |
37 |
|
* BVLC/caffe/blob/master/src/caffe/util/im2col.cpp. |
38 |
|
* |
39 |
|
* We slightly modify it. |
40 |
|
* |
41 |
|
*/ |
42 |
|
//template<typename T> |
43 |
|
//void im2col |
44 |
|
//( |
45 |
|
// const T* data_im, size_t channels, |
46 |
|
// size_t height, size_t width, |
47 |
|
// size_t kernel_h, size_t kernel_w, |
48 |
|
// size_t pad_h, size_t pad_w, |
49 |
|
// size_t stride_h, size_t stride_w, |
50 |
|
// size_t dilation_h, size_t dilation_w, |
51 |
|
// T* data_col |
52 |
|
//) |
53 |
|
//{ |
54 |
|
// size_t output_h = ( height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; |
55 |
|
// size_t output_w = ( width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; |
56 |
|
// size_t channel_size = height * width; |
57 |
|
// |
58 |
|
// /** loop over channel, data_im += channel_size */ |
59 |
|
// for ( int channel = channels; channel --; data_im += channel_size ) |
60 |
|
// { |
61 |
|
// for ( int kernel_row = 0; kernel_row < kernel_h; kernel_row ++ ) |
62 |
|
// { |
63 |
|
// for ( int kernel_col = 0; kernel_col < kernel_w; kernel_col ++) |
64 |
|
// { |
65 |
|
// int input_row = -pad_h + kernel_row * dilation_h; |
66 |
|
// for ( int output_rows = output_h; output_rows; output_rows--) |
67 |
|
// { |
68 |
|
// /** zero-padding */ |
69 |
|
// if ( !is_a_ge_zero_and_a_lt_b( input_row, height ) ) |
70 |
|
// { |
71 |
|
// for ( int output_cols = output_w; output_cols; output_cols -- ) |
72 |
|
// { |
73 |
|
// *(data_col++) = 0; |
74 |
|
// } |
75 |
|
// } |
76 |
|
// else |
77 |
|
// { |
78 |
|
// int input_col = -pad_w + kernel_col * dilation_w; |
79 |
|
// for ( int output_col = output_w; output_col; output_col-- ) |
80 |
|
// { |
81 |
|
// if ( is_a_ge_zero_and_a_lt_b( input_col, width ) ) |
82 |
|
// { |
83 |
|
// *(data_col++) = data_im[ input_row * width + input_col ]; |
84 |
|
// } |
85 |
|
// else |
86 |
|
// { |
87 |
|
// *(data_col++) = 0; |
88 |
|
// } |
89 |
|
// input_col += stride_w; |
90 |
|
// } |
91 |
|
// } |
92 |
|
// input_row += stride_h; |
93 |
|
// } |
94 |
|
// } |
95 |
|
// } |
96 |
|
// } |
97 |
|
// |
98 |
|
//}; /** end im2col() */ |
99 |
|
|
100 |
|
|
101 |
|
|
102 |
|
|
103 |
|
|
104 |
|
template<typename T> |
105 |
|
inline void im2col |
106 |
|
( |
107 |
|
int m, int n, // packing buffer size |
108 |
|
T* packX, |
109 |
|
T* X, |
110 |
|
int w0, int h0, int d0, int s, int p, // Image size |
111 |
|
int w1, int h1 |
112 |
|
) |
113 |
|
{ |
114 |
|
int nx = ( w0 - w1 + 2 * p ) / s + 1; |
115 |
|
|
116 |
|
#pragma omp parallel for |
117 |
|
for ( auto y0 = -1 * p; y0 <= h0 - h1 + p; y0 += s ) |
118 |
|
{ |
119 |
|
for ( auto x0 = -1 * p; x0 <= w0 - w1 + p; x0 += s ) |
120 |
|
{ |
121 |
|
auto i = ( ( y0 + p ) / s ) * nx + ( x0 + p ) / s; |
122 |
|
|
123 |
|
//printf( "x0 %d y0 %d i %d\n", x0, y0, i ); |
124 |
|
|
125 |
|
for ( auto j = 0, z = 0, x = 0, y = 0; j < n; j ++ ) |
126 |
|
{ |
127 |
|
auto x1 = x0 + x; |
128 |
|
auto y1 = y0 + y; |
129 |
|
|
130 |
|
if ( 0 <= x1 && x1 < w0 && 0 <= y1 && y1 < h0 ) |
131 |
|
{ |
132 |
|
packX[ i * n + j ] = X[ y1 * w0 * d0 + x1 * d0 + z ]; |
133 |
|
} |
134 |
|
else // zero-paging |
135 |
|
{ |
136 |
|
packX[ i * n + j ] = 0.0; |
137 |
|
} |
138 |
|
|
139 |
|
z ++; |
140 |
|
if ( z >= d0 ) |
141 |
|
{ |
142 |
|
z = 0; x ++; |
143 |
|
} |
144 |
|
if ( x >= w1 ) |
145 |
|
{ |
146 |
|
x = 0; y ++; |
147 |
|
} |
148 |
|
} |
149 |
|
|
150 |
|
} |
151 |
|
} |
152 |
|
}; // end im2col() |
153 |
|
|
154 |
|
|
155 |
|
|
156 |
|
|
157 |
|
/** |
158 |
|
* @brief pack image into 2D packed buffer. Notice that here X is d leading. |
159 |
|
*/ |
160 |
|
template<int FOLD, bool ZEROPAD=true, typename T> |
161 |
|
inline void pack2Dimg |
162 |
|
( |
163 |
|
int m, int n, // packing buffer size |
164 |
|
T* packX, |
165 |
|
int x0, int y0, int offset, // Image pointers |
166 |
|
T *X, // Image |
167 |
|
int w0, int h0, int d0, int s, int p, // Image size |
168 |
|
int w1, int h1 |
169 |
|
) |
170 |
|
{ |
171 |
|
//int x, x1, y, y1, z; |
172 |
|
|
173 |
|
for ( auto i = 0; i < m; i ++ ) |
174 |
|
{ |
175 |
|
// Compute the current x, y, z. |
176 |
|
for ( auto j = 0, |
177 |
|
z = ( offset % d0 ), |
178 |
|
x = ( offset / d0 ) % w1, |
179 |
|
y = ( offset / d0 ) / w1; |
180 |
|
j < n; j ++ ) |
181 |
|
{ |
182 |
|
auto x1 = x0 + x; |
183 |
|
auto y1 = y0 + y; |
184 |
|
|
185 |
|
if ( 0 <= x1 && x1 < w0 && 0 <= y1 && y1 < h0 ) |
186 |
|
{ |
187 |
|
packX[ j * FOLD + i ] = X[ y1 * w0 * d0 + x1 * d0 + z ]; |
188 |
|
} |
189 |
|
else // zero-paging |
190 |
|
{ |
191 |
|
packX[ j * FOLD + i ] = 0.0; |
192 |
|
} |
193 |
|
|
194 |
|
//printf( "( y, x, z ) = ( %2d, %2d, %2d ) %5.2lf\n", y1, x1, z, packX[ j * FOLD + i ] ); |
195 |
|
|
196 |
|
z ++; |
197 |
|
if ( z >= d0 ) |
198 |
|
{ |
199 |
|
z = 0; x ++; |
200 |
|
} |
201 |
|
if ( x >= w1 ) |
202 |
|
{ |
203 |
|
x = 0; y ++; |
204 |
|
} |
205 |
|
} |
206 |
|
|
207 |
|
// move to the next window |
208 |
|
x0 += s; |
209 |
|
if ( ( x0 + w1 ) > ( w0 + p ) ) |
210 |
|
{ |
211 |
|
x0 = -1 * p; y0 += s; |
212 |
|
} |
213 |
|
} |
214 |
|
}; // end pack2Dimg() |
215 |
|
|
216 |
|
|
217 |
|
|
218 |
|
|
219 |
|
/** |
220 |
|
* @brief This is the default packing routine for GKMX, GSKS, |
221 |
|
* GSKNN and STRASSEN. |
222 |
|
*/ |
223 |
|
template<bool TRANS, int FOLD, bool ZEROPAD=false, typename T> |
224 |
|
inline void pack2D |
225 |
|
( |
226 |
|
int m, int n, |
227 |
|
T *X0, T *X1, int ldx, T gamma, int *xmap, T *packX |
228 |
|
) |
229 |
|
{ |
230 |
|
//printf( "X0[0]: %lf, X1[0]: %lf\n", X0[0], X1[0] ); |
231 |
|
T *x0_pntr[ FOLD ]; |
232 |
|
T *x1_pntr[ FOLD ]; |
233 |
|
|
234 |
|
if ( TRANS ) |
235 |
|
{ |
236 |
|
for ( auto i = 0; i < m; i ++ ) |
237 |
|
{ |
238 |
|
x0_pntr[ i ] = X0 + ldx * xmap[ i ]; |
239 |
|
x1_pntr[ i ] = X1 + ldx * xmap[ i ]; |
240 |
|
} |
241 |
|
for ( auto i = m; i < FOLD; i ++ ) |
242 |
|
{ |
243 |
|
x0_pntr[ i ] = X0 + ldx * xmap[ 0 ]; |
244 |
|
x1_pntr[ i ] = X1 + ldx * xmap[ 0 ]; |
245 |
|
} |
246 |
|
for ( auto j = 0; j < n; j ++ ) |
247 |
|
{ |
248 |
|
for ( auto i = 0; i < m; i ++ ) |
249 |
|
{ |
250 |
|
//*packX ++ = (*x0_pntr[ i ] ++) + gamma * (*x1_pntr[ i ] ++) ; |
251 |
|
|
252 |
|
*packX = ( *x0_pntr[ i ] ) + gamma * ( *x1_pntr[ i ] ) ; |
253 |
|
//printf( "TRANS:*x0_pntr[i]:%lf, gamma:%lf, x1_pntr[i]:%lf,packX:%lf\n",*x0_pntr[i], gamma, *x1_pntr[i], *packX); |
254 |
|
packX ++; |
255 |
|
x0_pntr[ i ] += 1; |
256 |
|
x1_pntr[ i ] += 1; |
257 |
|
} |
258 |
|
for ( auto i = m; i < FOLD; i ++ ) |
259 |
|
{ |
260 |
|
if ( ZEROPAD ) *packX ++ = (T)0.0; |
261 |
|
else *packX ++ = (*x0_pntr[ i ] ++) + gamma * (*x1_pntr[ i ] ++) ; |
262 |
|
} |
263 |
|
} |
264 |
|
} |
265 |
|
else |
266 |
|
{ |
267 |
|
|
268 |
|
//printf( "pack2D(): TRANS = false not yet implemented yet.\n" ); |
269 |
|
for ( auto i = 0; i < m; i ++ ) |
270 |
|
{ |
271 |
|
x0_pntr[ i ] = X0 + xmap[ i ]; |
272 |
|
x1_pntr[ i ] = X1 + xmap[ i ]; |
273 |
|
} |
274 |
|
for ( auto i = m; i < FOLD; i ++ ) |
275 |
|
{ |
276 |
|
x0_pntr[ i ] = X0 + xmap[ 0 ]; |
277 |
|
x1_pntr[ i ] = X1 + xmap[ 0 ]; |
278 |
|
} |
279 |
|
|
280 |
|
for ( auto j = 0; j < n; j ++ ) |
281 |
|
{ |
282 |
|
|
283 |
|
for ( auto i = 0; i < m; i ++ ) |
284 |
|
{ |
285 |
|
*packX = *x0_pntr[ i ] + gamma * *x1_pntr[ i ]; |
286 |
|
//printf( "NOTRANS:*x0_pntr[i]:%lf, gamma:%lf, x1_pntr[i]:%lf,packX:%lf\n",*x0_pntr[i], gamma, *x1_pntr[i], *packX); |
287 |
|
packX ++; |
288 |
|
x0_pntr[ i ] += ldx; |
289 |
|
x1_pntr[ i ] += ldx; |
290 |
|
} |
291 |
|
//printf( "ldx: %d\n" , ldx ); |
292 |
|
//printf( "m:%d,FOLD:%d\n", m, FOLD ); |
293 |
|
for ( auto i = m; i < FOLD; i ++ ) |
294 |
|
{ |
295 |
|
|
296 |
|
//printf( "i: %d\n", i ); |
297 |
|
if ( ZEROPAD ) *packX ++ = (T)0.0; |
298 |
|
else |
299 |
|
{ |
300 |
|
*packX = (*x0_pntr[ i ]) + gamma * (*x1_pntr[ i ]); |
301 |
|
*packX ++; |
302 |
|
|
303 |
|
x0_pntr[ i ] += ldx; |
304 |
|
x1_pntr[ i ] += ldx; |
305 |
|
} |
306 |
|
} |
307 |
|
|
308 |
|
} |
309 |
|
|
310 |
|
} |
311 |
|
}; // end pack2D() |
312 |
|
|
313 |
|
|
314 |
|
/** |
315 |
|
* @brief |
316 |
|
*/ |
317 |
|
template<bool TRANS, int FOLD, bool ZEROPAD=false, typename T> |
318 |
|
inline void pack2D |
319 |
|
( |
320 |
|
int m, int n, |
321 |
|
T *X0, T *X1, int ldx, T gamma, T *packX |
322 |
|
) |
323 |
|
{ |
324 |
|
int xmap[ FOLD ]; |
325 |
|
for ( int i = 0; i < FOLD; i ++ ) xmap[ i ] = i; |
326 |
|
pack2D<TRANS, FOLD, ZEROPAD, T> |
327 |
|
( |
328 |
|
m, n, |
329 |
|
X0, X1, ldx, gamma, xmap, packX |
330 |
|
); |
331 |
|
}; // end pack2D() |
332 |
|
|
333 |
|
|
334 |
|
|
335 |
|
/** |
336 |
|
* |
337 |
|
*/ |
338 |
|
template<bool TRANS, int FOLD, bool ZEROPAD=false, typename T> |
339 |
|
inline void pack2D |
340 |
|
( |
341 |
|
int m, int n, |
342 |
|
T *X, int ldx, int *xmap, T *packX |
343 |
|
) |
344 |
|
{ |
345 |
|
T *x_pntr[ FOLD ]; |
346 |
|
|
347 |
|
if ( TRANS ) |
348 |
|
{ |
349 |
|
for ( auto i = 0; i < m; i ++ ) |
350 |
|
{ |
351 |
|
x_pntr[ i ] = X + ldx * xmap[ i ]; |
352 |
|
} |
353 |
|
for ( auto i = m; i < FOLD; i ++ ) |
354 |
|
{ |
355 |
|
x_pntr[ i ] = X + ldx * xmap[ 0 ]; |
356 |
|
} |
357 |
|
for ( auto j = 0; j < n; j ++ ) |
358 |
|
{ |
359 |
|
for ( auto i = 0; i < m; i ++ ) |
360 |
|
{ |
361 |
|
*packX ++ = *x_pntr[ i ] ++; |
362 |
|
} |
363 |
|
for ( auto i = m; i < FOLD; i ++ ) |
364 |
|
{ |
365 |
|
if ( ZEROPAD ) *packX ++ = (T)0.0; |
366 |
|
else *packX ++ = *x_pntr[ i ] ++; |
367 |
|
} |
368 |
|
} |
369 |
|
} |
370 |
|
else |
371 |
|
{ |
372 |
|
//printf( "pack2D(): TRANS = false not yet implemented yet.\n" ); |
373 |
|
for ( auto i = 0; i < m; i ++ ) |
374 |
|
{ |
375 |
|
x_pntr[ i ] = X + xmap[ i ]; |
376 |
|
} |
377 |
|
for ( auto i = m; i < FOLD; i ++ ) |
378 |
|
{ |
379 |
|
x_pntr[ i ] = X + xmap[ 0 ]; |
380 |
|
} |
381 |
|
for ( auto j = 0; j < n; j ++ ) |
382 |
|
{ |
383 |
|
for ( auto i = 0; i < m; i ++ ) |
384 |
|
{ |
385 |
|
*packX = *x_pntr[ i ]; |
386 |
|
packX ++; |
387 |
|
x_pntr[ i ] += ldx; |
388 |
|
} |
389 |
|
for ( auto i = m; i < FOLD; i ++ ) |
390 |
|
{ |
391 |
|
if ( ZEROPAD ) *packX ++ = (T)0.0; |
392 |
|
else |
393 |
|
{ |
394 |
|
*packX = *x_pntr[ i ]; |
395 |
|
*packX ++; |
396 |
|
x_pntr[ i ] += ldx; |
397 |
|
} |
398 |
|
} |
399 |
|
} |
400 |
|
} |
401 |
|
}; |
402 |
|
|
403 |
|
/** |
404 |
|
* |
405 |
|
*/ |
406 |
|
template<bool TRANS, int FOLD, bool ZEROPAD=false, typename T> |
407 |
|
inline void pack2D |
408 |
|
( |
409 |
|
int m, int n, |
410 |
|
T *X, int ldx, T *packX |
411 |
|
) |
412 |
|
{ |
413 |
|
int xmap[ FOLD ]; |
414 |
|
for ( int i = 0; i < FOLD; i ++ ) xmap[ i ] = i; |
415 |
|
pack2D<TRANS, FOLD, ZEROPAD, T> |
416 |
|
( |
417 |
|
m, n, |
418 |
|
X, ldx, xmap, packX |
419 |
|
); |
420 |
|
} |
421 |
|
|
422 |
|
|
423 |
|
|
424 |
|
|
425 |
|
/** |
426 |
|
* |
427 |
|
*/ |
428 |
|
template<int PACK_MR, typename TA> |
429 |
|
inline void packA_kcxmc( |
430 |
|
int m, int k, |
431 |
|
TA *A, int lda, int *amap, TA *packA ) |
432 |
|
{ |
433 |
|
TA *a_pntr[ PACK_MR ]; |
434 |
|
|
435 |
|
for ( auto i = 0; i < m; i ++ ) a_pntr[ i ] = A + lda * amap[ i ]; |
436 |
|
for ( auto i = m; i < PACK_MR; i ++ ) a_pntr[ i ] = A + lda * amap[ 0 ]; |
437 |
|
for ( auto p = 0; p < k; p ++ ) |
438 |
|
{ |
439 |
|
for ( auto i = 0; i < PACK_MR; i ++ ) |
440 |
|
{ |
441 |
|
*packA ++ = *a_pntr[ i ] ++; |
442 |
|
} |
443 |
|
} |
444 |
|
} |
445 |
|
|
446 |
|
/** |
447 |
|
* |
448 |
|
*/ |
449 |
|
template<int PACK_NR, typename TB> |
450 |
|
inline void packB_kcxnc( |
451 |
|
int n, int k, |
452 |
|
TB *B, int ldb, int *bmap, TB *packB ) |
453 |
|
{ |
454 |
|
int j, p; |
455 |
|
TB *b_pntr[ PACK_NR ]; |
456 |
|
|
457 |
|
for ( j = 0; j < n; j ++ ) b_pntr[ j ] = B + ldb * bmap[ j ]; |
458 |
|
for ( j = n; j < PACK_NR; j ++ ) b_pntr[ j ] = B + ldb * bmap[ 0 ]; |
459 |
|
for ( p = 0; p < k; p ++ ) |
460 |
|
{ |
461 |
|
for ( j = 0; j < PACK_NR; j ++ ) |
462 |
|
{ |
463 |
|
*packB ++ = *b_pntr[ j ] ++; |
464 |
|
} |
465 |
|
} |
466 |
|
} |
467 |
|
|
468 |
|
/** |
469 |
|
* |
470 |
|
*/ |
471 |
|
template<int PACK_NR, typename TC> |
472 |
|
inline void packw_rhsxnc( |
473 |
|
int n, int rhs, |
474 |
|
TC *w, int ldw, int *wmap, TC *packw ) |
475 |
|
{ |
476 |
|
int j, p; |
477 |
|
TC *w_pntr[ PACK_NR ]; |
478 |
|
|
479 |
|
for ( j = 0; j < n; j ++ ) w_pntr[ j ] = w + ldw * wmap[ j ]; |
480 |
|
|
481 |
|
for ( p = 0; p < rhs; p ++ ) |
482 |
|
{ |
483 |
|
for ( j = 0; j < n; j ++ ) |
484 |
|
{ |
485 |
|
*packw ++ = *w_pntr[ j ] ++; |
486 |
|
} |
487 |
|
for ( j = n; j < PACK_NR; j ++ ) |
488 |
|
{ |
489 |
|
*packw ++ = 0.0; |
490 |
|
} |
491 |
|
} |
492 |
|
} |
493 |
|
|
494 |
|
/** |
495 |
|
* |
496 |
|
*/ |
497 |
|
template<int PACK_MR, typename TC> |
498 |
|
inline void packu_rhsxmc( |
499 |
|
int m, int rhs, |
500 |
|
TC *u, int ldu, int *umap, TC *packu ) |
501 |
|
{ |
502 |
|
int i, p; |
503 |
|
TC *u_pntr[ PACK_MR ]; |
504 |
|
|
505 |
|
for ( i = 0; i < m; i ++ ) u_pntr[ i ] = u + ldu * umap[ i ]; |
506 |
|
for ( p = 0; p < rhs; p ++ ) |
507 |
|
{ |
508 |
|
for ( i = 0; i < m; i ++ ) |
509 |
|
{ |
510 |
|
*packu ++ = *u_pntr[ i ] ++; |
511 |
|
} |
512 |
|
for ( i = m; i < PACK_MR; i ++ ) |
513 |
|
{ |
514 |
|
packu ++; |
515 |
|
} |
516 |
|
} |
517 |
|
}; |
518 |
|
|
519 |
|
|
520 |
|
|
521 |
|
}; // end namespace hmlp |
522 |
|
|
523 |
|
#endif // define HMLP_PACKING_HPP |