1 |
|
/** |
2 |
|
* HMLP (High-Performance Machine Learning Primitives) |
3 |
|
* |
4 |
|
* Copyright (C) 2014-2017, The University of Texas at Austin |
5 |
|
* |
6 |
|
* This program is free software: you can redistribute it and/or modify |
7 |
|
* it under the terms of the GNU General Public License as published by |
8 |
|
* the Free Software Foundation, either version 3 of the License, or |
9 |
|
* (at your option) any later version. |
10 |
|
* |
11 |
|
* This program is distributed in the hope that it will be useful, |
12 |
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of |
13 |
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
14 |
|
* GNU General Public License for more details. |
15 |
|
* |
16 |
|
* You should have received a copy of the GNU General Public License |
17 |
|
* along with this program. If not, see the LICENSE file. |
18 |
|
* |
19 |
|
**/ |
20 |
|
|
21 |
|
|
22 |
|
|
23 |
|
#ifndef GKMX_HPP |
24 |
|
#define GKMX_HPP |
25 |
|
|
26 |
|
#include <assert.h> |
27 |
|
#include <typeinfo> |
28 |
|
#include <algorithm> |
29 |
|
|
30 |
|
#include <hmlp.h> |
31 |
|
#include <hmlp_internal.hpp> |
32 |
|
#include <hmlp_base.hpp> |
33 |
|
|
34 |
|
/** for USE_STRASSEN */ |
35 |
|
#include <primitives/strassen.hpp> |
36 |
|
|
37 |
|
/** reference microkernels */ |
38 |
|
#include <semiring_mrxnr.hpp> |
39 |
|
#include <fused_mrxnr.hpp> |
40 |
|
|
41 |
|
//#define GKMX_CONFIG \ |
42 |
|
|
43 |
|
|
44 |
|
namespace hmlp |
45 |
|
{ |
46 |
|
namespace gkmx |
47 |
|
{ |
48 |
|
|
49 |
|
/** |
50 |
|
* @brief Macro kernel contains the 3rd and 2nd loops. Depending on the |
51 |
|
* configuration of the communicator, the 3rd loop may be parallelized. |
52 |
|
* b_next is the prefetch pointer. |
53 |
|
*/ |
54 |
|
template< |
55 |
|
int KC, int MR, int NR, int PACK_MR, int PACK_NR, |
56 |
|
typename SEMIRINGKERNEL, |
57 |
|
typename TA, typename TB, typename TC, typename TV> |
58 |
|
void rank_k_macro_kernel |
59 |
|
( |
60 |
|
Worker &thread, |
61 |
|
int ic, int jc, int pc, |
62 |
|
int m, int n, int k, |
63 |
|
TA *packA, |
64 |
|
TB *packB, |
65 |
|
TV *V, int ldv, |
66 |
|
SEMIRINGKERNEL semiringkernel |
67 |
|
) |
68 |
|
{ |
69 |
|
thread_communicator &ic_comm = *thread.ic_comm; |
70 |
|
|
71 |
|
auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() ); |
72 |
|
auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() ); |
73 |
|
auto loop2nd = GetRange( 0, m, MR ); |
74 |
|
auto pack2nd = GetRange( 0, m, PACK_MR ); |
75 |
|
|
76 |
|
for ( int j = loop3rd.beg(), jp = pack3rd.beg(); |
77 |
|
j < loop3rd.end(); |
78 |
|
j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop |
79 |
|
{ |
80 |
|
struct aux_s<TA, TB, TC, TV> aux; |
81 |
|
aux.pc = pc; |
82 |
|
aux.b_next = packB; |
83 |
|
aux.do_packC = 0; |
84 |
|
aux.jb = std::min( n - j, NR ); |
85 |
|
|
86 |
|
for ( int i = loop2nd.beg(), ip = pack2nd.beg(); |
87 |
|
i < loop2nd.end(); |
88 |
|
i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop |
89 |
|
{ |
90 |
|
aux.ib = std::min( m - i, MR ); |
91 |
|
if ( i + MR >= m ) |
92 |
|
{ |
93 |
|
aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k; |
94 |
|
} |
95 |
|
|
96 |
|
if ( aux.jb == NR && aux.ib == MR ) |
97 |
|
{ |
98 |
|
semiringkernel |
99 |
|
( |
100 |
|
k, |
101 |
|
&packA[ ip * k ], |
102 |
|
&packB[ jp * k ], |
103 |
|
&V[ j * ldv + i ], 1, ldv, |
104 |
|
&aux |
105 |
|
); |
106 |
|
} |
107 |
|
else // corner case |
108 |
|
{ |
109 |
|
TV vtmp[ MR * NR ]; |
110 |
|
|
111 |
|
if ( pc ) // initilize ctmp |
112 |
|
{ |
113 |
|
for ( auto jj = 0; jj < aux.jb; jj ++ ) |
114 |
|
for ( auto ii = 0; ii < aux.ib; ii ++ ) |
115 |
|
vtmp[ jj * MR + ii ] = V[ ( j + jj ) * ldv + i + ii ]; |
116 |
|
} |
117 |
|
|
118 |
|
semiringkernel |
119 |
|
( |
120 |
|
k, |
121 |
|
&packA[ ip * k ], |
122 |
|
&packB[ jp * k ], |
123 |
|
vtmp, 1, MR, |
124 |
|
&aux |
125 |
|
); |
126 |
|
|
127 |
|
for ( auto jj = 0; jj < aux.jb; jj ++ ) |
128 |
|
for ( auto ii = 0; ii < aux.ib; ii ++ ) |
129 |
|
V[ ( j + jj ) * ldv + i + ii ] = vtmp[ jj * MR + ii ]; |
130 |
|
} |
131 |
|
} // end 2nd loop |
132 |
|
} // end 3rd loop |
133 |
|
} // end rank_k_macro_kernel |
134 |
|
|
135 |
|
|
136 |
|
/** |
137 |
|
* @brief fused_macro_kernel contains the 3rd, 2nd loops and the fused micro |
138 |
|
* kernel. Notice that here C has type TC, which is differnet from the |
139 |
|
* one in rank_k_macro_kernel. ctmp used in the conner case is also |
140 |
|
* type TC. |
141 |
|
*/ |
142 |
|
template< |
143 |
|
int KC, int MR, int NR, int PACK_MR, int PACK_NR, |
144 |
|
bool REUSE_C, |
145 |
|
typename FUSEDKERNEL, |
146 |
|
typename TA, typename TB, typename TC, typename TV> |
147 |
|
void fused_macro_kernel |
148 |
|
( |
149 |
|
Worker &thread, |
150 |
|
int ic, int jc, int pc, |
151 |
|
int m, int n, int k, |
152 |
|
TA *packA, |
153 |
|
TB *packB, |
154 |
|
TC *C, int ldc, |
155 |
|
TV *V, int ldv, |
156 |
|
int batchId, |
157 |
|
FUSEDKERNEL fusedkernel |
158 |
|
) |
159 |
|
{ |
160 |
|
thread_communicator &ic_comm = *thread.ic_comm; |
161 |
|
|
162 |
|
auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() ); |
163 |
|
auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() ); |
164 |
|
auto loop2nd = GetRange( 0, m, MR ); |
165 |
|
auto pack2nd = GetRange( 0, m, PACK_MR ); |
166 |
|
|
167 |
|
for ( int j = loop3rd.beg(), jp = pack3rd.beg(); |
168 |
|
j < loop3rd.end(); |
169 |
|
j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop |
170 |
|
{ |
171 |
|
struct aux_s<TA, TB, TC, TV> aux; |
172 |
|
aux.pc = pc; |
173 |
|
aux.b_next = packB; |
174 |
|
aux.do_packC = 0; |
175 |
|
|
176 |
|
for ( int i = loop2nd.beg(), ip = pack2nd.beg(); |
177 |
|
i < loop2nd.end(); |
178 |
|
i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop |
179 |
|
{ |
180 |
|
// These auxiluary infos are used to access data in the closure of |
181 |
|
// opkernel and opreduce. |
182 |
|
aux.i = ic + i; |
183 |
|
aux.j = jc + j; |
184 |
|
aux.b = batchId; |
185 |
|
|
186 |
|
aux.ib = std::min( m - i, MR ); |
187 |
|
aux.jb = std::min( n - j, NR ); |
188 |
|
|
189 |
|
aux.V = V + j * ldv + i; |
190 |
|
aux.ldv = ldv; |
191 |
|
|
192 |
|
if ( i + MR >= m ) |
193 |
|
{ |
194 |
|
aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k; |
195 |
|
} |
196 |
|
|
197 |
|
if ( aux.jb == NR && aux.ib == MR ) |
198 |
|
{ |
199 |
|
fusedkernel |
200 |
|
( |
201 |
|
k, |
202 |
|
&packA[ ip * k ], |
203 |
|
&packB[ jp * k ], |
204 |
|
&C[ j * ldc + i ], 1, ldc, |
205 |
|
//&C[ ( j / NR ) * ldc + i ], ldc, // for conv_relu_pool |
206 |
|
&aux |
207 |
|
); |
208 |
|
} |
209 |
|
else // corner case |
210 |
|
{ |
211 |
|
TC ctmp[ MR * NR ]; |
212 |
|
TV vtmp[ MR * NR ]; |
213 |
|
|
214 |
|
if ( pc ) // initilize ctmp |
215 |
|
{ |
216 |
|
if ( REUSE_C ) |
217 |
|
{ |
218 |
|
for ( auto jj = 0; jj < aux.jb; jj ++ ) |
219 |
|
for ( auto ii = 0; ii < aux.ib; ii ++ ) |
220 |
|
ctmp[ jj * MR + ii ] = C[ ( j + jj ) * ldc + i + ii ]; |
221 |
|
} |
222 |
|
else |
223 |
|
{ |
224 |
|
for ( auto jj = 0; jj < aux.jb; jj ++ ) |
225 |
|
for ( auto ii = 0; ii < aux.ib; ii ++ ) |
226 |
|
vtmp[ jj * MR + ii ] = V[ ( j + jj ) * ldv + i + ii ]; |
227 |
|
aux.V = vtmp; |
228 |
|
aux.ldv = MR; |
229 |
|
} |
230 |
|
} |
231 |
|
|
232 |
|
fusedkernel |
233 |
|
( |
234 |
|
k, |
235 |
|
&packA[ ip * k ], |
236 |
|
&packB[ jp * k ], |
237 |
|
ctmp, 1, MR, |
238 |
|
&aux |
239 |
|
); |
240 |
|
|
241 |
|
for ( auto jj = 0; jj < aux.jb; jj ++ ) |
242 |
|
for ( auto ii = 0; ii < aux.ib; ii ++ ) |
243 |
|
C[ ( j + jj ) * ldc + i + ii ] = ctmp[ jj * MR + ii ]; |
244 |
|
|
245 |
|
} |
246 |
|
} // end 2nd loop |
247 |
|
} // end 3rd loop |
248 |
|
}; // end fused_macro_kernel |
249 |
|
|
250 |
|
|
251 |
|
|
252 |
|
|
253 |
|
|
254 |
|
/** |
255 |
|
* @breif This function contains the loop body of the 6th to 4th loops, |
256 |
|
* including all packing and unpacking routines. Notice that this |
257 |
|
* function is executed by all threads in the root communicator. |
258 |
|
* To access each thread in different level of communicators, use |
259 |
|
* their ids. |
260 |
|
*/ |
261 |
|
template< |
262 |
|
int MC, |
263 |
|
int NC, |
264 |
|
int KC, |
265 |
|
int MR, |
266 |
|
int NR, |
267 |
|
int PACK_MC, |
268 |
|
int PACK_NC, |
269 |
|
int PACK_MR, |
270 |
|
int PACK_NR, |
271 |
|
int ALIGN_SIZE, |
272 |
|
bool USE_STRASSEN, |
273 |
|
bool REUSE_C, |
274 |
|
typename SEMIRINGKERNEL, typename MICROKERNEL, |
275 |
|
typename TA, typename TB, typename TC, typename TV> |
276 |
|
void gkmx_internal |
277 |
|
( |
278 |
|
Worker &thread, |
279 |
|
hmlpOperation_t transA, hmlpOperation_t transB, |
280 |
|
int m, int n, int k, int k_stra, |
281 |
|
TA *A, int lda, |
282 |
|
TB *B, int ldb, |
283 |
|
TC *C, int ldc, |
284 |
|
TV *V, int ldv, |
285 |
|
int batchId, |
286 |
|
SEMIRINGKERNEL semiringkernel, |
287 |
|
MICROKERNEL microkernel, |
288 |
|
int nc, int pack_nc, |
289 |
|
TA *packA, |
290 |
|
TB *packB |
291 |
|
) |
292 |
|
{ |
293 |
|
packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC |
294 |
|
+ ( thread.ic_id ) * PACK_MC * KC; |
295 |
|
packB += ( thread.jc_id ) * pack_nc * KC; |
296 |
|
|
297 |
|
auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt ); |
298 |
|
auto loop5th = GetRange( k_stra, k, KC ); |
299 |
|
auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt ); |
300 |
|
|
301 |
|
for ( int jc = loop6th.beg(); |
302 |
|
jc < loop6th.end(); |
303 |
|
jc += loop6th.inc() ) // beg 6th loop |
304 |
|
{ |
305 |
|
auto &jc_comm = *thread.jc_comm; |
306 |
|
auto jb = std::min( n - jc, nc ); |
307 |
|
|
308 |
|
for ( int pc = loop5th.beg(); |
309 |
|
pc < loop5th.end(); |
310 |
|
pc += loop5th.inc() ) |
311 |
|
{ |
312 |
|
auto &pc_comm = *thread.pc_comm; |
313 |
|
auto pb = std::min( k - pc, KC ); |
314 |
|
auto is_the_last_pc_iteration = ( pc + KC >= k ); |
315 |
|
auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() ); |
316 |
|
auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() ); |
317 |
|
|
318 |
|
for ( int j = looppkB.beg(), jp = packpkB.beg(); |
319 |
|
j < looppkB.end(); |
320 |
|
j += looppkB.inc(), jp += packpkB.inc() ) |
321 |
|
{ |
322 |
|
if ( transB == HMLP_OP_N ) |
323 |
|
{ |
324 |
|
pack2D<true, PACK_NR> // packB |
325 |
|
( |
326 |
|
std::min( jb - j, NR ), pb, |
327 |
|
&B[ ( jc + j ) * ldb + pc ], ldb, &packB[ jp * pb ] |
328 |
|
); |
329 |
|
} |
330 |
|
else |
331 |
|
{ |
332 |
|
pack2D<false, PACK_NR> // packB (transB) |
333 |
|
( |
334 |
|
std::min( jb - j, NR ), pb, |
335 |
|
&B[ pc * ldb + ( jc + j ) ], ldb, &packB[ jp * pb ] |
336 |
|
); |
337 |
|
} |
338 |
|
} |
339 |
|
pc_comm.Barrier(); |
340 |
|
|
341 |
|
for ( int ic = loop4th.beg(); |
342 |
|
ic < loop4th.end(); |
343 |
|
ic += loop4th.inc() ) // beg 4th loop |
344 |
|
{ |
345 |
|
auto &ic_comm = *thread.ic_comm; |
346 |
|
auto ib = std::min( m - ic, MC ); |
347 |
|
auto looppkA = GetRange( 0, ib, MR, thread.jr_id, thread.jr_nt ); |
348 |
|
auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt ); |
349 |
|
|
350 |
|
for ( int i = looppkA.beg(), ip = packpkA.beg(); |
351 |
|
i < looppkA.end(); |
352 |
|
i += looppkA.inc(), ip += packpkA.inc() ) |
353 |
|
{ |
354 |
|
if ( transA == HMLP_OP_N ) |
355 |
|
{ |
356 |
|
pack2D<false, PACK_MR> // packA |
357 |
|
( |
358 |
|
std::min( ib - i, MR ), pb, |
359 |
|
&A[ pc * lda + ( ic + i ) ], lda, &packA[ ip * pb ] |
360 |
|
); |
361 |
|
} |
362 |
|
else |
363 |
|
{ |
364 |
|
pack2D<true, PACK_MR> // packA (transA) |
365 |
|
( |
366 |
|
std::min( ib - i, MR ), pb, |
367 |
|
&A[ ( ic + i ) * lda + pc ], lda, &packA[ ip * pb ] |
368 |
|
); |
369 |
|
} |
370 |
|
} |
371 |
|
ic_comm.Barrier(); |
372 |
|
|
373 |
|
if ( is_the_last_pc_iteration ) // fused_macro_kernel |
374 |
|
{ |
375 |
|
fused_macro_kernel |
376 |
|
<KC, MR, NR, PACK_MR, PACK_NR, REUSE_C, MICROKERNEL, TA, TB, TC, TV> |
377 |
|
( |
378 |
|
thread, |
379 |
|
ic, jc, pc, |
380 |
|
ib, jb, pb, |
381 |
|
packA, |
382 |
|
packB, |
383 |
|
C + jc * ldc + ic, ldc, |
384 |
|
V + jc * ldv + ic, ldv, // if REUSE_C, then V = C. |
385 |
|
batchId, |
386 |
|
microkernel |
387 |
|
); |
388 |
|
} |
389 |
|
else // semiring rank-k update |
390 |
|
{ |
391 |
|
rank_k_macro_kernel |
392 |
|
<KC, MR, NR, PACK_MR, PACK_NR, SEMIRINGKERNEL, TA, TB, TC, TV> |
393 |
|
( |
394 |
|
thread, |
395 |
|
ic, jc, pc, |
396 |
|
ib, jb, pb, |
397 |
|
packA, |
398 |
|
packB, |
399 |
|
//C + jc * ldc + ic, ldc, |
400 |
|
V + jc * ldv + ic, ldv, |
401 |
|
semiringkernel |
402 |
|
); |
403 |
|
} |
404 |
|
ic_comm.Barrier(); // sync all jr_id!! |
405 |
|
} // end 4th loop |
406 |
|
pc_comm.Barrier(); |
407 |
|
} // end 5th loop |
408 |
|
} // end 6th loop |
409 |
|
} // end gkmx_internal |
410 |
|
|
411 |
|
|
412 |
|
|
413 |
|
|
414 |
|
|
415 |
|
/** |
416 |
|
* @breif This is the main routine of gkmx. All packing buffers are |
417 |
|
* managed here. The communicator and the parallel section |
418 |
|
* start here. |
419 |
|
* |
420 |
|
*/ |
421 |
|
template< |
422 |
|
int MC, |
423 |
|
int NC, |
424 |
|
int KC, |
425 |
|
int MR, |
426 |
|
int NR, |
427 |
|
int PACK_MC, |
428 |
|
int PACK_NC, |
429 |
|
int PACK_MR, |
430 |
|
int PACK_NR, |
431 |
|
int ALIGN_SIZE, |
432 |
|
bool USE_STRASSEN = false, |
433 |
|
bool REUSE_C, |
434 |
|
typename SEMIRINGKERNEL, typename MICROKERNEL, |
435 |
|
typename TA, typename TB, typename TC, typename TV = TC> |
436 |
|
void gkmx |
437 |
|
( |
438 |
|
hmlpOperation_t transA, hmlpOperation_t transB, |
439 |
|
int m, int n, int k, |
440 |
|
TA *A, int lda, |
441 |
|
TB *B, int ldb, |
442 |
|
TC *C, int ldc, |
443 |
|
int batchId, |
444 |
|
SEMIRINGKERNEL semiringkernel, |
445 |
|
MICROKERNEL microkernel |
446 |
|
) |
447 |
|
{ |
448 |
|
int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1; |
449 |
|
int k_stra = 0; |
450 |
|
int ldv = 0; |
451 |
|
int nc = NC, pack_nc = PACK_NC; |
452 |
|
char *str; |
453 |
|
|
454 |
|
TA *packA_buff = NULL; |
455 |
|
TB *packB_buff = NULL; |
456 |
|
TV *V = NULL; |
457 |
|
|
458 |
|
// Early return if possible |
459 |
|
if ( m == 0 || n == 0 || k == 0 ) return; |
460 |
|
|
461 |
|
// type checking (currently assume TC == TV) |
462 |
|
if ( typeid(TC) != typeid(TV) && k > KC ) |
463 |
|
{ |
464 |
|
printf( "gkmx: currently k(%d) must be smaller than %d when TC != TV\n", k, KC ); |
465 |
|
exit( 1 ); |
466 |
|
} |
467 |
|
|
468 |
|
if ( omp_get_num_threads() == 1 && omp_get_max_threads() > 1 ) |
469 |
|
{ |
470 |
|
// Check the environment variable. |
471 |
|
jc_nt = hmlp_read_nway_from_env( "KS_JC_NT" ); |
472 |
|
ic_nt = hmlp_read_nway_from_env( "KS_IC_NT" ); |
473 |
|
jr_nt = hmlp_read_nway_from_env( "KS_JR_NT" ); |
474 |
|
} |
475 |
|
|
476 |
|
if ( jc_nt > 1 ) |
477 |
|
{ |
478 |
|
nc = ( ( n - 1 ) / ( NR * jc_nt ) + 1 ) * NR; |
479 |
|
pack_nc = ( nc / NR ) * PACK_NR; |
480 |
|
} |
481 |
|
|
482 |
|
// allocate packing memory |
483 |
|
packA_buff = hmlp_malloc<ALIGN_SIZE, TA>( KC * ( PACK_MC + 1 ) * jc_nt * ic_nt ); |
484 |
|
packB_buff = hmlp_malloc<ALIGN_SIZE, TB>( KC * ( pack_nc + 1 ) * jc_nt ); |
485 |
|
|
486 |
|
|
487 |
|
// allocate V if k > KC |
488 |
|
if ( k > KC && !std::is_same<TC, TV>::value && !REUSE_C ) |
489 |
|
{ |
490 |
|
V = hmlp_malloc<ALIGN_SIZE, TV>( m * n ); |
491 |
|
ldv = m; |
492 |
|
} |
493 |
|
else // TODO: do not free V in this case. |
494 |
|
{ |
495 |
|
V = reinterpret_cast<TV*>( C ); |
496 |
|
ldv = ldc; |
497 |
|
} |
498 |
|
|
499 |
|
// allocate tree communicator |
500 |
|
thread_communicator my_comm( jc_nt, pc_nt, ic_nt, jr_nt ); |
501 |
|
|
502 |
|
|
503 |
|
if ( USE_STRASSEN ) |
504 |
|
{ |
505 |
|
assert( typeid(TA) == typeid(TB) ); |
506 |
|
assert( typeid(TC) == typeid(TV) ); |
507 |
|
k_stra = k - k % KC; |
508 |
|
|
509 |
|
if ( k_stra == k ) k_stra -= KC; |
510 |
|
|
511 |
|
if ( k_stra ) |
512 |
|
{ |
513 |
|
#pragma omp parallel for |
514 |
|
for ( int i = 0; i < n * ldv; i ++ ) V[ i ] = 0.0; |
515 |
|
} |
516 |
|
} |
517 |
|
|
518 |
|
|
519 |
|
#pragma omp parallel num_threads( my_comm.GetNumThreads() ) |
520 |
|
{ |
521 |
|
Worker thread( &my_comm ); |
522 |
|
|
523 |
|
if ( USE_STRASSEN ) |
524 |
|
{ |
525 |
|
strassen::strassen_internal |
526 |
|
<MC, NC, KC, MR, NR, |
527 |
|
PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, |
528 |
|
USE_STRASSEN, |
529 |
|
SEMIRINGKERNEL, SEMIRINGKERNEL, |
530 |
|
TA, TB, TC, TV> |
531 |
|
( |
532 |
|
thread, |
533 |
|
transA, transB, |
534 |
|
m, n, k_stra, |
535 |
|
A, lda, |
536 |
|
B, ldb, |
537 |
|
V, ldv, |
538 |
|
semiringkernel, semiringkernel, |
539 |
|
nc, pack_nc, |
540 |
|
packA_buff, |
541 |
|
packB_buff |
542 |
|
); |
543 |
|
} |
544 |
|
|
545 |
|
gkmx_internal |
546 |
|
<MC, NC, KC, MR, NR, |
547 |
|
PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, |
548 |
|
USE_STRASSEN, REUSE_C, |
549 |
|
SEMIRINGKERNEL, MICROKERNEL, |
550 |
|
TA, TB, TC, TV> |
551 |
|
( |
552 |
|
thread, |
553 |
|
transA, transB, |
554 |
|
m, n, k, k_stra, |
555 |
|
A, lda, |
556 |
|
B, ldb, |
557 |
|
C, ldc, |
558 |
|
V, ldv, |
559 |
|
batchId, |
560 |
|
semiringkernel, microkernel, |
561 |
|
nc, pack_nc, |
562 |
|
packA_buff, |
563 |
|
packB_buff |
564 |
|
); |
565 |
|
} // end omp parallel |
566 |
|
|
567 |
|
hmlp_free( packA_buff ); |
568 |
|
hmlp_free( packB_buff ); |
569 |
|
//hmlp_free( V ); |
570 |
|
}; // end gkmx |
571 |
|
|
572 |
|
|
573 |
|
|
574 |
|
|
575 |
|
|
576 |
|
/** |
577 |
|
* @beief |
578 |
|
*/ |
579 |
|
template< |
580 |
|
int MC = 104, |
581 |
|
int NC = 1024, |
582 |
|
int KC = 256, |
583 |
|
int MR = 8, |
584 |
|
int NR = 4, |
585 |
|
int PACK_MC = 104, |
586 |
|
int PACK_NC = 1024, |
587 |
|
int PACK_MR = 8, |
588 |
|
int PACK_NR = 4, |
589 |
|
int ALIGN_SIZE = 32, |
590 |
|
bool USE_STRASSEN = false, |
591 |
|
bool REUSE_C = false, |
592 |
|
typename OPKERNEL, typename OP1, typename OP2, |
593 |
|
typename TA, typename TB, typename TC, typename TV> |
594 |
|
void gkmm |
595 |
|
( |
596 |
|
hmlpOperation_t transA, hmlpOperation_t transB, |
597 |
|
int m, int n, int k, |
598 |
|
TA *A, int lda, |
599 |
|
TB *B, int ldb, |
600 |
|
TC *C, int ldc, |
601 |
|
int batchId, |
602 |
|
OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV |
603 |
|
) |
604 |
|
{ |
605 |
|
semiring_mrxnr<MR, NR, OP1, OP2, TA, TB, TC, TV> semiringkernel; |
606 |
|
gkmm_mrxnr<MR, NR, OPKERNEL, OP1, OP2, TA, TB, TC, TV> gkmmkernel; |
607 |
|
|
608 |
|
semiringkernel.op1 = op1; |
609 |
|
semiringkernel.op2 = op2; |
610 |
|
semiringkernel.initV = initV; |
611 |
|
|
612 |
|
gkmmkernel.op1 = op1; |
613 |
|
gkmmkernel.op2 = op2; |
614 |
|
gkmmkernel.opkernel = opkernel; |
615 |
|
gkmmkernel.initV = initV; |
616 |
|
|
617 |
|
gkmx |
618 |
|
<MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, |
619 |
|
USE_STRASSEN, REUSE_C, |
620 |
|
semiring_mrxnr<MR, NR, OP1, OP2, TA, TB, TC, TV>, |
621 |
|
gkmm_mrxnr<MR, NR, OPKERNEL, OP1, OP2, TA, TB, TC, TV>, |
622 |
|
TA, TB, TC, TV> |
623 |
|
( |
624 |
|
transA, transB, |
625 |
|
m, n, k, |
626 |
|
A, lda, |
627 |
|
B, ldb, |
628 |
|
C, ldc, |
629 |
|
batchId, |
630 |
|
semiringkernel, gkmmkernel |
631 |
|
); |
632 |
|
}; |
633 |
|
|
634 |
|
|
635 |
|
/** |
636 |
|
* @brief batched interface with array of arrays |
637 |
|
* |
638 |
|
* TODO: the problem is how to manage thread here? Do I want to use omp |
639 |
|
* nested? or there is a better way to deal with this. |
640 |
|
* |
641 |
|
*/ |
642 |
|
template< |
643 |
|
int MC, int NC, int KC, int MR, int NR, |
644 |
|
int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE, |
645 |
|
bool USE_STRASSEN, bool REUSE_C, |
646 |
|
typename OPKERNEL, typename OP1, typename OP2, |
647 |
|
typename TA, typename TB, typename TC, typename TV> |
648 |
|
void gkmm |
649 |
|
( |
650 |
|
hmlpOperation_t transA, hmlpOperation_t transB, |
651 |
|
int m, int n, int k, |
652 |
|
TA *Aarray[], int lda, |
653 |
|
TB *Barray[], int ldb, |
654 |
|
TC *Carray[], int ldc, |
655 |
|
int batchSize, |
656 |
|
OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV |
657 |
|
) |
658 |
|
{ |
659 |
|
#pragma omp parallel for |
660 |
|
for ( auto b = 0; b < batchSize; b ++ ) |
661 |
|
{ |
662 |
|
gkmm |
663 |
|
<MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, |
664 |
|
USE_STRASSEN, |
665 |
|
OPKERNEL, OP1, OP2, |
666 |
|
TA, TB, TC, TV> |
667 |
|
( |
668 |
|
transA, transB, |
669 |
|
m, n, k, |
670 |
|
Aarray[ b ], lda, |
671 |
|
Barray[ b ], ldb, |
672 |
|
Carray[ b ], ldc, |
673 |
|
b, |
674 |
|
opkernel, op1, op2, initV |
675 |
|
); |
676 |
|
} |
677 |
|
}; // end gkmm |
678 |
|
|
679 |
|
|
680 |
|
/** |
681 |
|
* @brief batched interface with strides |
682 |
|
* |
683 |
|
* TODO: the problem is how to manage thread here? Do I want to use omp |
684 |
|
* nested? or there is a better way to deal with this. |
685 |
|
* |
686 |
|
*/ |
687 |
|
template< |
688 |
|
int MC, |
689 |
|
int NC, |
690 |
|
int KC, int MR, int NR, |
691 |
|
int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE, |
692 |
|
bool USE_STRASSEN, bool REUSE_C, |
693 |
|
typename OPKERNEL, typename OP1, typename OP2, |
694 |
|
typename TA, typename TB, typename TC, typename TV> |
695 |
|
void gkmm |
696 |
|
( |
697 |
|
hmlpOperation_t transA, hmlpOperation_t transB, |
698 |
|
int m, int n, int k, |
699 |
|
TA *Aarray, int lda, int loa, |
700 |
|
TB *Barray, int ldb, int lob, |
701 |
|
TC *Carray, int ldc, int loc, |
702 |
|
int batchSize, |
703 |
|
OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV |
704 |
|
) |
705 |
|
{ |
706 |
|
#pragma omp parallel for |
707 |
|
for ( auto b = 0; b < batchSize; b ++ ) |
708 |
|
{ |
709 |
|
gkmm |
710 |
|
<MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, |
711 |
|
USE_STRASSEN, REUSE_C, |
712 |
|
OPKERNEL, OP1, OP2, |
713 |
|
TA, TB, TC, TV> |
714 |
|
( |
715 |
|
transA, transB, |
716 |
|
m, n, k, |
717 |
|
Aarray + b * loa, lda, |
718 |
|
Barray + b * lob, ldb, |
719 |
|
Carray + b * loc, ldc, |
720 |
|
b, |
721 |
|
opkernel, op1, op2, initV |
722 |
|
); |
723 |
|
} |
724 |
|
}; // end gkmm |
725 |
|
|
726 |
|
|
727 |
|
|
728 |
|
|
729 |
|
|
730 |
|
|
731 |
|
|
732 |
|
|
733 |
|
|
734 |
|
|
735 |
|
|
736 |
|
|
737 |
|
|
738 |
|
|
739 |
|
|
740 |
|
|
741 |
|
|
742 |
|
|
743 |
|
|
744 |
|
/** |
745 |
|
* @beief Implement GKRM with GKMX template. Notice that OPREDUCE |
746 |
|
* is handled inside fusedkernel. Updating microkernel has |
747 |
|
* to be atomic if jc_nt or jr_nt is not 1. We may be atomic |
748 |
|
* update. |
749 |
|
* |
750 |
|
*/ |
751 |
|
template< |
752 |
|
int MC = 104, |
753 |
|
int NC = 1024, |
754 |
|
int KC = 256, |
755 |
|
int MR = 8, |
756 |
|
int NR = 4, |
757 |
|
int PACK_MC = 104, |
758 |
|
int PACK_NC = 1024, |
759 |
|
int PACK_MR = 8, |
760 |
|
int PACK_NR = 4, |
761 |
|
int ALIGN_SIZE = 32, |
762 |
|
bool USE_STRASSEN = false, |
763 |
|
typename OPKERNEL, typename OP1, typename OP2, typename OPREDUCE, |
764 |
|
typename TA, typename TB, typename TC, typename TV = TC> |
765 |
|
void gkrm |
766 |
|
( |
767 |
|
hmlpOperation_t transA, hmlpOperation_t transB, |
768 |
|
int m, int n, int k, |
769 |
|
TA *A, int lda, |
770 |
|
TB *B, int ldb, |
771 |
|
TC *C, int ldc, |
772 |
|
int batchId, |
773 |
|
OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV, |
774 |
|
OPREDUCE opreduce, TC initC |
775 |
|
) |
776 |
|
{ |
777 |
|
semiring_mrxnr<MR, NR, OP1, OP2, TA, TB, TC, TV> semiringkernel; |
778 |
|
gkrm_mrxnr<MR, NR, OPKERNEL, OP1, OP2, OPREDUCE, TA, TB, TC, TV> gkrmkernel; |
779 |
|
|
780 |
|
semiringkernel.op1 = op1; |
781 |
|
semiringkernel.op2 = op2; |
782 |
|
semiringkernel.initV = initV; |
783 |
|
|
784 |
|
gkrmkernel.op1 = op1; |
785 |
|
gkrmkernel.op2 = op2; |
786 |
|
gkrmkernel.opkernel = opkernel; |
787 |
|
gkrmkernel.initV = initV; |
788 |
|
gkrmkernel.opreduce = opreduce; |
789 |
|
gkrmkernel.initC = initC; |
790 |
|
|
791 |
|
gkmx |
792 |
|
<MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, |
793 |
|
USE_STRASSEN, |
794 |
|
semiring_mrxnr<MR, NR, OP1, OP2, TA, TB, TC, TV>, |
795 |
|
gkmm_mrxnr<MR, NR, OPKERNEL, OP1, OP2, TA, TB, TC, TV>, |
796 |
|
TA, TB, TC, TV> |
797 |
|
( |
798 |
|
transA, transB, |
799 |
|
m, n, k, |
800 |
|
A, lda, |
801 |
|
B, ldb, |
802 |
|
C, 0, // TODO: is there a better way to do this? |
803 |
|
batchId, |
804 |
|
semiringkernel, gkrmkernel |
805 |
|
); |
806 |
|
}; // end gkrm |
807 |
|
|
808 |
|
|
809 |
|
|
810 |
|
|
811 |
|
/** |
812 |
|
* @breif This is a simple triple loop reference. |
813 |
|
*/ |
814 |
|
template< |
815 |
|
typename OPKERNEL, typename OP1, typename OP2, |
816 |
|
typename TA, typename TB, typename TC, typename TV = TC> |
817 |
|
void gkmm_ref |
818 |
|
( |
819 |
|
hmlpOperation_t transA, hmlpOperation_t transB, |
820 |
|
int m, int n, int k, |
821 |
|
TA *A, int lda, |
822 |
|
TB *B, int ldb, |
823 |
|
TC *C, int ldc, |
824 |
|
OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV |
825 |
|
) |
826 |
|
{ |
827 |
|
for ( int i = 0; i < m; i ++ ) |
828 |
|
{ |
829 |
|
for ( int j = 0; j < n; j ++ ) |
830 |
|
{ |
831 |
|
auto v = initV; |
832 |
|
for ( int p = 0; p < k; p ++ ) |
833 |
|
{ |
834 |
|
TA a; |
835 |
|
TB b; |
836 |
|
if ( transA == HMLP_OP_N ) a = A[ p * lda + i ]; |
837 |
|
else a = A[ i * lda + p ]; |
838 |
|
if ( transB == HMLP_OP_N ) b = B[ j * ldb + p ]; |
839 |
|
else b = B[ p * ldb + j ]; |
840 |
|
v = op1( v, op2( a, b ) ); |
841 |
|
} |
842 |
|
C[ j * ldc + i ] = opkernel( v ); |
843 |
|
} |
844 |
|
} |
845 |
|
}; // end gkmm_ref |
846 |
|
|
847 |
|
|
848 |
|
/** |
849 |
|
* @breif This is a simple triple loop reference. |
850 |
|
* |
851 |
|
* TODO: ldc is strange here, assuming that C is a vector. |
852 |
|
*/ |
853 |
|
template< |
854 |
|
typename OPKERNEL, typename OP1, typename OP2, typename OPREDUCE, |
855 |
|
typename TA, typename TB, typename TC, typename TV = TC> |
856 |
|
void gkrm_ref |
857 |
|
( |
858 |
|
hmlpOperation_t transA, hmlpOperation_t transB, |
859 |
|
int m, int n, int k, |
860 |
|
TA *A, int lda, |
861 |
|
TB *B, int ldb, |
862 |
|
TC *C, int ldc, |
863 |
|
int batchId, |
864 |
|
OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV, |
865 |
|
OPREDUCE opreduce, TC initC |
866 |
|
) |
867 |
|
{ |
868 |
|
for ( int i = 0; i < m; i ++ ) |
869 |
|
{ |
870 |
|
auto c = initC; |
871 |
|
for ( int j = 0; j < n; j ++ ) |
872 |
|
{ |
873 |
|
auto v = initV; |
874 |
|
for ( int p = 0; p < k; p ++ ) |
875 |
|
{ |
876 |
|
TA a; |
877 |
|
TB b; |
878 |
|
if ( transA == HMLP_OP_N ) a = A[ p * lda + i ]; |
879 |
|
else a = A[ i * lda + p ]; |
880 |
|
if ( transB == HMLP_OP_N ) b = B[ j * ldb + p ]; |
881 |
|
else b = B[ p * ldb + j ]; |
882 |
|
v = op1( v, op2( a, b ) ); |
883 |
|
} |
884 |
|
c = opreduce( c, opkernel( v ) ); |
885 |
|
} |
886 |
|
C[ i ] = c; |
887 |
|
} |
888 |
|
}; // end gkrm_ref |
889 |
|
|
890 |
|
|
891 |
|
}; // end namespace gkmx |
892 |
|
}; // end namespace hmlp |
893 |
|
|
894 |
|
#endif // define GKMX_HPP |