1 |
|
/** |
2 |
|
* HMLP (High-Performance Machine Learning Primitives) |
3 |
|
* |
4 |
|
* Copyright (C) 2014-2017, The University of Texas at Austin |
5 |
|
* |
6 |
|
* This program is free software: you can redistribute it and/or modify |
7 |
|
* it under the terms of the GNU General Public License as published by |
8 |
|
* the Free Software Foundation, either version 3 of the License, or |
9 |
|
* (at your option) any later version. |
10 |
|
* |
11 |
|
* This program is distributed in the hope that it will be useful, |
12 |
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of |
13 |
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
14 |
|
* GNU General Public License for more details. |
15 |
|
* |
16 |
|
* You should have received a copy of the GNU General Public License |
17 |
|
* along with this program. If not, see the LICENSE file. |
18 |
|
* |
19 |
|
**/ |
20 |
|
|
21 |
|
|
22 |
|
|
23 |
|
|
24 |
|
#ifndef CNN_HPP |
25 |
|
#define CNN_HPP |
26 |
|
|
27 |
|
#include <hmlp.h> |
28 |
|
#include <hmlp_internal.hpp> |
29 |
|
#include <hmlp_base.hpp> |
30 |
|
|
31 |
|
// #define DEBUG_CONV2D 1 |
32 |
|
|
33 |
|
namespace hmlp |
34 |
|
{ |
35 |
|
namespace cnn |
36 |
|
{ |
37 |
|
|
38 |
|
/** |
39 |
|
* |
40 |
|
*/ |
41 |
|
template< |
42 |
|
int KC, int MR, int NR, int PACK_MR, int PACK_NR, |
43 |
|
typename SEMIRINGKERNEL, |
44 |
|
typename TA, typename TB, typename TC, typename TV> |
45 |
|
void rank_k_macro_kernel |
46 |
|
( |
47 |
|
Worker &thread, |
48 |
|
int ic, int jc, int pc, |
49 |
|
int m, int n, int k, |
50 |
|
TA *packA, |
51 |
|
TB *packB, |
52 |
|
TV *C, int ldc, |
53 |
|
SEMIRINGKERNEL semiringkernel |
54 |
|
) |
55 |
|
{ |
56 |
|
thread_communicator &ic_comm = *thread.ic_comm; |
57 |
|
|
58 |
|
auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() ); |
59 |
|
auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() ); |
60 |
|
auto loop2nd = GetRange( 0, m, MR ); |
61 |
|
auto pack2nd = GetRange( 0, m, PACK_MR ); |
62 |
|
|
63 |
|
for ( int j = loop3rd.beg(), jp = pack3rd.beg(); |
64 |
|
j < loop3rd.end(); |
65 |
|
j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop |
66 |
|
{ |
67 |
|
struct aux_s<TA, TB, TC, TV> aux; |
68 |
|
aux.pc = pc; |
69 |
|
aux.b_next = packB; |
70 |
|
aux.do_packC = 0; |
71 |
|
aux.jb = std::min( n - j, NR ); |
72 |
|
|
73 |
|
for ( int i = loop2nd.beg(), ip = pack2nd.beg(); |
74 |
|
i < loop2nd.end(); |
75 |
|
i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop |
76 |
|
{ |
77 |
|
aux.ib = std::min( m - i, MR ); |
78 |
|
if ( aux.ib != MR ) |
79 |
|
{ |
80 |
|
aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k; |
81 |
|
} |
82 |
|
|
83 |
|
if ( aux.jb == NR && aux.ib == MR ) |
84 |
|
{ |
85 |
|
semiringkernel |
86 |
|
( |
87 |
|
k, |
88 |
|
&packA[ ip * k ], |
89 |
|
&packB[ jp * k ], |
90 |
|
&C[ j * ldc + i ], 1, ldc, |
91 |
|
&aux |
92 |
|
); |
93 |
|
} |
94 |
|
else // corner case |
95 |
|
{ |
96 |
|
// TODO: this should be initC. |
97 |
|
TV ctmp[ MR * NR ] = { (TV)0.0 }; |
98 |
|
semiringkernel |
99 |
|
( |
100 |
|
k, |
101 |
|
&packA[ ip * k ], |
102 |
|
&packB[ jp * k ], |
103 |
|
ctmp, 1, MR, |
104 |
|
&aux |
105 |
|
); |
106 |
|
if ( pc ) |
107 |
|
{ |
108 |
|
for ( auto jj = 0; jj < aux.jb; jj ++ ) |
109 |
|
{ |
110 |
|
for ( auto ii = 0; ii < aux.ib; ii ++ ) |
111 |
|
{ |
112 |
|
C[ ( j + jj ) * ldc + i + ii ] += ctmp[ jj * MR + ii ]; |
113 |
|
} |
114 |
|
} |
115 |
|
} |
116 |
|
else |
117 |
|
{ |
118 |
|
for ( auto jj = 0; jj < aux.jb; jj ++ ) |
119 |
|
{ |
120 |
|
for ( auto ii = 0; ii < aux.ib; ii ++ ) |
121 |
|
{ |
122 |
|
C[ ( j + jj ) * ldc + i + ii ] = ctmp[ jj * MR + ii ]; |
123 |
|
} |
124 |
|
} |
125 |
|
} |
126 |
|
} |
127 |
|
} // end 2nd loop |
128 |
|
} // end 3rd loop |
129 |
|
} // end rank_k_macro_kernel |
130 |
|
|
131 |
|
/** |
132 |
|
* |
133 |
|
*/ |
134 |
|
template< |
135 |
|
int KC, |
136 |
|
int MR, |
137 |
|
int NR, |
138 |
|
int PACK_MR, |
139 |
|
int PACK_NR, |
140 |
|
typename MICROKERNEL, |
141 |
|
typename TA, typename TB, typename TC, typename TV> |
142 |
|
void fused_macro_kernel |
143 |
|
( |
144 |
|
Worker &thread, |
145 |
|
int ic, int jc, int pc, |
146 |
|
int m, int n, int k, |
147 |
|
TA *packA, |
148 |
|
TB *packB, |
149 |
|
TV *C, int ldc, |
150 |
|
MICROKERNEL microkernel |
151 |
|
) |
152 |
|
{ |
153 |
|
thread_communicator &ic_comm = *thread.ic_comm; |
154 |
|
|
155 |
|
auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() ); |
156 |
|
auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() ); |
157 |
|
auto loop2nd = GetRange( 0, m, MR ); |
158 |
|
auto pack2nd = GetRange( 0, m, PACK_MR ); |
159 |
|
|
160 |
|
for ( int j = loop3rd.beg(), jp = pack3rd.beg(); |
161 |
|
j < loop3rd.end(); |
162 |
|
j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop |
163 |
|
{ |
164 |
|
struct aux_s<TA, TB, TC, TV> aux; |
165 |
|
aux.pc = pc; |
166 |
|
aux.b_next = packB; |
167 |
|
aux.do_packC = 0; |
168 |
|
aux.jb = std::min( n - j, NR ); |
169 |
|
|
170 |
|
for ( int i = loop2nd.beg(), ip = pack2nd.beg(); |
171 |
|
i < loop2nd.end(); |
172 |
|
i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop |
173 |
|
{ |
174 |
|
aux.ib = std::min( m - i, MR ); |
175 |
|
if ( aux.ib != MR ) |
176 |
|
{ |
177 |
|
aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k; |
178 |
|
} |
179 |
|
|
180 |
|
if ( aux.jb == NR && aux.ib == MR ) |
181 |
|
{ |
182 |
|
microkernel |
183 |
|
( |
184 |
|
k, |
185 |
|
&packA[ ip * k ], |
186 |
|
&packB[ jp * k ], |
187 |
|
&C[ j * ldc + i ], 1, ldc, |
188 |
|
&aux |
189 |
|
); |
190 |
|
} |
191 |
|
else // corner case |
192 |
|
{ |
193 |
|
TV ctmp[ MR * NR ] = { (TV)0.0 }; |
194 |
|
microkernel |
195 |
|
( |
196 |
|
k, |
197 |
|
&packA[ ip * k ], |
198 |
|
&packB[ jp * k ], |
199 |
|
ctmp, 1, MR, |
200 |
|
&aux |
201 |
|
); |
202 |
|
|
203 |
|
if ( pc ) |
204 |
|
{ |
205 |
|
for ( auto jj = 0; jj < aux.jb; jj ++ ) |
206 |
|
{ |
207 |
|
for ( auto ii = 0; ii < aux.ib; ii ++ ) |
208 |
|
{ |
209 |
|
C[ ( j + jj ) * ldc + i + ii ] += ctmp[ jj * MR + ii ]; |
210 |
|
} |
211 |
|
} |
212 |
|
} |
213 |
|
else |
214 |
|
{ |
215 |
|
for ( auto jj = 0; jj < aux.jb; jj ++ ) |
216 |
|
{ |
217 |
|
for ( auto ii = 0; ii < aux.ib; ii ++ ) |
218 |
|
{ |
219 |
|
C[ ( j + jj ) * ldc + i + ii ] = ctmp[ jj * MR + ii ]; |
220 |
|
} |
221 |
|
} |
222 |
|
} |
223 |
|
} |
224 |
|
} // end 2nd loop |
225 |
|
} // end 3rd loop |
226 |
|
}; // end fused_macro_kernel |
227 |
|
|
228 |
|
|
229 |
|
|
230 |
|
/* |
231 |
|
* |
232 |
|
*/ |
233 |
|
template< |
234 |
|
int MC, int NC, int KC, int MR, int NR, |
235 |
|
int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE, |
236 |
|
bool USE_STRASSEN, |
237 |
|
typename SEMIRINGKERNEL, typename MICROKERNEL, |
238 |
|
typename TA, typename TB, typename TC, typename TV> |
239 |
|
void conv2d_internal |
240 |
|
( |
241 |
|
Worker &thread, |
242 |
|
int w0, int h0, int d0, int s, int p, |
243 |
|
TB *B, |
244 |
|
int w1, int h1, int d1, |
245 |
|
TA *A, |
246 |
|
TC *C, |
247 |
|
SEMIRINGKERNEL semiringkernel, |
248 |
|
MICROKERNEL microkernel, |
249 |
|
int nc, int pack_nc, |
250 |
|
TA *packA, |
251 |
|
TB *packB |
252 |
|
) |
253 |
|
{ |
254 |
|
packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC |
255 |
|
+ ( thread.ic_id ) * PACK_MC * KC; |
256 |
|
packB += ( thread.jc_id ) * pack_nc * KC; |
257 |
|
|
258 |
|
|
259 |
|
// Now compute parameters such that I can transform the problem into GEMM. |
260 |
|
int m = d1; |
261 |
|
int nx = ( w0 - w1 + 2 * p ) / s + 1; |
262 |
|
int ny = ( h0 - h1 + 2 * p ) / s + 1; |
263 |
|
int n = nx * ny; |
264 |
|
int k = w1 * h1 * d0; |
265 |
|
|
266 |
|
//auto loop6th = GetRange( HMLP_SCHEDULE_HEFT, 0, n, nc, thread.jc_id, thread.jc_nt ); |
267 |
|
auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt ); |
268 |
|
auto loop5th = GetRange( 0, k, KC ); |
269 |
|
auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt ); |
270 |
|
|
271 |
|
//printf( "tid %d beg %d end %d inc %d\n", thread.jc_id, loop6th.beg(), loop6th.end(), loop6th.inc() ); |
272 |
|
|
273 |
|
//double my_beg = omp_get_wtime(); |
274 |
|
/* |
275 |
|
* @CHENHAN: loop over your filters. |
276 |
|
*/ |
277 |
|
for ( int jc = loop6th.beg(); |
278 |
|
jc < loop6th.end(); |
279 |
|
jc += loop6th.inc() ) // beg 6th loop |
280 |
|
{ |
281 |
|
auto &jc_comm = *thread.jc_comm; |
282 |
|
auto jb = std::min( n - jc, nc ); |
283 |
|
|
284 |
|
/* |
285 |
|
* @CHENHAN: loop over your window size ( w1 * h1 * d0 ). |
286 |
|
*/ |
287 |
|
for ( int pc = loop5th.beg(); |
288 |
|
pc < loop5th.end(); |
289 |
|
pc += loop5th.inc() ) |
290 |
|
{ |
291 |
|
auto &pc_comm = *thread.pc_comm; |
292 |
|
auto pb = std::min( k - pc, KC ); |
293 |
|
auto is_the_last_pc_iteration = ( pc + KC >= k ); |
294 |
|
|
295 |
|
/* |
296 |
|
* @CHENHAN: pack image into packB. |
297 |
|
*/ |
298 |
|
auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() ); |
299 |
|
auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() ); |
300 |
|
|
301 |
|
for ( int j = looppkB.beg(), jp = packpkB.beg(); |
302 |
|
j < looppkB.end(); |
303 |
|
j += looppkB.inc(), jp += packpkB.inc() ) |
304 |
|
{ |
305 |
|
auto x0 = ( ( jc + j ) % nx ) * s - p; // top-left |
306 |
|
auto y0 = ( ( jc + j ) / nx ) * s - p; // top-left |
307 |
|
|
308 |
|
#ifdef DEBUG_CONV2D |
309 |
|
printf( "x0 %4d y0 %4d\n", x0, y0 ); |
310 |
|
#endif |
311 |
|
|
312 |
|
pack2Dimg<PACK_NR> // packB |
313 |
|
( |
314 |
|
std::min( jb - j, NR ), pb, |
315 |
|
&packB[ jp * pb ], |
316 |
|
x0, y0, pc, |
317 |
|
B, |
318 |
|
w0, h0, d0, s, p, |
319 |
|
w1, h1 |
320 |
|
); |
321 |
|
} |
322 |
|
pc_comm.Barrier(); |
323 |
|
|
324 |
|
|
325 |
|
#ifdef DEBUG_CONV2D |
326 |
|
for ( int i = 0; i < pb; i ++ ) |
327 |
|
{ |
328 |
|
for ( int jj = 0; jj < jb; jj += NR ) |
329 |
|
{ |
330 |
|
for ( int j = 0; j < NR; j ++ ) |
331 |
|
{ |
332 |
|
printf( "%5.2lf ", packB[ jj * pb + i * NR + j ] ); |
333 |
|
} |
334 |
|
printf( " " ); |
335 |
|
} |
336 |
|
printf( "\n" ); |
337 |
|
} |
338 |
|
printf( "\n" ); |
339 |
|
#endif |
340 |
|
|
341 |
|
|
342 |
|
for ( int ic = loop4th.beg(); |
343 |
|
ic < loop4th.end(); |
344 |
|
ic += loop4th.inc() ) // beg 4th loop |
345 |
|
{ |
346 |
|
auto &ic_comm = *thread.ic_comm; |
347 |
|
auto ib = std::min( m - ic, MC ); |
348 |
|
|
349 |
|
auto looppkA = GetRange( 0, ib, MR, thread.jr_id, thread.jr_nt ); |
350 |
|
auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt ); |
351 |
|
|
352 |
|
/* |
353 |
|
* @CHENHAN: assume filters were already packed format. |
354 |
|
*/ |
355 |
|
for ( int i = looppkA.beg(), ip = packpkA.beg(); |
356 |
|
i < looppkA.end(); |
357 |
|
i += looppkA.inc(), ip += packpkA.inc() ) |
358 |
|
{ |
359 |
|
pack2D<true, PACK_MR> // packA (transA) |
360 |
|
( |
361 |
|
std::min( ib - i, MR ), pb, |
362 |
|
&A[ ( ic + i ) * k + pc ], k, &packA[ ip * pb ] |
363 |
|
); |
364 |
|
} |
365 |
|
|
366 |
|
if ( is_the_last_pc_iteration ) // fused_macro_kernel |
367 |
|
{ |
368 |
|
fused_macro_kernel |
369 |
|
<KC, MR, NR, PACK_MR, PACK_NR, MICROKERNEL, TA, TB, TC, TV> |
370 |
|
( |
371 |
|
thread, |
372 |
|
ic, jc, pc, |
373 |
|
ib, jb, pb, |
374 |
|
packA, |
375 |
|
packB, |
376 |
|
C + jc * m + ic, m, |
377 |
|
microkernel |
378 |
|
); |
379 |
|
} |
380 |
|
else // semiring rank-k update |
381 |
|
{ |
382 |
|
rank_k_macro_kernel |
383 |
|
<KC, MR, NR, PACK_MR, PACK_NR, SEMIRINGKERNEL, TA, TB, TC, TV> |
384 |
|
( |
385 |
|
thread, |
386 |
|
ic, jc, pc, |
387 |
|
ib, jb, pb, |
388 |
|
packA, |
389 |
|
packB, |
390 |
|
C + jc * m + ic, m, |
391 |
|
semiringkernel |
392 |
|
); |
393 |
|
} |
394 |
|
ic_comm.Barrier(); // sync all jr_id!! |
395 |
|
} // end 4th loop |
396 |
|
pc_comm.Barrier(); |
397 |
|
} // end 5th loop |
398 |
|
} // end 6th loop |
399 |
|
//double my_time = omp_get_wtime() - my_beg; |
400 |
|
//double my_flop = ( ( loop6th.end() - loop6th.beg() ) / 1e+9 ) * 2 * m * k; |
401 |
|
////printf( "tid %d GFLOPS %5.2lf\n", thread.jc_id, my_flop / my_time ); |
402 |
|
//printf( "tid %d GFLOPS %5.2lf\n", thread.jc_id, my_time ); |
403 |
|
}; // end cnn_internal |
404 |
|
|
405 |
|
|
406 |
|
|
407 |
|
|
408 |
|
|
409 |
|
/** |
410 |
|
* @CHENHAN: |
411 |
|
* |
412 |
|
* These templates (the same as gkmx.hpp) define a general matrix-matrix multiplication. |
413 |
|
* You will be using these existing code to write a convolution operation. |
414 |
|
* |
415 |
|
* (First) you should define what parameters you need. For convolution, your |
416 |
|
* input A will be a image (tensor). B is filters (tensor). C is the output, |
417 |
|
* which again should be a tensor. Since tensors need more attributes to |
418 |
|
* describe. You will need to think about what you need instead of m, n, k, |
419 |
|
* lda, ldb, ldc. |
420 |
|
* |
421 |
|
* (Second) you need to restructure the loop to loop over each convolution |
422 |
|
* window. The window size (width*length) is the k dimension of your GEMM. |
423 |
|
* Notice for each loop in the original GEMM operation you may need more than |
424 |
|
* one loop in the convolution expression. |
425 |
|
* |
426 |
|
* The jc loop (6th) will loop over each NC filters. |
427 |
|
* The pc loop (5th) will loop over each KC elements in one window. |
428 |
|
* The ic loop (4th) will loop over each MC windows of your image. |
429 |
|
* |
430 |
|
* You probably don't need to change anything about the macro kernels we |
431 |
|
* define here (3rd, 2nd loops), since in 4th loop you already transformed the |
432 |
|
* problem into a GEMM operation. |
433 |
|
* |
434 |
|
* (Third) finally you need to write two packing routines and one unpacking |
435 |
|
* routine. Think about how to pack your image into packA and how to pack your |
436 |
|
* filters into packB. Finally, you need to reshape your C back to the |
437 |
|
* original tensor shape. |
438 |
|
* |
439 |
|
* (Fourth) write a reference function cnn_ref and a test function |
440 |
|
* /hmlp/test/test_cnn.cpp to compare your results. |
441 |
|
* |
442 |
|
* Good luck and have fun! |
443 |
|
* |
444 |
|
* |
445 |
|
*/ |
446 |
|
template< |
447 |
|
int MC, int NC, int KC, int MR, int NR, |
448 |
|
int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE, |
449 |
|
bool USE_STRASSEN, |
450 |
|
typename SEMIRINGKERNEL, typename MICROKERNEL, |
451 |
|
typename TA, typename TB, typename TC, typename TV> |
452 |
|
void conv2d |
453 |
|
( |
454 |
|
int w0, int h0, int d0, int s, int p, |
455 |
|
TA *B, |
456 |
|
int w1, int h1, int d1, |
457 |
|
TB *A, |
458 |
|
TC *C, |
459 |
|
SEMIRINGKERNEL semiringkernel, |
460 |
|
MICROKERNEL microkernel |
461 |
|
) |
462 |
|
{ |
463 |
|
int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1; |
464 |
|
int nc = NC, pack_nc = PACK_NC; |
465 |
|
char *str; |
466 |
|
|
467 |
|
int m = d1; |
468 |
|
int nx = ( w0 - w1 + 2 * p ) / s + 1; |
469 |
|
int ny = ( h0 - h1 + 2 * p ) / s + 1; |
470 |
|
int n = nx * ny; |
471 |
|
int k = w1 * h1 * d0; |
472 |
|
|
473 |
|
|
474 |
|
//printf( "m %4d n %4d k %4d\n", m, n, k ); |
475 |
|
|
476 |
|
TA *packA_buff = NULL; |
477 |
|
TB *packB_buff = NULL; |
478 |
|
|
479 |
|
// Early return if possible |
480 |
|
|
481 |
|
// Check the environment variable. |
482 |
|
if ( omp_get_num_threads() == 1 && omp_get_max_threads() > 1 ) |
483 |
|
{ |
484 |
|
jc_nt = hmlp_read_nway_from_env( "KS_JC_NT" ); |
485 |
|
ic_nt = hmlp_read_nway_from_env( "KS_IC_NT" ); |
486 |
|
jr_nt = hmlp_read_nway_from_env( "KS_JR_NT" ); |
487 |
|
} |
488 |
|
|
489 |
|
|
490 |
|
if ( jc_nt > 1 ) |
491 |
|
{ |
492 |
|
nc = ( ( n - 1 ) / ( NR * jc_nt ) + 1 ) * NR; |
493 |
|
//if ( nc > NC ) nc = NC; |
494 |
|
pack_nc = ( nc / NR ) * PACK_NR; |
495 |
|
} |
496 |
|
|
497 |
|
// allocate packing memory |
498 |
|
packA_buff = hmlp_malloc<ALIGN_SIZE, TA>( KC, ( PACK_MC + 1 ) * jc_nt * ic_nt, sizeof(TA) ); |
499 |
|
packB_buff = hmlp_malloc<ALIGN_SIZE, TB>( KC, ( pack_nc + 1 ) * jc_nt, sizeof(TB) ); |
500 |
|
|
501 |
|
//#pragma omp parallel for |
502 |
|
//for ( int i = 0; i < KC * ( PACK_MC + 1 ) * jc_nt * ic_nt; i ++ ) packA_buff[ i ] = 1.0; |
503 |
|
|
504 |
|
|
505 |
|
// allocate tree communicator |
506 |
|
thread_communicator my_comm( jc_nt, pc_nt, ic_nt, jr_nt ); |
507 |
|
|
508 |
|
|
509 |
|
#pragma omp parallel num_threads( my_comm.GetNumThreads() ) |
510 |
|
{ |
511 |
|
Worker thread( &my_comm ); |
512 |
|
|
513 |
|
if ( USE_STRASSEN ) |
514 |
|
{ |
515 |
|
printf( "cnn: strassen algorithms haven't been implemented." ); |
516 |
|
exit( 1 ); |
517 |
|
} |
518 |
|
|
519 |
|
conv2d_internal |
520 |
|
<MC, NC, KC, MR, NR, |
521 |
|
PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, |
522 |
|
USE_STRASSEN, |
523 |
|
SEMIRINGKERNEL, MICROKERNEL, |
524 |
|
TA, TB, TC, TB> |
525 |
|
( |
526 |
|
thread, |
527 |
|
w0, h0, d0, s, p, |
528 |
|
B, |
529 |
|
w1, h1, d1, |
530 |
|
A, |
531 |
|
C, |
532 |
|
semiringkernel, microkernel, |
533 |
|
nc, pack_nc, |
534 |
|
packA_buff, |
535 |
|
packB_buff |
536 |
|
); |
537 |
|
} // end omp |
538 |
|
|
539 |
|
#ifdef DEBUG_CONV2D |
540 |
|
for ( int j = 0; j < ny; j ++ ) |
541 |
|
{ |
542 |
|
for ( int i = 0; i < nx; i ++ ) |
543 |
|
{ |
544 |
|
printf( "%5.2lf ", C[ j * nx + i ] ); |
545 |
|
} |
546 |
|
printf( "\n" ); |
547 |
|
} |
548 |
|
#endif |
549 |
|
|
550 |
|
}; // end cnn |
551 |
|
|
552 |
|
|
553 |
|
//template< |
554 |
|
// int MC, int NC, int KC, int MR, int NR, |
555 |
|
// int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE, |
556 |
|
// bool USE_STRASSEN, |
557 |
|
// typename SEMIRINGKERNEL, typename MICROKERNEL, |
558 |
|
// typename TA, typename TB, typename TC, typename TV> |
559 |
|
//void conv2d |
560 |
|
//( |
561 |
|
// int w0, int h0, int d0, |
562 |
|
// TA *B, |
563 |
|
// int w1, int h1, int d1, |
564 |
|
// TB *A, |
565 |
|
// TC *C, |
566 |
|
// SEMIRINGKERNEL semiringkernel, |
567 |
|
// MICROKERNEL microkernel |
568 |
|
//) |
569 |
|
//{ |
570 |
|
// // Deciding s and p given the output size is also (w0, h0). |
571 |
|
// // w0 = ( w0 - w1 + 2 * p ) / s + 1 |
572 |
|
// // h0 = ( h0 - h1 + 2 * p ) / s + 1 |
573 |
|
// // if s = 1, then p = ( w1 - 1 ) / 2 |
574 |
|
// // p = ( h1 - 1 ) / 2 |
575 |
|
// // that is w1 and h1 must be odd. |
576 |
|
// |
577 |
|
// assert( w1 == h1 ); |
578 |
|
// |
579 |
|
// conv2d |
580 |
|
// <MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, |
581 |
|
// USE_STRASSEN, |
582 |
|
// SEMIRINGKERNEL, MICROKERNEL, |
583 |
|
// TA, TB, TC, TV> |
584 |
|
// ( |
585 |
|
// w0, h0, d0, 1, ( w1 - 1 ) / 2, |
586 |
|
// B, |
587 |
|
// w1, h1, d1, |
588 |
|
// A, |
589 |
|
// C, |
590 |
|
// semiringkernel, |
591 |
|
// microkernel |
592 |
|
// ); |
593 |
|
//}; |
594 |
|
|
595 |
|
template< |
596 |
|
int MC, int NC, int KC, int MR, int NR, |
597 |
|
int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE, |
598 |
|
bool USE_STRASSEN, |
599 |
|
typename SEMIRINGKERNEL, typename MICROKERNEL, |
600 |
|
typename TA, typename TB, typename TC, typename TV> |
601 |
|
void conv2d |
602 |
|
( |
603 |
|
int w0, int h0, int d0, int s, int p, int batchSize, |
604 |
|
TA *B, |
605 |
|
int w1, int h1, int d1, |
606 |
|
TB *A, |
607 |
|
TC *C, |
608 |
|
SEMIRINGKERNEL semiringkernel, |
609 |
|
MICROKERNEL microkernel |
610 |
|
) |
611 |
|
{ |
612 |
|
// Deciding s and p given the output size is also (w0, h0). |
613 |
|
// w0 = ( w0 - w1 + 2 * p ) / s + 1 |
614 |
|
// h0 = ( h0 - h1 + 2 * p ) / s + 1 |
615 |
|
// if s = 1, then p = ( w1 - 1 ) / 2 |
616 |
|
// p = ( h1 - 1 ) / 2 |
617 |
|
// that is w1 and h1 must be odd. |
618 |
|
|
619 |
|
int nx = ( w0 - w1 + 2 * p ) / s + 1; |
620 |
|
int ny = ( h0 - h1 + 2 * p ) / s + 1; |
621 |
|
|
622 |
|
|
623 |
|
assert( w1 == h1 ); |
624 |
|
|
625 |
|
#pragma omp parallel for |
626 |
|
for ( int b = 0; b < batchSize; b ++ ) |
627 |
|
{ |
628 |
|
conv2d |
629 |
|
<MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, |
630 |
|
USE_STRASSEN, |
631 |
|
SEMIRINGKERNEL, MICROKERNEL, |
632 |
|
TA, TB, TC, TV> |
633 |
|
( |
634 |
|
w0, h0, d0, s, p, |
635 |
|
B + b * w0 * h0 * d0, |
636 |
|
w1, h1, d1, |
637 |
|
A, |
638 |
|
C + b * nx * ny * d1, |
639 |
|
semiringkernel, |
640 |
|
microkernel |
641 |
|
); |
642 |
|
} |
643 |
|
}; |
644 |
|
|
645 |
|
|
646 |
|
/** |
647 |
|
* @CHENHAN: write a reference function using GEMM. The signiture of xgemm can |
648 |
|
* be found in hmlp_blas_lapack.h. |
649 |
|
*/ |
650 |
|
template<typename T> |
651 |
|
void conv2d_ref |
652 |
|
( |
653 |
|
int w0, int h0, int d0, int s, int p, |
654 |
|
T *B, |
655 |
|
int w1, int h1, int d1, |
656 |
|
T *A, |
657 |
|
T *C |
658 |
|
) |
659 |
|
{ |
660 |
|
int m = d1; |
661 |
|
int nx = ( w0 - w1 + 2 * p ) / s + 1; |
662 |
|
int ny = ( h0 - h1 + 2 * p ) / s + 1; |
663 |
|
int n = nx * ny; |
664 |
|
int k = w1 * h1 * d0; |
665 |
|
|
666 |
|
T *packA = A; |
667 |
|
T *packB = hmlp_malloc<16, T>( k, n, sizeof(T) ); |
668 |
|
|
669 |
|
double beg = omp_get_wtime(); |
670 |
|
im2col<T> |
671 |
|
( |
672 |
|
n, k, |
673 |
|
packB, B, |
674 |
|
w0, h0, d0, s, p, |
675 |
|
w1, h1 |
676 |
|
); |
677 |
|
double im2col_t = omp_get_wtime() - beg; |
678 |
|
printf( "im2col( B ) %3.1Es\n", im2col_t ); fflush( stdout ); |
679 |
|
|
680 |
|
#ifdef DEBUG_CONV2D |
681 |
|
printf( "packB\n" ); |
682 |
|
for ( int p = 0; p < k; p ++ ) |
683 |
|
{ |
684 |
|
for ( int j = 0; j < n; j ++ ) |
685 |
|
{ |
686 |
|
printf( "%5.2lf ", packB[ j * k + p ] ); |
687 |
|
} |
688 |
|
printf( "\n" ); |
689 |
|
} |
690 |
|
#endif |
691 |
|
|
692 |
|
|
693 |
|
#ifdef USE_BLAS |
694 |
|
xgemm |
695 |
|
( |
696 |
|
"T", "N", |
697 |
|
m, n, k, |
698 |
|
1.0, packA, k, |
699 |
|
packB, k, |
700 |
|
0.0, C, m |
701 |
|
); |
702 |
|
#else |
703 |
|
#pragma omp parallel for |
704 |
|
for ( int j = 0; j < n; j ++ ) |
705 |
|
{ |
706 |
|
for ( int i = 0; i < m; i ++ ) |
707 |
|
{ |
708 |
|
C[ j * m + i ] = 0.0; |
709 |
|
for ( int p = 0; p < k; p ++ ) |
710 |
|
{ |
711 |
|
C[ j * m + i ] += packA[ i * k + p ] * packB[ j * k + p ]; |
712 |
|
} |
713 |
|
} |
714 |
|
} |
715 |
|
#endif |
716 |
|
}; // end void conv2d_ref |
717 |
|
|
718 |
|
template<typename T> |
719 |
|
void conv2d_ref |
720 |
|
( |
721 |
|
int w0, int h0, int d0, int s, int p, int batchSize, |
722 |
|
T *B, |
723 |
|
int w1, int h1, int d1, |
724 |
|
T *A, |
725 |
|
T *C |
726 |
|
) |
727 |
|
{ |
728 |
|
int nx = ( w0 - w1 + 2 * p ) / s + 1; |
729 |
|
int ny = ( h0 - h1 + 2 * p ) / s + 1; |
730 |
|
|
731 |
|
#pragma omp parallel for |
732 |
|
for ( int b = 0; b < batchSize; b ++ ) |
733 |
|
{ |
734 |
|
conv2d_ref<T> |
735 |
|
( |
736 |
|
w0, h0, d0, s, p, |
737 |
|
B + b * w0 * h0 * d0, |
738 |
|
w1, h1, d1, |
739 |
|
A, |
740 |
|
C + b * nx * ny * d1 |
741 |
|
); |
742 |
|
} |
743 |
|
}; |
744 |
|
|
745 |
|
}; // end namespace conv2d |
746 |
|
}; // end namespace hmlp |
747 |
|
|
748 |
|
#endif // define GKMX_HPP |