1 |
|
/** |
2 |
|
* HMLP (High-Performance Machine Learning Primitives) |
3 |
|
* |
4 |
|
* Copyright (C) 2014-2017, The University of Texas at Austin |
5 |
|
* |
6 |
|
* This program is free software: you can redistribute it and/or modify |
7 |
|
* it under the terms of the GNU General Public License as published by |
8 |
|
* the Free Software Foundation, either version 3 of the License, or |
9 |
|
* (at your option) any later version. |
10 |
|
* |
11 |
|
* This program is distributed in the hope that it will be useful, |
12 |
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of |
13 |
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
14 |
|
* GNU General Public License for more details. |
15 |
|
* |
16 |
|
* You should have received a copy of the GNU General Public License |
17 |
|
* along with this program. If not, see the LICENSE file. |
18 |
|
* |
19 |
|
**/ |
20 |
|
|
21 |
|
|
22 |
|
#ifndef STRASSEN_HPP |
23 |
|
#define STRASSEN_HPP |
24 |
|
|
25 |
|
#define STRAPRIM( A0,A1,gamma,B0,B1,delta,C0,C1,alpha0,alpha1 ) \ |
26 |
|
straprim \ |
27 |
|
<MC, NC, KC, MR, NR, \ |
28 |
|
PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, \ |
29 |
|
USE_STRASSEN, \ |
30 |
|
STRA_SEMIRINGKERNEL, STRA_MICROKERNEL, \ |
31 |
|
TA, TB, TC, TB> \ |
32 |
|
( \ |
33 |
|
thread, \ |
34 |
|
transA, transB, \ |
35 |
|
md, nd, kd, \ |
36 |
|
A0, A1, lda, gamma, \ |
37 |
|
B0, B1, ldb, delta, \ |
38 |
|
C0, C1, ldc, alpha0, alpha1, \ |
39 |
|
stra_semiringkernel, stra_microkernel, \ |
40 |
|
nc, pack_nc, \ |
41 |
|
packA_buff, \ |
42 |
|
packB_buff \ |
43 |
|
); \ |
44 |
|
|
45 |
|
#define STRAPRIM_MAP( A0,A1,gamma,B0,B1,delta,C0,C1,alpha0,alpha1 ) \ |
46 |
|
straprim \ |
47 |
|
<MC, NC, KC, MR, NR, \ |
48 |
|
PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, \ |
49 |
|
USE_STRASSEN, \ |
50 |
|
STRA_SEMIRINGKERNEL, STRA_MICROKERNEL, \ |
51 |
|
TA, TB, TC, TB> \ |
52 |
|
( \ |
53 |
|
thread, \ |
54 |
|
transA, transB, \ |
55 |
|
md, nd, kd, \ |
56 |
|
A0, A1, lda, gamma, amap, \ |
57 |
|
B0, B1, ldb, delta, bmap, \ |
58 |
|
C0, C1, ldc, alpha0, alpha1, \ |
59 |
|
stra_semiringkernel, stra_microkernel, \ |
60 |
|
nc, pack_nc, \ |
61 |
|
packA_buff, \ |
62 |
|
packB_buff \ |
63 |
|
); \ |
64 |
|
|
65 |
|
#include <hmlp.h> |
66 |
|
#include <hmlp_internal.hpp> |
67 |
|
#include <hmlp_base.hpp> |
68 |
|
|
69 |
|
namespace hmlp |
70 |
|
{ |
71 |
|
namespace strassen |
72 |
|
{ |
73 |
|
|
74 |
|
//#define min( i, j ) ( (i)<(j) ? (i): (j) ) |
75 |
|
|
76 |
|
/** |
77 |
|
* |
78 |
|
*/ |
79 |
|
template< |
80 |
|
int KC, int MR, int NR, int PACK_MR, int PACK_NR, |
81 |
|
typename SEMIRINGKERNEL, |
82 |
|
typename TA, typename TB, typename TC, typename TV> |
83 |
|
void rank_k_macro_kernel |
84 |
|
( |
85 |
|
Worker &thread, |
86 |
|
int ic, int jc, int pc, |
87 |
|
int m, int n, int k, |
88 |
|
TA *packA, |
89 |
|
TB *packB, |
90 |
|
TV *C0, TV *C1, int ldc, TV alpha0, TV alpha1, |
91 |
|
SEMIRINGKERNEL semiringkernel |
92 |
|
) |
93 |
|
{ |
94 |
|
thread_communicator &ic_comm = *thread.ic_comm; |
95 |
|
|
96 |
|
auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() ); |
97 |
|
auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() ); |
98 |
|
auto loop2nd = GetRange( 0, m, MR ); |
99 |
|
auto pack2nd = GetRange( 0, m, PACK_MR ); |
100 |
|
|
101 |
|
for ( int j = loop3rd.beg(), jp = pack3rd.beg(); |
102 |
|
j < loop3rd.end(); |
103 |
|
j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop |
104 |
|
{ |
105 |
|
struct aux_s<TA, TB, TC, TV> aux; |
106 |
|
aux.pc = pc; |
107 |
|
aux.b_next = packB; |
108 |
|
aux.do_packC = 0; |
109 |
|
aux.jb = std::min( n - j, NR ); |
110 |
|
|
111 |
|
for ( int i = loop2nd.beg(), ip = pack2nd.beg(); |
112 |
|
i < loop2nd.end(); |
113 |
|
i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop |
114 |
|
{ |
115 |
|
aux.ib = std::min( m - i, MR ); |
116 |
|
if ( aux.ib != MR ) |
117 |
|
{ |
118 |
|
aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k; |
119 |
|
} |
120 |
|
|
121 |
|
if ( aux.jb == NR && aux.ib == MR ) |
122 |
|
{ |
123 |
|
|
124 |
|
if ( alpha1 == 0 || C1 == NULL ) { |
125 |
|
TV *c_list[1], alpha_list[1]; |
126 |
|
c_list[0] = &C0[ j * ldc + i ]; |
127 |
|
alpha_list[0] = alpha0; |
128 |
|
|
129 |
|
semiringkernel |
130 |
|
( |
131 |
|
k, |
132 |
|
&packA[ ip * k ], |
133 |
|
&packB[ jp * k ], |
134 |
|
1, c_list, ldc, alpha_list, |
135 |
|
&aux |
136 |
|
); |
137 |
|
|
138 |
|
} else { |
139 |
|
|
140 |
|
TV *c_list[2], alpha_list[2]; |
141 |
|
c_list[0] = &C0[ j * ldc + i ]; c_list[1] = &C1[ j * ldc + i ]; |
142 |
|
alpha_list[0] = alpha0; alpha_list[1] = alpha1; |
143 |
|
semiringkernel |
144 |
|
( |
145 |
|
k, |
146 |
|
&packA[ ip * k ], |
147 |
|
&packB[ jp * k ], |
148 |
|
2, c_list, ldc, alpha_list, |
149 |
|
&aux |
150 |
|
); |
151 |
|
|
152 |
|
} |
153 |
|
|
154 |
|
//semiringkernel |
155 |
|
//( |
156 |
|
// k, |
157 |
|
// &packA[ ip * k ], |
158 |
|
// &packB[ jp * k ], |
159 |
|
// &C0[ j * ldc + i ], &C1[ j * ldc + i ], ldc, alpha0, alpha1, |
160 |
|
// &aux |
161 |
|
//); |
162 |
|
|
163 |
|
|
164 |
|
} |
165 |
|
else // corner case |
166 |
|
{ |
167 |
|
|
168 |
|
//printf( "Enter corner case!\n" ); |
169 |
|
// TODO: this should be initC. |
170 |
|
TV ctmp[ MR * NR ] = { (TV)0.0 }; |
171 |
|
|
172 |
|
TV *c_list[1], alpha_list[1]; |
173 |
|
c_list[0] = ctmp; |
174 |
|
alpha_list[0] = 1; |
175 |
|
|
176 |
|
semiringkernel |
177 |
|
( |
178 |
|
k, |
179 |
|
&packA[ ip * k ], |
180 |
|
&packB[ jp * k ], |
181 |
|
//ctmp, MR, |
182 |
|
1, c_list, MR, alpha_list, |
183 |
|
&aux |
184 |
|
); |
185 |
|
|
186 |
|
|
187 |
|
////rank_k_int_d8x4 rankk_semiringkernel; |
188 |
|
////rankk_semiringkernel |
189 |
|
//semiringkernel |
190 |
|
//( |
191 |
|
// k, |
192 |
|
// &packA[ ip * k ], |
193 |
|
// &packB[ jp * k ], |
194 |
|
// //ctmp, MR, |
195 |
|
// ctmp, NULL, MR, 1, 0, |
196 |
|
// &aux |
197 |
|
//); |
198 |
|
//if ( pc ) |
199 |
|
{ |
200 |
|
for ( auto jj = 0; jj < aux.jb; jj ++ ) |
201 |
|
{ |
202 |
|
for ( auto ii = 0; ii < aux.ib; ii ++ ) |
203 |
|
{ |
204 |
|
C0[ ( j + jj ) * ldc + i + ii ] += alpha0 * ctmp[ jj * MR + ii ]; |
205 |
|
|
206 |
|
if ( alpha1 != 0 && C1 != NULL ) { |
207 |
|
C1[ ( j + jj ) * ldc + i + ii ] += alpha1 * ctmp[ jj * MR + ii ]; |
208 |
|
} |
209 |
|
} |
210 |
|
} |
211 |
|
} |
212 |
|
//else |
213 |
|
//{ |
214 |
|
// for ( auto jj = 0; jj < aux.jb; jj ++ ) |
215 |
|
// { |
216 |
|
// for ( auto ii = 0; ii < aux.ib; ii ++ ) |
217 |
|
// { |
218 |
|
// C0[ ( j + jj ) * ldc + i + ii ] = alpha0 * ctmp[ jj * MR + ii ]; |
219 |
|
|
220 |
|
// if ( alpha1 != 0 && C1 != NULL ) { |
221 |
|
// C1[ ( j + jj ) * ldc + i + ii ] = alpha1 * ctmp[ jj * MR + ii ]; |
222 |
|
// } |
223 |
|
// } |
224 |
|
// } |
225 |
|
//} |
226 |
|
} |
227 |
|
} // end 2nd loop |
228 |
|
} // end 3rd loop |
229 |
|
} // end rank_k_macro_kernel |
230 |
|
|
231 |
|
/** |
232 |
|
* |
233 |
|
*/ |
234 |
|
//template<int KC, int MR, int NR, int PACK_MR, int PACK_NR, |
235 |
|
// typename MICROKERNEL, |
236 |
|
// typename TA, typename TB, typename TC, typename TV> |
237 |
|
//void fused_macro_kernel |
238 |
|
//( |
239 |
|
// Worker &thread, |
240 |
|
// int ic, int jc, int pc, |
241 |
|
// int m, int n, int k, |
242 |
|
// TA *packA, |
243 |
|
// TB *packB, |
244 |
|
// TV *C0, TV *C1, int ldc, TV alpha0, TV alpha1, |
245 |
|
// MICROKERNEL microkernel |
246 |
|
//) |
247 |
|
//{ |
248 |
|
// thread_communicator &ic_comm = *thread.ic_comm; |
249 |
|
// |
250 |
|
// auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() ); |
251 |
|
// auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() ); |
252 |
|
// auto loop2nd = GetRange( 0, m, MR ); |
253 |
|
// auto pack2nd = GetRange( 0, m, PACK_MR ); |
254 |
|
// |
255 |
|
// for ( int j = loop3rd.beg(), jp = pack3rd.beg(); |
256 |
|
// j < loop3rd.end(); |
257 |
|
// j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop |
258 |
|
// { |
259 |
|
// struct aux_s<TA, TB, TC, TV> aux; |
260 |
|
// aux.pc = pc; |
261 |
|
// aux.b_next = packB; |
262 |
|
// aux.do_packC = 0; |
263 |
|
// aux.jb = std::min( n - j, NR ); |
264 |
|
// |
265 |
|
// for ( int i = loop2nd.beg(), ip = pack2nd.beg(); |
266 |
|
// i < loop2nd.end(); |
267 |
|
// i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop |
268 |
|
// { |
269 |
|
// aux.ib = std::min( m - i, MR ); |
270 |
|
// if ( aux.ib != MR ) |
271 |
|
// { |
272 |
|
// aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k; |
273 |
|
// } |
274 |
|
// |
275 |
|
// if ( aux.jb == NR && aux.ib == MR ) |
276 |
|
// { |
277 |
|
// |
278 |
|
// if ( alpha1 == 0 || C1 == NULL ) { |
279 |
|
// |
280 |
|
// double *c_list[1], alpha_list[1]; |
281 |
|
// c_list[0] = &C0[ j * ldc + i ]; |
282 |
|
// alpha_list[0] = alpha0; |
283 |
|
// |
284 |
|
// microkernel |
285 |
|
// ( |
286 |
|
// k, |
287 |
|
// &packA[ ip * k ], |
288 |
|
// &packB[ jp * k ], |
289 |
|
// 1, c_list, ldc, alpha_list, |
290 |
|
// &aux |
291 |
|
// ); |
292 |
|
// } else { |
293 |
|
// |
294 |
|
// double *c_list[2], alpha_list[2]; |
295 |
|
// c_list[0] = &C0[ j * ldc + i ]; c_list[1] = &C1[ j * ldc + i ]; |
296 |
|
// alpha_list[0] = alpha0; alpha_list[1] = alpha1; |
297 |
|
// |
298 |
|
// microkernel |
299 |
|
// ( |
300 |
|
// k, |
301 |
|
// &packA[ ip * k ], |
302 |
|
// &packB[ jp * k ], |
303 |
|
// 2, c_list, ldc, alpha_list, |
304 |
|
// &aux |
305 |
|
// ); |
306 |
|
// |
307 |
|
// } |
308 |
|
// |
309 |
|
// |
310 |
|
// //microkernel |
311 |
|
// //( |
312 |
|
// // k, |
313 |
|
// // &packA[ ip * k ], |
314 |
|
// // &packB[ jp * k ], |
315 |
|
// // &C0[ j * ldc + i ], &C1[ j * ldc + i ], ldc, alpha0, alpha1, |
316 |
|
// // &aux |
317 |
|
// //); |
318 |
|
// } |
319 |
|
// else // corner case |
320 |
|
// { |
321 |
|
// //printf( "Enter corner case!\n" ); |
322 |
|
// // TODO: this should be initC. |
323 |
|
// TV ctmp[ MR * NR ] = { (TV)0.0 }; |
324 |
|
// |
325 |
|
// double *c_list[1], alpha_list[1]; |
326 |
|
// c_list[0] = ctmp; |
327 |
|
// alpha_list[0] = 1; |
328 |
|
// |
329 |
|
// microkernel |
330 |
|
// ( |
331 |
|
// k, |
332 |
|
// &packA[ ip * k ], |
333 |
|
// &packB[ jp * k ], |
334 |
|
// //ctmp, MR, |
335 |
|
// 1, c_list, MR, alpha_list, |
336 |
|
// &aux |
337 |
|
// ); |
338 |
|
// |
339 |
|
// ////rank_k_int_d8x4 rankk_microkernel; |
340 |
|
// ////rankk_microkernel |
341 |
|
// //microkernel |
342 |
|
// //( |
343 |
|
// // k, |
344 |
|
// // &packA[ ip * k ], |
345 |
|
// // &packB[ jp * k ], |
346 |
|
// // //ctmp, MR, |
347 |
|
// // ctmp, NULL, MR, 1, 0, |
348 |
|
// // &aux |
349 |
|
// //); |
350 |
|
// |
351 |
|
// //if ( pc ) |
352 |
|
// { |
353 |
|
// for ( auto jj = 0; jj < aux.jb; jj ++ ) |
354 |
|
// { |
355 |
|
// for ( auto ii = 0; ii < aux.ib; ii ++ ) |
356 |
|
// { |
357 |
|
// C0[ ( j + jj ) * ldc + i + ii ] += alpha0 * ctmp[ jj * MR + ii ]; |
358 |
|
// |
359 |
|
// if ( alpha1 != 0 && C1 != NULL ) { |
360 |
|
// C1[ ( j + jj ) * ldc + i + ii ] += alpha1 * ctmp[ jj * MR + ii ]; |
361 |
|
// } |
362 |
|
// } |
363 |
|
// } |
364 |
|
// } |
365 |
|
// //else |
366 |
|
// //{ |
367 |
|
// // for ( auto jj = 0; jj < aux.jb; jj ++ ) |
368 |
|
// // { |
369 |
|
// // for ( auto ii = 0; ii < aux.ib; ii ++ ) |
370 |
|
// // { |
371 |
|
// // C0[ ( j + jj ) * ldc + i + ii ] = alpha0 * ctmp[ jj * MR + ii ]; |
372 |
|
// |
373 |
|
// // if ( alpha1 != 0 && C1 != NULL ) { |
374 |
|
// // C1[ ( j + jj ) * ldc + i + ii ] = alpha1 * ctmp[ jj * MR + ii ]; |
375 |
|
// // } |
376 |
|
// // } |
377 |
|
// // } |
378 |
|
// //} |
379 |
|
// } |
380 |
|
// } // end 2nd loop |
381 |
|
// } // end 3rd loop |
382 |
|
//} // end fused_macro_kernel |
383 |
|
|
384 |
|
|
385 |
|
/* |
386 |
|
* |
387 |
|
*/ |
388 |
|
template< |
389 |
|
int MC, int NC, int KC, int MR, int NR, |
390 |
|
int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE, |
391 |
|
bool USE_STRASSEN, |
392 |
|
typename STRA_SEMIRINGKERNEL, typename STRA_MICROKERNEL, |
393 |
|
typename TA, typename TB, typename TC, typename TV> |
394 |
|
void straprim |
395 |
|
( |
396 |
|
Worker &thread, |
397 |
|
hmlpOperation_t transA, hmlpOperation_t transB, |
398 |
|
int m, int n, int k, |
399 |
|
TA *A0, TA *A1, int lda, TA gamma, |
400 |
|
TB *B0, TB *B1, int ldb, TB delta, |
401 |
|
TV *C0, TV *C1, int ldc, TV alpha0, TV alpha1, |
402 |
|
STRA_SEMIRINGKERNEL stra_semiringkernel, |
403 |
|
STRA_MICROKERNEL stra_microkernel, |
404 |
|
int nc, int pack_nc, |
405 |
|
TA *packA, |
406 |
|
TB *packB |
407 |
|
) |
408 |
|
{ |
409 |
|
//printf( "m: %d, n: %d, k: %d\n", m, n, k ); |
410 |
|
|
411 |
|
packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC |
412 |
|
+ ( thread.ic_id ) * PACK_MC * KC; |
413 |
|
packB += ( thread.jc_id ) * pack_nc * KC; |
414 |
|
|
415 |
|
auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt ); |
416 |
|
auto loop5th = GetRange( 0, k, KC ); |
417 |
|
auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt ); |
418 |
|
|
419 |
|
for ( int jc = loop6th.beg(); |
420 |
|
jc < loop6th.end(); |
421 |
|
jc += loop6th.inc() ) // beg 6th loop |
422 |
|
{ |
423 |
|
auto &jc_comm = *thread.jc_comm; |
424 |
|
auto jb = std::min( n - jc, nc ); |
425 |
|
|
426 |
|
for ( int pc = loop5th.beg(); |
427 |
|
pc < loop5th.end(); |
428 |
|
pc += loop5th.inc() ) |
429 |
|
{ |
430 |
|
auto &pc_comm = *thread.pc_comm; |
431 |
|
auto pb = std::min( k - pc, KC ); |
432 |
|
auto is_the_last_pc_iteration = ( pc + KC >= k ); |
433 |
|
auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() ); |
434 |
|
auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() ); |
435 |
|
|
436 |
|
for ( int j = looppkB.beg(), jp = packpkB.beg(); |
437 |
|
j < looppkB.end(); |
438 |
|
j += looppkB.inc(), jp += packpkB.inc() ) |
439 |
|
{ |
440 |
|
|
441 |
|
//printf( "before packB\n" ); |
442 |
|
if ( transB == HMLP_OP_N ) |
443 |
|
{ |
444 |
|
|
445 |
|
if ( delta == 0 || B1 == NULL ) { |
446 |
|
pack2D<true, PACK_NR> // packB |
447 |
|
( |
448 |
|
std::min( jb - j, NR ), pb, |
449 |
|
&B0[ ( jc + j ) * ldb + pc ], ldb, &packB[ jp * pb ] |
450 |
|
); |
451 |
|
} else { |
452 |
|
|
453 |
|
pack2D<true, PACK_NR> // packB |
454 |
|
( |
455 |
|
std::min( jb - j, NR ), pb, |
456 |
|
&B0[ ( jc + j ) * ldb + pc ], &B1[ ( jc + j ) * ldb + pc ], ldb, delta, &packB[ jp * pb ] |
457 |
|
); |
458 |
|
|
459 |
|
} |
460 |
|
|
461 |
|
} |
462 |
|
else |
463 |
|
{ |
464 |
|
if ( delta == 0 || B1 == NULL ) { |
465 |
|
pack2D<false, PACK_NR> // packB (transB) |
466 |
|
( |
467 |
|
std::min( jb - j, NR ), pb, |
468 |
|
&B0[ pc * ldb + ( jc + j ) ], ldb, &packB[ jp * pb ] |
469 |
|
); |
470 |
|
} else { |
471 |
|
|
472 |
|
//printf( "before pack2D\n" ); |
473 |
|
//printf( "B1[%d]=%lf\n", pc * ldb + ( jc + j ), B1[ pc * ldb + ( jc + j ) ] ); |
474 |
|
|
475 |
|
pack2D<false, PACK_NR> // packB (transB) |
476 |
|
( |
477 |
|
std::min( jb - j, NR ), pb, |
478 |
|
&B0[ pc * ldb + ( jc + j ) ], &B1[ pc * ldb + ( jc + j ) ], ldb, delta, &packB[ jp * pb ] |
479 |
|
); |
480 |
|
//printf( "after pack2D\n" ); |
481 |
|
|
482 |
|
} |
483 |
|
|
484 |
|
} |
485 |
|
//printf( "After packB\n" ); |
486 |
|
} |
487 |
|
pc_comm.Barrier(); |
488 |
|
|
489 |
|
//printf( "packB:\n" ); |
490 |
|
//hmlp_printmatrix( 4, 1, packB, PACK_NR ); |
491 |
|
|
492 |
|
|
493 |
|
|
494 |
|
for ( int ic = loop4th.beg(); |
495 |
|
ic < loop4th.end(); |
496 |
|
ic += loop4th.inc() ) // beg 4th loop |
497 |
|
{ |
498 |
|
auto &ic_comm = *thread.ic_comm; |
499 |
|
auto ib = std::min( m - ic, MC ); |
500 |
|
auto looppkA = GetRange( 0, ib, MR, thread.jr_id, thread.jr_nt ); |
501 |
|
auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt ); |
502 |
|
|
503 |
|
for ( int i = looppkA.beg(), ip = packpkA.beg(); |
504 |
|
i < looppkA.end(); |
505 |
|
i += looppkA.inc(), ip += packpkA.inc() ) |
506 |
|
{ |
507 |
|
|
508 |
|
//printf( "Before packA\n" ); |
509 |
|
|
510 |
|
if ( transA == HMLP_OP_N ) |
511 |
|
{ |
512 |
|
|
513 |
|
if ( gamma == 0 || A1 == NULL ) { |
514 |
|
pack2D<false, PACK_MR> // packA |
515 |
|
( |
516 |
|
std::min( ib - i, MR ), pb, |
517 |
|
&A0[ pc * lda + ( ic + i ) ], lda, &packA[ ip * pb ] |
518 |
|
); |
519 |
|
} else { |
520 |
|
|
521 |
|
//printf( "flag1\n" ); |
522 |
|
pack2D<false, PACK_MR> // packA |
523 |
|
( |
524 |
|
std::min( ib - i, MR ), pb, |
525 |
|
&A0[ pc * lda + ( ic + i ) ], &A1[ pc * lda + ( ic + i ) ], lda, gamma, &packA[ ip * pb ] |
526 |
|
); |
527 |
|
//printf( "flag2\n" ); |
528 |
|
} |
529 |
|
|
530 |
|
} |
531 |
|
else |
532 |
|
{ |
533 |
|
|
534 |
|
if ( gamma == 0 || A1 == NULL ) { |
535 |
|
pack2D<true, PACK_MR> // packA (transA) |
536 |
|
( |
537 |
|
std::min( ib - i, MR ), pb, |
538 |
|
&A0[ ( ic + i ) * lda + pc ], lda, &packA[ ip * pb ] |
539 |
|
); |
540 |
|
} else { |
541 |
|
pack2D<true, PACK_MR> // packA (transA) |
542 |
|
( |
543 |
|
std::min( ib - i, MR ), pb, |
544 |
|
&A0[ ( ic + i ) * lda + pc ], &A1[ ( ic + i ) * lda + pc ], lda, gamma, &packA[ ip * pb ] |
545 |
|
); |
546 |
|
} |
547 |
|
|
548 |
|
} |
549 |
|
|
550 |
|
//printf( "After packA\n" ); |
551 |
|
} |
552 |
|
ic_comm.Barrier(); |
553 |
|
|
554 |
|
// if ( is_the_last_pc_iteration ) // fused_macro_kernel |
555 |
|
// { |
556 |
|
// if ( alpha1 == 0 || C1 == NULL ) { |
557 |
|
// |
558 |
|
// //hmlp::gkmx::fused_macro_kernel |
559 |
|
// //<KC, MR, NR, PACK_MR, PACK_NR, RANK_MICROKERNEL, TA, TB, TC, TV> |
560 |
|
// //( |
561 |
|
// // thread, |
562 |
|
// // ic, jc, pc, |
563 |
|
// // ib, jb, pb, |
564 |
|
// // packA, |
565 |
|
// // packB, |
566 |
|
// // C0 + jc * ldc + ic, ldc, |
567 |
|
// // rank_microkernel |
568 |
|
// //); |
569 |
|
// |
570 |
|
// //printf( "before fused macro kernel\n" ); |
571 |
|
// fused_macro_kernel |
572 |
|
// <KC, MR, NR, PACK_MR, PACK_NR, STRA_MICROKERNEL, TA, TB, TC, TV> |
573 |
|
// ( |
574 |
|
// thread, |
575 |
|
// ic, jc, pc, |
576 |
|
// ib, jb, pb, |
577 |
|
// packA, |
578 |
|
// packB, |
579 |
|
// C0 + jc * ldc + ic, |
580 |
|
// NULL, ldc, alpha0, 0, |
581 |
|
// stra_microkernel |
582 |
|
// ); |
583 |
|
// //printf( "after fused macro kernel\n" ); |
584 |
|
// |
585 |
|
// } else { |
586 |
|
// fused_macro_kernel |
587 |
|
// <KC, MR, NR, PACK_MR, PACK_NR, STRA_MICROKERNEL, TA, TB, TC, TV> |
588 |
|
// ( |
589 |
|
// thread, |
590 |
|
// ic, jc, pc, |
591 |
|
// ib, jb, pb, |
592 |
|
// packA, |
593 |
|
// packB, |
594 |
|
// C0 + jc * ldc + ic, |
595 |
|
// C1 + jc * ldc + ic, ldc, alpha0, alpha1, |
596 |
|
// stra_microkernel |
597 |
|
// ); |
598 |
|
// } |
599 |
|
// |
600 |
|
// } |
601 |
|
// else // semiring rank-k update |
602 |
|
// { |
603 |
|
|
604 |
|
if ( alpha1 == 0 || C1 == NULL ) |
605 |
|
{ |
606 |
|
//hmlp::gkmx::rank_k_macro_kernel |
607 |
|
//<KC, MR, NR, PACK_MR, PACK_NR, RANK_SEMIRINGKERNEL, TA, TB, TC, TV> |
608 |
|
//( |
609 |
|
// thread, |
610 |
|
// ic, jc, pc, |
611 |
|
// ib, jb, pb, |
612 |
|
// packA, |
613 |
|
// packB, |
614 |
|
// C0 + jc * ldc + ic, ldc, |
615 |
|
// rank_semiringkernel |
616 |
|
//); |
617 |
|
|
618 |
|
rank_k_macro_kernel |
619 |
|
//strassen_macro_kernel |
620 |
|
<KC, MR, NR, PACK_MR, PACK_NR, STRA_SEMIRINGKERNEL, TA, TB, TC, TV> |
621 |
|
( |
622 |
|
thread, |
623 |
|
ic, jc, pc, |
624 |
|
ib, jb, pb, |
625 |
|
packA, |
626 |
|
packB, |
627 |
|
C0 + jc * ldc + ic, |
628 |
|
NULL, ldc, alpha0, 0, |
629 |
|
stra_semiringkernel |
630 |
|
); |
631 |
|
|
632 |
|
} |
633 |
|
else |
634 |
|
{ |
635 |
|
|
636 |
|
rank_k_macro_kernel |
637 |
|
//strassen_macro_kernel |
638 |
|
<KC, MR, NR, PACK_MR, PACK_NR, STRA_SEMIRINGKERNEL, TA, TB, TC, TV> |
639 |
|
( |
640 |
|
thread, |
641 |
|
ic, jc, pc, |
642 |
|
ib, jb, pb, |
643 |
|
packA, |
644 |
|
packB, |
645 |
|
C0 + jc * ldc + ic, |
646 |
|
C1 + jc * ldc + ic, ldc, alpha0, alpha1, |
647 |
|
stra_semiringkernel |
648 |
|
); |
649 |
|
|
650 |
|
} |
651 |
|
|
652 |
|
// } |
653 |
|
ic_comm.Barrier(); // sync all jr_id!! |
654 |
|
} // end 4th loop |
655 |
|
pc_comm.Barrier(); |
656 |
|
} // end 5th loop |
657 |
|
} // end 6th loop |
658 |
|
} // end strassen_internal |
659 |
|
|
660 |
|
|
661 |
|
|
662 |
|
|
663 |
|
/* |
664 |
|
* |
665 |
|
*/ |
666 |
|
template< |
667 |
|
int MC, int NC, int KC, int MR, int NR, |
668 |
|
int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE, |
669 |
|
bool USE_STRASSEN, |
670 |
|
typename STRA_SEMIRINGKERNEL, typename STRA_MICROKERNEL, |
671 |
|
typename TA, typename TB, typename TC, typename TV> |
672 |
|
void straprim |
673 |
|
( |
674 |
|
Worker &thread, |
675 |
|
hmlpOperation_t transA, hmlpOperation_t transB, |
676 |
|
int m, int n, int k, |
677 |
|
TA *A0, TA *A1, int lda, TA gamma, int *amap, |
678 |
|
TB *B0, TB *B1, int ldb, TB delta, int *bmap, |
679 |
|
TV *C0, TV *C1, int ldc, TV alpha0, TV alpha1, |
680 |
|
STRA_SEMIRINGKERNEL stra_semiringkernel, |
681 |
|
STRA_MICROKERNEL stra_microkernel, |
682 |
|
int nc, int pack_nc, |
683 |
|
TA *packA, |
684 |
|
TB *packB |
685 |
|
) |
686 |
|
{ |
687 |
|
//printf( "m: %d, n: %d, k: %d\n", m, n, k ); |
688 |
|
|
689 |
|
packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC |
690 |
|
+ ( thread.ic_id ) * PACK_MC * KC; |
691 |
|
packB += ( thread.jc_id ) * pack_nc * KC; |
692 |
|
|
693 |
|
auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt ); |
694 |
|
auto loop5th = GetRange( 0, k, KC ); |
695 |
|
auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt ); |
696 |
|
|
697 |
|
for ( int jc = loop6th.beg(); |
698 |
|
jc < loop6th.end(); |
699 |
|
jc += loop6th.inc() ) // beg 6th loop |
700 |
|
{ |
701 |
|
auto &jc_comm = *thread.jc_comm; |
702 |
|
auto jb = std::min( n - jc, nc ); |
703 |
|
|
704 |
|
for ( int pc = loop5th.beg(); |
705 |
|
pc < loop5th.end(); |
706 |
|
pc += loop5th.inc() ) |
707 |
|
{ |
708 |
|
auto &pc_comm = *thread.pc_comm; |
709 |
|
auto pb = std::min( k - pc, KC ); |
710 |
|
auto is_the_last_pc_iteration = ( pc + KC >= k ); |
711 |
|
auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() ); |
712 |
|
auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() ); |
713 |
|
|
714 |
|
for ( int j = looppkB.beg(), jp = packpkB.beg(); |
715 |
|
j < looppkB.end(); |
716 |
|
j += looppkB.inc(), jp += packpkB.inc() ) |
717 |
|
{ |
718 |
|
|
719 |
|
//printf( "before packB\n" ); |
720 |
|
if ( transB == HMLP_OP_N ) |
721 |
|
{ |
722 |
|
|
723 |
|
if ( delta == 0 || B1 == NULL ) { |
724 |
|
// ldb == k |
725 |
|
pack2D<true, PACK_NR> // packB |
726 |
|
( |
727 |
|
std::min( jb - j, NR ), pb, |
728 |
|
&B0[ pc ], ldb, &bmap[ jc + j ], &packB[ jp * pb ] |
729 |
|
); |
730 |
|
} else { |
731 |
|
pack2D<true, PACK_NR> // packB |
732 |
|
( |
733 |
|
std::min( jb - j, NR ), pb, |
734 |
|
&B0[ pc ], &B1[ pc ], ldb, delta, &bmap[ jc + j ], &packB[ jp * pb ] |
735 |
|
); |
736 |
|
} |
737 |
|
|
738 |
|
} |
739 |
|
else |
740 |
|
{ |
741 |
|
if ( delta == 0 || B1 == NULL ) { |
742 |
|
pack2D<false, PACK_NR> // packB (transB) |
743 |
|
( |
744 |
|
std::min( jb - j, NR ), pb, |
745 |
|
&B0[ pc ], ldb, &bmap[ jc + j ], &packB[ jp * pb ] |
746 |
|
); |
747 |
|
} else { |
748 |
|
pack2D<false, PACK_NR> // packB (transB) |
749 |
|
( |
750 |
|
std::min( jb - j, NR ), pb, |
751 |
|
&B0[ pc ], &B1[ pc ], ldb, delta, &bmap[ jc + j ], &packB[ jp * pb ] |
752 |
|
); |
753 |
|
|
754 |
|
|
755 |
|
} |
756 |
|
|
757 |
|
} |
758 |
|
|
759 |
|
} |
760 |
|
pc_comm.Barrier(); |
761 |
|
|
762 |
|
for ( int ic = loop4th.beg(); |
763 |
|
ic < loop4th.end(); |
764 |
|
ic += loop4th.inc() ) // beg 4th loop |
765 |
|
{ |
766 |
|
auto &ic_comm = *thread.ic_comm; |
767 |
|
auto ib = std::min( m - ic, MC ); |
768 |
|
auto looppkA = GetRange( 0, ib, MR, thread.jr_id, thread.jr_nt ); |
769 |
|
auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt ); |
770 |
|
|
771 |
|
for ( int i = looppkA.beg(), ip = packpkA.beg(); |
772 |
|
i < looppkA.end(); |
773 |
|
i += looppkA.inc(), ip += packpkA.inc() ) |
774 |
|
{ |
775 |
|
|
776 |
|
//assert( lda == k ); |
777 |
|
//For transpose cases, lda should be equal to k. |
778 |
|
|
779 |
|
if ( transA == HMLP_OP_N ) |
780 |
|
{ |
781 |
|
|
782 |
|
if ( gamma == 0 || A1 == NULL ) { |
783 |
|
pack2D<false, PACK_MR> // packA |
784 |
|
( |
785 |
|
std::min( ib - i, MR ), pb, |
786 |
|
&A0[ pc ], lda, &amap[ ic + i ], &packA[ ip * pb ] |
787 |
|
); |
788 |
|
} else { |
789 |
|
pack2D<false, PACK_MR> // packA |
790 |
|
( |
791 |
|
std::min( ib - i, MR ), pb, |
792 |
|
&A0[ pc ], &A1[ pc ], lda, gamma, &amap[ ic + i ], &packA[ ip * pb ] |
793 |
|
); |
794 |
|
} |
795 |
|
|
796 |
|
} |
797 |
|
else |
798 |
|
{ |
799 |
|
|
800 |
|
if ( gamma == 0 || A1 == NULL ) { |
801 |
|
pack2D<true, PACK_MR> // packA (transA) |
802 |
|
( |
803 |
|
std::min( ib - i, MR ), pb, |
804 |
|
&A0[ pc ], lda, &amap[ ic + i ], &packA[ ip * pb ] |
805 |
|
); |
806 |
|
} else { |
807 |
|
pack2D<true, PACK_MR> // packA (transA) |
808 |
|
( |
809 |
|
std::min( ib - i, MR ), pb, |
810 |
|
&A0[ pc ], &A1[ pc ], lda, gamma, &amap[ ic + i ], &packA[ ip * pb ] |
811 |
|
); |
812 |
|
|
813 |
|
} |
814 |
|
|
815 |
|
} |
816 |
|
|
817 |
|
} |
818 |
|
ic_comm.Barrier(); |
819 |
|
|
820 |
|
// if ( is_the_last_pc_iteration ) // fused_macro_kernel |
821 |
|
// { |
822 |
|
// if ( alpha1 == 0 || C1 == NULL ) { |
823 |
|
// |
824 |
|
// //hmlp::gkmx::fused_macro_kernel |
825 |
|
// //<KC, MR, NR, PACK_MR, PACK_NR, RANK_MICROKERNEL, TA, TB, TC, TV> |
826 |
|
// //( |
827 |
|
// // thread, |
828 |
|
// // ic, jc, pc, |
829 |
|
// // ib, jb, pb, |
830 |
|
// // packA, |
831 |
|
// // packB, |
832 |
|
// // C0 + jc * ldc + ic, ldc, |
833 |
|
// // rank_microkernel |
834 |
|
// //); |
835 |
|
// |
836 |
|
// //printf( "before fused macro kernel\n" ); |
837 |
|
// fused_macro_kernel |
838 |
|
// <KC, MR, NR, PACK_MR, PACK_NR, STRA_MICROKERNEL, TA, TB, TC, TV> |
839 |
|
// ( |
840 |
|
// thread, |
841 |
|
// ic, jc, pc, |
842 |
|
// ib, jb, pb, |
843 |
|
// packA, |
844 |
|
// packB, |
845 |
|
// C0 + jc * ldc + ic, |
846 |
|
// NULL, ldc, alpha0, 0, |
847 |
|
// stra_microkernel |
848 |
|
// ); |
849 |
|
// //printf( "after fused macro kernel\n" ); |
850 |
|
// |
851 |
|
// } else { |
852 |
|
// fused_macro_kernel |
853 |
|
// <KC, MR, NR, PACK_MR, PACK_NR, STRA_MICROKERNEL, TA, TB, TC, TV> |
854 |
|
// ( |
855 |
|
// thread, |
856 |
|
// ic, jc, pc, |
857 |
|
// ib, jb, pb, |
858 |
|
// packA, |
859 |
|
// packB, |
860 |
|
// C0 + jc * ldc + ic, |
861 |
|
// C1 + jc * ldc + ic, ldc, alpha0, alpha1, |
862 |
|
// stra_microkernel |
863 |
|
// ); |
864 |
|
// } |
865 |
|
// |
866 |
|
// } |
867 |
|
// else // semiring rank-k update |
868 |
|
// { |
869 |
|
|
870 |
|
if ( alpha1 == 0 || C1 == NULL ) |
871 |
|
{ |
872 |
|
//hmlp::gkmx::rank_k_macro_kernel |
873 |
|
//<KC, MR, NR, PACK_MR, PACK_NR, RANK_SEMIRINGKERNEL, TA, TB, TC, TV> |
874 |
|
//( |
875 |
|
// thread, |
876 |
|
// ic, jc, pc, |
877 |
|
// ib, jb, pb, |
878 |
|
// packA, |
879 |
|
// packB, |
880 |
|
// C0 + jc * ldc + ic, ldc, |
881 |
|
// rank_semiringkernel |
882 |
|
//); |
883 |
|
|
884 |
|
rank_k_macro_kernel |
885 |
|
//strassen_macro_kernel |
886 |
|
<KC, MR, NR, PACK_MR, PACK_NR, STRA_SEMIRINGKERNEL, TA, TB, TC, TV> |
887 |
|
( |
888 |
|
thread, |
889 |
|
ic, jc, pc, |
890 |
|
ib, jb, pb, |
891 |
|
packA, |
892 |
|
packB, |
893 |
|
C0 + jc * ldc + ic, |
894 |
|
NULL, ldc, alpha0, 0, |
895 |
|
stra_semiringkernel |
896 |
|
); |
897 |
|
|
898 |
|
} |
899 |
|
else |
900 |
|
{ |
901 |
|
|
902 |
|
rank_k_macro_kernel |
903 |
|
//strassen_macro_kernel |
904 |
|
<KC, MR, NR, PACK_MR, PACK_NR, STRA_SEMIRINGKERNEL, TA, TB, TC, TV> |
905 |
|
( |
906 |
|
thread, |
907 |
|
ic, jc, pc, |
908 |
|
ib, jb, pb, |
909 |
|
packA, |
910 |
|
packB, |
911 |
|
C0 + jc * ldc + ic, |
912 |
|
C1 + jc * ldc + ic, ldc, alpha0, alpha1, |
913 |
|
stra_semiringkernel |
914 |
|
); |
915 |
|
|
916 |
|
} |
917 |
|
|
918 |
|
// } |
919 |
|
ic_comm.Barrier(); // sync all jr_id!! |
920 |
|
} // end 4th loop |
921 |
|
pc_comm.Barrier(); |
922 |
|
} // end 5th loop |
923 |
|
} // end 6th loop |
924 |
|
} // end strassen_internal |
925 |
|
|
926 |
|
template<typename TA, typename TB, typename TV> |
927 |
|
void hmlp_dynamic_peeling |
928 |
|
( |
929 |
|
hmlpOperation_t transA, hmlpOperation_t transB, |
930 |
|
int m, int n, int k, |
931 |
|
TA *A, int lda, |
932 |
|
TB *B, int ldb, |
933 |
|
TV *C, int ldc, |
934 |
|
int dim1, int dim2, int dim3 |
935 |
|
) |
936 |
|
{ |
937 |
|
//printf( "Enter dynamic peeling\n" ); |
938 |
|
int mr = m % dim1; |
939 |
|
int kr = k % dim2; |
940 |
|
int nr = n % dim3; |
941 |
|
int ms = m - mr; |
942 |
|
int ns = n - nr; |
943 |
|
int ks = k - kr; |
944 |
|
TA *A_extra; |
945 |
|
TB *B_extra; |
946 |
|
TV *C_extra; |
947 |
|
|
948 |
|
char transA_val, transB_val; |
949 |
|
char *char_transA = &transA_val, *char_transB = &transB_val; |
950 |
|
|
951 |
|
|
952 |
|
//printf( "flag d1\n" ); |
953 |
|
|
954 |
|
// Adjust part handled by fast matrix multiplication. |
955 |
|
// Add far column of A outer product bottom row B |
956 |
|
if ( kr > 0 ) { |
957 |
|
// In Strassen, this looks like C([1, 2], [1, 2]) += A([1, 2], 3) * B(3, [1, 2]) |
958 |
|
|
959 |
|
//printf( "flag d2\n" ); |
960 |
|
|
961 |
|
if ( transA == HMLP_OP_N ) { |
962 |
|
A_extra = &A[ 0 + ks * lda ];//ms * kr |
963 |
|
*char_transA = 'N'; |
964 |
|
} else { |
965 |
|
A_extra = &A[ 0 * lda + ks ];//ms * kr |
966 |
|
*char_transA = 'T'; |
967 |
|
} |
968 |
|
|
969 |
|
//printf( "flag d3\n" ); |
970 |
|
if ( transB == HMLP_OP_N ) { |
971 |
|
B_extra = &B[ ks + 0 * ldb ];//kr * ns |
972 |
|
*char_transB = 'N'; |
973 |
|
} else { |
974 |
|
B_extra = &B[ ks * ldb + 0 ];//kr * ns |
975 |
|
*char_transB = 'T'; |
976 |
|
} |
977 |
|
|
978 |
|
//printf( "flag d4\n" ); |
979 |
|
C_extra = &C[ 0 + 0 * ldc ];//ms * ns |
980 |
|
if ( ms > 0 && ns > 0 ) |
981 |
|
{ |
982 |
|
//bl_dgemm( ms, ns, kr, A_extra, lda, B_extra, ldb, C_extra, ldc ); |
983 |
|
xgemm( char_transA, char_transB, ms, ns, kr, 1.0, A_extra, lda, B_extra, ldb, 1.0, C_extra, ldc ); |
984 |
|
} |
985 |
|
} |
986 |
|
|
987 |
|
//printf( "flag d5\n" ); |
988 |
|
|
989 |
|
// Adjust for far right columns of C |
990 |
|
if ( nr > 0 ) { |
991 |
|
// In Strassen, this looks like C(:, 3) = A * B(:, 3) |
992 |
|
|
993 |
|
if ( transA == HMLP_OP_N ) { |
994 |
|
*char_transA = 'N'; |
995 |
|
} else { |
996 |
|
*char_transA = 'T'; |
997 |
|
} |
998 |
|
//printf( "flag d6\n" ); |
999 |
|
|
1000 |
|
if ( transB == HMLP_OP_N ) { |
1001 |
|
B_extra = &B[ 0 + ns * ldb ];//k * nr |
1002 |
|
*char_transB = 'N'; |
1003 |
|
} else { |
1004 |
|
B_extra = &B[ 0 * ldb + ns ];//k * nr |
1005 |
|
*char_transB = 'T'; |
1006 |
|
} |
1007 |
|
|
1008 |
|
|
1009 |
|
//printf( "flag d7\n" ); |
1010 |
|
|
1011 |
|
C_extra = &C[ 0 + ns * ldc ];//m * nr |
1012 |
|
//bl_dgemm( m, nr, k, A, lda, B_extra, ldb, C_extra, ldc ); |
1013 |
|
xgemm( char_transA, char_transB, m, nr, k, 1.0, A, lda, B_extra, ldb, 1.0, C_extra, ldc ); |
1014 |
|
|
1015 |
|
} |
1016 |
|
|
1017 |
|
//printf( "flag d8\n" ); |
1018 |
|
|
1019 |
|
// Adjust for bottom rows of C |
1020 |
|
if ( mr > 0 ) { |
1021 |
|
// In Strassen, this looks like C(3, [1, 2]) = A(3, :) * B(:, [1, 2]) |
1022 |
|
|
1023 |
|
|
1024 |
|
//printf( "flag d8.1\n" ); |
1025 |
|
if ( transA == HMLP_OP_N ) { |
1026 |
|
|
1027 |
|
//printf( "flag d8.15\n" ); |
1028 |
|
A_extra = &A[ ms + 0 * lda ];// mr * k |
1029 |
|
//printf( "flag d8.16\n" ); |
1030 |
|
*char_transA = 'N'; |
1031 |
|
|
1032 |
|
//printf( "flag d8.2\n" ); |
1033 |
|
} else { |
1034 |
|
A_extra = &A[ ms * lda + 0 ];// mr * k |
1035 |
|
*char_transA = 'T'; |
1036 |
|
//printf( "flag d8.3\n" ); |
1037 |
|
} |
1038 |
|
|
1039 |
|
//printf( "flag d8.4\n" ); |
1040 |
|
|
1041 |
|
if ( transB == HMLP_OP_N ) { |
1042 |
|
B_extra = &B[ 0 + 0 * ldb ];// k * ns |
1043 |
|
*char_transB = 'N'; |
1044 |
|
|
1045 |
|
//printf( "flag d8.5\n" ); |
1046 |
|
|
1047 |
|
} else { |
1048 |
|
B_extra = &B[ 0 * ldb + 0 ];// k * ns |
1049 |
|
*char_transB = 'T'; |
1050 |
|
|
1051 |
|
//printf( "flag d8.6\n" ); |
1052 |
|
} |
1053 |
|
|
1054 |
|
//printf( "flag d9\n" ); |
1055 |
|
|
1056 |
|
TV *C_extra = &C[ ms + 0 * ldc ];// mr * ns |
1057 |
|
if ( ns > 0 ) |
1058 |
|
{ |
1059 |
|
//bl_dgemm( mr, ns, k, A_extra, lda, B_extra, ldb, C_extra, ldc ); |
1060 |
|
xgemm( char_transA, char_transB, mr, ns, k, 1.0, A_extra, lda, B_extra, ldb, 1.0, C_extra, ldc ); |
1061 |
|
} |
1062 |
|
} |
1063 |
|
//printf( "Leave dynamic peeling\n" ); |
1064 |
|
} |
1065 |
|
|
1066 |
|
template< |
1067 |
|
int MC, int NC, int KC, int MR, int NR, |
1068 |
|
int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE, |
1069 |
|
bool USE_STRASSEN, |
1070 |
|
typename STRA_SEMIRINGKERNEL, typename STRA_MICROKERNEL, |
1071 |
|
typename TA, typename TB, typename TC, typename TV> |
1072 |
|
void strassen_internal |
1073 |
|
( |
1074 |
|
Worker &thread, |
1075 |
|
hmlpOperation_t transA, hmlpOperation_t transB, |
1076 |
|
int m, int n, int k, |
1077 |
|
TA *A, int lda, int *amap, |
1078 |
|
TB *B, int ldb, int *bmap, |
1079 |
|
TV *C, int ldc, |
1080 |
|
STRA_SEMIRINGKERNEL stra_semiringkernel, |
1081 |
|
STRA_MICROKERNEL stra_microkernel, |
1082 |
|
int nc, int pack_nc, |
1083 |
|
TA *packA_buff, |
1084 |
|
TB *packB_buff |
1085 |
|
) |
1086 |
|
{ |
1087 |
|
|
1088 |
|
int ms, ks, ns; |
1089 |
|
int md, kd, nd; |
1090 |
|
int mr, kr, nr; |
1091 |
|
|
1092 |
|
mr = m % ( 2 ), kr = k % ( 2 ), nr = n % ( 2 ); |
1093 |
|
md = m - mr, kd = k - kr, nd = n - nr; |
1094 |
|
|
1095 |
|
// Partition code. |
1096 |
|
ms=md, ks=kd, ns=nd; |
1097 |
|
TA *A00, *A01, *A10, *A11; |
1098 |
|
hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 0, 0, &A00 ); |
1099 |
|
hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 0, 1, &A01 ); |
1100 |
|
hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 1, 0, &A10 ); |
1101 |
|
hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 1, 1, &A11 ); |
1102 |
|
|
1103 |
|
TB *B00, *B01, *B10, *B11; |
1104 |
|
hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 0, 0, &B00 ); |
1105 |
|
hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 0, 1, &B01 ); |
1106 |
|
hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 1, 0, &B10 ); |
1107 |
|
hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 1, 1, &B11 ); |
1108 |
|
|
1109 |
|
TV *C00, *C01, *C10, *C11; |
1110 |
|
hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 0, 0, &C00 ); |
1111 |
|
hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 0, 1, &C01 ); |
1112 |
|
hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 1, 0, &C10 ); |
1113 |
|
hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 1, 1, &C11 ); |
1114 |
|
|
1115 |
|
md = md / 2, kd = kd / 2, nd = nd / 2; |
1116 |
|
|
1117 |
|
// M1: C00 = 1*C00+1*(A00+A11)(B00+B11); C11 = 1*C11+1*(A00+A11)(B00+B11) |
1118 |
|
STRAPRIM_MAP( A00, A11, 1, B00, B11, 1, C00, C11, 1, 1 ); |
1119 |
|
// M2: C10 = 1*C10+1*(A10+A11)B00; C11 = 1*C11-1*(A10+A11)B00 |
1120 |
|
STRAPRIM_MAP( A10, A11, 1, B00, NULL, 0, C10, C11, 1, -1 ) |
1121 |
|
// M3: C01 = 1*C01+1*A00(B01-B11); C11 = 1*C11+1*A00(B01-B11) |
1122 |
|
STRAPRIM_MAP( A00, NULL, 0, B01, B11, -1, C01, C11, 1, 1 ) |
1123 |
|
// M4: C00 = 1*C00+1*A11(B10-B00); C10 = 1*C10+1*A11(B10-B00) |
1124 |
|
STRAPRIM_MAP( A11, NULL, 0, B10, B00, -1, C00, C10, 1, 1 ) |
1125 |
|
// M5: C00 = 1*C00-1*(A00+A01)B11; C01 = 1*C01+1*(A00+A01)B11 |
1126 |
|
STRAPRIM_MAP( A00, A01, 1, B11, NULL, 0, C00, C01, -1, 1 ) |
1127 |
|
// M6: C11 = 1*C11+(A10-A00)(B00+B01) |
1128 |
|
STRAPRIM_MAP( A10, A00, -1, B00, B01, 1, C11, NULL, 1, 0 ) |
1129 |
|
// M7: C00 = 1*C00+(A01-A11)(B10+B11) |
1130 |
|
STRAPRIM_MAP( A01, A11, -1, B10, B11, 1, C00, NULL, 1, 0 ) |
1131 |
|
|
1132 |
|
if ( omp_get_thread_num() == 0 ) { //Chief thread |
1133 |
|
hmlp_dynamic_peeling( transA, transB, m, n, k, A, lda, B, ldb, C, ldc, 2, 2, 2 ); |
1134 |
|
} |
1135 |
|
|
1136 |
|
} |
1137 |
|
|
1138 |
|
template< |
1139 |
|
int MC, int NC, int KC, int MR, int NR, |
1140 |
|
int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE, |
1141 |
|
bool USE_STRASSEN, |
1142 |
|
typename STRA_SEMIRINGKERNEL, typename STRA_MICROKERNEL, |
1143 |
|
typename TA, typename TB, typename TC, typename TV> |
1144 |
|
void strassen_internal |
1145 |
|
( |
1146 |
|
Worker &thread, |
1147 |
|
hmlpOperation_t transA, hmlpOperation_t transB, |
1148 |
|
int m, int n, int k, |
1149 |
|
TA *A, int lda, |
1150 |
|
TB *B, int ldb, |
1151 |
|
TV *C, int ldc, |
1152 |
|
STRA_SEMIRINGKERNEL stra_semiringkernel, |
1153 |
|
STRA_MICROKERNEL stra_microkernel, |
1154 |
|
int nc, int pack_nc, |
1155 |
|
TA *packA_buff, |
1156 |
|
TB *packB_buff |
1157 |
|
) |
1158 |
|
{ |
1159 |
|
|
1160 |
|
int ms, ks, ns; |
1161 |
|
int md, kd, nd; |
1162 |
|
int mr, kr, nr; |
1163 |
|
|
1164 |
|
mr = m % ( 2 ), kr = k % ( 2 ), nr = n % ( 2 ); |
1165 |
|
md = m - mr, kd = k - kr, nd = n - nr; |
1166 |
|
|
1167 |
|
// Partition code. |
1168 |
|
ms=md, ks=kd, ns=nd; |
1169 |
|
TA *A00, *A01, *A10, *A11; |
1170 |
|
hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 0, 0, &A00 ); |
1171 |
|
hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 0, 1, &A01 ); |
1172 |
|
hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 1, 0, &A10 ); |
1173 |
|
hmlp_acquire_mpart( transA, ms, ks, A, lda, 2, 2, 1, 1, &A11 ); |
1174 |
|
|
1175 |
|
TB *B00, *B01, *B10, *B11; |
1176 |
|
hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 0, 0, &B00 ); |
1177 |
|
hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 0, 1, &B01 ); |
1178 |
|
hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 1, 0, &B10 ); |
1179 |
|
hmlp_acquire_mpart( transB, ks, ns, B, ldb, 2, 2, 1, 1, &B11 ); |
1180 |
|
|
1181 |
|
TV *C00, *C01, *C10, *C11; |
1182 |
|
hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 0, 0, &C00 ); |
1183 |
|
hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 0, 1, &C01 ); |
1184 |
|
hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 1, 0, &C10 ); |
1185 |
|
hmlp_acquire_mpart( HMLP_OP_N, ms, ns, C, ldc, 2, 2, 1, 1, &C11 ); |
1186 |
|
|
1187 |
|
md = md / 2, kd = kd / 2, nd = nd / 2; |
1188 |
|
|
1189 |
|
// M1: C00 = 1*C00+1*(A00+A11)(B00+B11); C11 = 1*C11+1*(A00+A11)(B00+B11) |
1190 |
|
STRAPRIM( A00, A11, 1, B00, B11, 1, C00, C11, 1, 1 ); |
1191 |
|
|
1192 |
|
//printf( "A00:\n" ); |
1193 |
|
//hmlp_printmatrix( md, kd, A00, m ); |
1194 |
|
//printf( "A11:\n" ); |
1195 |
|
//hmlp_printmatrix( md, kd, A11, m ); |
1196 |
|
//printf( "B00:\n" ); |
1197 |
|
//hmlp_printmatrix( kd, nd, B00, k ); |
1198 |
|
//printf( "B11:\n" ); |
1199 |
|
//hmlp_printmatrix( kd, nd, B11, k ); |
1200 |
|
//printf( "C00:\n" ); |
1201 |
|
//hmlp_printmatrix( md, nd, C00, m ); |
1202 |
|
//printf( "C01:\n" ); |
1203 |
|
//hmlp_printmatrix( md, nd, C11, m ); |
1204 |
|
|
1205 |
|
// M2: C10 = 1*C10+1*(A10+A11)B00; C11 = 1*C11-1*(A10+A11)B00 |
1206 |
|
STRAPRIM( A10, A11, 1, B00, NULL, 0, C10, C11, 1, -1 ) |
1207 |
|
|
1208 |
|
// M3: C01 = 1*C01+1*A00(B01-B11); C11 = 1*C11+1*A00(B01-B11) |
1209 |
|
STRAPRIM( A00, NULL, 0, B01, B11, -1, C01, C11, 1, 1 ) |
1210 |
|
// M4: C00 = 1*C00+1*A11(B10-B00); C10 = 1*C10+1*A11(B10-B00) |
1211 |
|
STRAPRIM( A11, NULL, 0, B10, B00, -1, C00, C10, 1, 1 ) |
1212 |
|
// M5: C00 = 1*C00-1*(A00+A01)B11; C01 = 1*C01+1*(A00+A01)B11 |
1213 |
|
STRAPRIM( A00, A01, 1, B11, NULL, 0, C00, C01, -1, 1 ) |
1214 |
|
// M6: C11 = 1*C11+(A10-A00)(B00+B01) |
1215 |
|
STRAPRIM( A10, A00, -1, B00, B01, 1, C11, NULL, 1, 0 ) |
1216 |
|
// M7: C00 = 1*C00+(A01-A11)(B10+B11) |
1217 |
|
STRAPRIM( A01, A11, -1, B10, B11, 1, C00, NULL, 1, 0 ) |
1218 |
|
|
1219 |
|
//printf( "C00:" ); |
1220 |
|
//hmlp_printmatrix( md, nd, C00, m ); |
1221 |
|
|
1222 |
|
//printf( "before dynamic peeling\n" ); |
1223 |
|
|
1224 |
|
if ( omp_get_thread_num() == 0 ) { //Chief thread |
1225 |
|
hmlp_dynamic_peeling( transA, transB, m, n, k, A, lda, B, ldb, C, ldc, 2, 2, 2 ); |
1226 |
|
} |
1227 |
|
|
1228 |
|
} |
1229 |
|
|
1230 |
|
|
1231 |
|
/** |
1232 |
|
* |
1233 |
|
* |
1234 |
|
*/ |
1235 |
|
template< |
1236 |
|
int MC, int NC, int KC, int MR, int NR, |
1237 |
|
int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE, |
1238 |
|
bool USE_STRASSEN, |
1239 |
|
typename STRA_SEMIRINGKERNEL, typename STRA_MICROKERNEL, |
1240 |
|
typename TA, typename TB, typename TC, typename TV> |
1241 |
|
void strassen |
1242 |
|
( |
1243 |
|
hmlpOperation_t transA, hmlpOperation_t transB, |
1244 |
|
int m, int n, int k, |
1245 |
|
TA *A, int lda, |
1246 |
|
TB *B, int ldb, |
1247 |
|
TV *C, int ldc, |
1248 |
|
STRA_SEMIRINGKERNEL stra_semiringkernel, |
1249 |
|
STRA_MICROKERNEL stra_microkernel |
1250 |
|
) |
1251 |
|
{ |
1252 |
|
int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1; |
1253 |
|
int nc = NC, pack_nc = PACK_NC; |
1254 |
|
char *str; |
1255 |
|
|
1256 |
|
TA *packA_buff = NULL; |
1257 |
|
TB *packB_buff = NULL; |
1258 |
|
|
1259 |
|
// Early return if possible |
1260 |
|
if ( m == 0 || n == 0 || k == 0 ) return; |
1261 |
|
|
1262 |
|
// Check the environment variable. |
1263 |
|
jc_nt = hmlp_read_nway_from_env( "KS_JC_NT" ); |
1264 |
|
ic_nt = hmlp_read_nway_from_env( "KS_IC_NT" ); |
1265 |
|
jr_nt = hmlp_read_nway_from_env( "KS_JR_NT" ); |
1266 |
|
|
1267 |
|
|
1268 |
|
if ( jc_nt > 1 ) |
1269 |
|
{ |
1270 |
|
nc = ( ( n - 1 ) / ( NR * jc_nt ) + 1 ) * NR; |
1271 |
|
pack_nc = ( nc / NR ) * PACK_NR; |
1272 |
|
} |
1273 |
|
|
1274 |
|
// allocate packing memory |
1275 |
|
packA_buff = hmlp_malloc<ALIGN_SIZE, TA>( KC, ( PACK_MC + 1 ) * jc_nt * ic_nt, sizeof(TA) ); |
1276 |
|
packB_buff = hmlp_malloc<ALIGN_SIZE, TB>( KC, ( pack_nc + 1 ) * jc_nt, sizeof(TB) ); |
1277 |
|
|
1278 |
|
// allocate tree communicator |
1279 |
|
thread_communicator my_comm( jc_nt, pc_nt, ic_nt, jr_nt ); |
1280 |
|
|
1281 |
|
#pragma omp parallel num_threads( my_comm.GetNumThreads() ) |
1282 |
|
{ |
1283 |
|
Worker thread( &my_comm ); |
1284 |
|
|
1285 |
|
strassen_internal |
1286 |
|
<MC, NC, KC, MR, NR, |
1287 |
|
PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, |
1288 |
|
USE_STRASSEN, |
1289 |
|
STRA_SEMIRINGKERNEL, STRA_MICROKERNEL, |
1290 |
|
TA, TB, TC, TB> |
1291 |
|
( |
1292 |
|
thread, |
1293 |
|
transA, transB, |
1294 |
|
m, n, k, |
1295 |
|
A, lda, |
1296 |
|
B, ldb, |
1297 |
|
C, ldc, |
1298 |
|
stra_semiringkernel, stra_microkernel, |
1299 |
|
nc, pack_nc, |
1300 |
|
packA_buff, |
1301 |
|
packB_buff |
1302 |
|
); |
1303 |
|
|
1304 |
|
} |
1305 |
|
// end omp |
1306 |
|
} // end strassen |
1307 |
|
|
1308 |
|
|
1309 |
|
}; // end namespace strassen |
1310 |
|
}; // end namespace hmlp |
1311 |
|
|
1312 |
|
#endif // define STRASSEN_HPP |
1313 |
|
|
1314 |
|
|