1 |
|
/** |
2 |
|
* HMLP (High-Performance Machine Learning Primitives) |
3 |
|
* |
4 |
|
* Copyright (C) 2014-2017, The University of Texas at Austin |
5 |
|
* |
6 |
|
* This program is free software: you can redistribute it and/or modify |
7 |
|
* it under the terms of the GNU General Public License as published by |
8 |
|
* the Free Software Foundation, either version 3 of the License, or |
9 |
|
* (at your option) any later version. |
10 |
|
* |
11 |
|
* This program is distributed in the hope that it will be useful, |
12 |
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of |
13 |
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
14 |
|
* GNU General Public License for more details. |
15 |
|
* |
16 |
|
* You should have received a copy of the GNU General Public License |
17 |
|
* along with this program. If not, see the LICENSE file. |
18 |
|
* |
19 |
|
**/ |
20 |
|
|
21 |
|
|
22 |
|
|
23 |
|
#ifndef GNBX_HPP |
24 |
|
#define GNBX_HPP |
25 |
|
|
26 |
|
#include <assert.h> |
27 |
|
#include <typeinfo> |
28 |
|
#include <algorithm> |
29 |
|
|
30 |
|
#include <hmlp.h> |
31 |
|
#include <hmlp_internal.hpp> |
32 |
|
#include <hmlp_base.hpp> |
33 |
|
|
34 |
|
/** for USE_STRASSEN */ |
35 |
|
//#include <primitives/strassen.hpp> |
36 |
|
|
37 |
|
/** reference microkernels */ |
38 |
|
#include <packing.hpp> |
39 |
|
#include <semiring_mrxnr.hpp> |
40 |
|
#include <fused_mrxnr.hpp> |
41 |
|
|
42 |
|
using namespace std; |
43 |
|
|
44 |
|
|
45 |
|
namespace hmlp |
46 |
|
{ |
47 |
|
namespace gnbx |
48 |
|
{ |
49 |
|
|
50 |
|
/** |
51 |
|
* @brief Macro kernel contains the 3rd and 2nd loops. Depending on the |
52 |
|
* configuration of the communicator, the 3rd loop may be parallelized. |
53 |
|
* b_next is the prefetch pointer. |
54 |
|
*/ |
55 |
|
template<int KC, typename SEMIRINGKERNEL, typename TA, typename TB, typename TV> |
56 |
|
void rank_k_macro_kernel |
57 |
|
( |
58 |
|
Worker &Comm4th, |
59 |
|
int ic, int jc, int pc, |
60 |
|
int m, int n, int k, |
61 |
|
TA *packA, |
62 |
|
TB *packB, |
63 |
|
TV *V, int rs_v, int cs_v, |
64 |
|
SEMIRINGKERNEL semiringkernel |
65 |
|
) |
66 |
|
{ |
67 |
|
/** Get all block sizes */ |
68 |
|
const static int MR = SEMIRINGKERNEL::mr; |
69 |
|
const static int NR = SEMIRINGKERNEL::nr; |
70 |
|
const static int PACK_MR = SEMIRINGKERNEL::pack_mr; |
71 |
|
const static int PACK_NR = SEMIRINGKERNEL::pack_nr; |
72 |
|
|
73 |
|
/** Get ic loop communicator */ |
74 |
|
thread_communicator &ic_comm = *Comm4th.comm; |
75 |
|
|
76 |
|
/** Compute loop ranges for each thread */ |
77 |
|
auto Loop3rd = Comm4th.DistributeOver1DGangs( 0, n, NR ); |
78 |
|
auto Pack3rd = Comm4th.DistributeOver1DGangs( 0, n, PACK_NR ); |
79 |
|
auto Loop2nd = Comm4th.DistributeOver1DThreads( 0, m, MR ); |
80 |
|
auto Pack2nd = Comm4th.DistributeOver1DThreads( 0, m, PACK_MR ); |
81 |
|
|
82 |
|
/** Loop 3rd (jr loop) */ |
83 |
|
for ( int j = get<0>( Loop3rd ), jp = get<0>( Pack3rd ); |
84 |
|
j < get<1>( Loop3rd ); |
85 |
|
j += get<2>( Loop3rd ), jp += get<2>( Pack3rd ) ) |
86 |
|
{ |
87 |
|
struct aux_s<TA, TB, TV, TV> aux; |
88 |
|
aux.pc = pc; |
89 |
|
aux.b_next = packB; |
90 |
|
aux.do_packC = 0; |
91 |
|
aux.jb = std::min( n - j, NR ); |
92 |
|
|
93 |
|
/** Loop 2nd (ir loop) */ |
94 |
|
for ( int i = get<0>( Loop2nd ), ip = get<0>( Pack2nd ); |
95 |
|
i < get<1>( Loop2nd ); |
96 |
|
i += get<2>( Loop2nd ), ip += get<2>( Pack2nd ) ) |
97 |
|
{ |
98 |
|
aux.ib = std::min( m - i, MR ); |
99 |
|
if ( i + MR >= m ) |
100 |
|
{ |
101 |
|
aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k; |
102 |
|
} |
103 |
|
|
104 |
|
if ( aux.jb == NR && aux.ib == MR ) |
105 |
|
{ |
106 |
|
semiringkernel |
107 |
|
( |
108 |
|
k, |
109 |
|
&packA[ ip * k ], |
110 |
|
&packB[ jp * k ], |
111 |
|
&V[ i * rs_v + j * cs_v ], rs_v, cs_v, |
112 |
|
&aux |
113 |
|
); |
114 |
|
} |
115 |
|
else // corner case |
116 |
|
{ |
117 |
|
TV vtmp[ MR * NR ]; |
118 |
|
|
119 |
|
if ( pc ) // initilize ctmp |
120 |
|
{ |
121 |
|
for ( auto jj = 0; jj < aux.jb; jj ++ ) |
122 |
|
for ( auto ii = 0; ii < aux.ib; ii ++ ) |
123 |
|
vtmp[ jj * MR + ii ] = |
124 |
|
V[ ( j + jj ) * cs_v + ( i + ii ) * rs_v ]; |
125 |
|
} |
126 |
|
|
127 |
|
semiringkernel |
128 |
|
( |
129 |
|
k, |
130 |
|
&packA[ ip * k ], |
131 |
|
&packB[ jp * k ], |
132 |
|
vtmp, 1, MR, |
133 |
|
&aux |
134 |
|
); |
135 |
|
|
136 |
|
for ( auto jj = 0; jj < aux.jb; jj ++ ) |
137 |
|
for ( auto ii = 0; ii < aux.ib; ii ++ ) |
138 |
|
V[ ( j + jj ) * cs_v + ( i + ii ) * rs_v ] = vtmp[ jj * MR + ii ]; |
139 |
|
} |
140 |
|
} // end 2nd loop |
141 |
|
} // end 3rd loop |
142 |
|
}; // end rank_k_macro_kernel |
143 |
|
|
144 |
|
|
145 |
|
|
146 |
|
|
147 |
|
|
148 |
|
/** |
149 |
|
* @brief fused_macro_kernel contains the 3rd, 2nd loops and the fused micro |
150 |
|
* kernel. Notice that here C has type TC, which is differnet from the |
151 |
|
* one in rank_k_macro_kernel. ctmp used in the conner case is also |
152 |
|
* type TC. |
153 |
|
*/ |
154 |
|
template<int KC, typename FUSEDKERNEL, typename TA, typename TB, typename TC, typename TV> |
155 |
|
void fused_macro_kernel |
156 |
|
( |
157 |
|
Worker &Comm4th, |
158 |
|
int m, int n, |
159 |
|
int ic, int jc, int pc, |
160 |
|
int mc, int nc, int kc, |
161 |
|
TA *packA, |
162 |
|
TB *packB, |
163 |
|
TC *C, |
164 |
|
TV *V, int rs_v, int cs_v, |
165 |
|
int batchId, |
166 |
|
FUSEDKERNEL fusedkernel |
167 |
|
) |
168 |
|
{ |
169 |
|
/** Get all block sizes */ |
170 |
|
const static int MR = FUSEDKERNEL::mr; |
171 |
|
const static int NR = FUSEDKERNEL::nr; |
172 |
|
const static int PACK_MR = FUSEDKERNEL::pack_mr; |
173 |
|
const static int PACK_NR = FUSEDKERNEL::pack_nr; |
174 |
|
|
175 |
|
/** Get ic loop communicator */ |
176 |
|
thread_communicator &ic_comm = *Comm4th.comm; |
177 |
|
|
178 |
|
/** Compute loop ranges for each thread */ |
179 |
|
auto Loop3rd = Comm4th.DistributeOver1DGangs( 0, nc, NR ); |
180 |
|
auto Pack3rd = Comm4th.DistributeOver1DGangs( 0, nc, PACK_NR ); |
181 |
|
auto Loop2nd = Comm4th.DistributeOver1DThreads( 0, mc, MR ); |
182 |
|
auto Pack2nd = Comm4th.DistributeOver1DThreads( 0, mc, PACK_MR ); |
183 |
|
|
184 |
|
/** Loop 3rd (jr loop) */ |
185 |
|
for ( int j = get<0>( Loop3rd ), jp = get<0>( Pack3rd ); |
186 |
|
j < get<1>( Loop3rd ); |
187 |
|
j += get<2>( Loop3rd ), jp += get<2>( Pack3rd ) ) |
188 |
|
{ |
189 |
|
struct aux_s<TA, TB, TC, TV> aux; |
190 |
|
aux.pc = pc; |
191 |
|
aux.b_next = packB; |
192 |
|
aux.do_packC = 0; |
193 |
|
|
194 |
|
/** Loop 2nd (ir loop) */ |
195 |
|
for ( int i = get<0>( Loop2nd ), ip = get<0>( Pack2nd ); |
196 |
|
i < get<1>( Loop2nd ); |
197 |
|
i += get<2>( Loop2nd ), ip += get<2>( Pack2nd ) ) |
198 |
|
{ |
199 |
|
/** |
200 |
|
* These auxiluary infos are used to access data in the closure of |
201 |
|
* opkernel and opreduce. |
202 |
|
*/ |
203 |
|
aux.m = m; |
204 |
|
aux.n = n; |
205 |
|
aux.i = ic + i; |
206 |
|
aux.j = jc + j; |
207 |
|
aux.b = batchId; |
208 |
|
|
209 |
|
/** |
210 |
|
* Encapsulate edge case information. |
211 |
|
*/ |
212 |
|
aux.ib = std::min( mc - i, MR ); |
213 |
|
aux.jb = std::min( nc - j, NR ); |
214 |
|
|
215 |
|
/** |
216 |
|
* Prepare the intermediate semiring rank-k update |
217 |
|
*/ |
218 |
|
aux.V = V + i * rs_v + j * cs_v; |
219 |
|
aux.ldv = cs_v; |
220 |
|
|
221 |
|
if ( i + MR >= mc ) |
222 |
|
{ |
223 |
|
aux.b_next += ic_comm.GetNumThreads() * PACK_NR * kc; |
224 |
|
} |
225 |
|
|
226 |
|
if ( aux.jb == NR && aux.ib == MR ) |
227 |
|
{ |
228 |
|
fusedkernel |
229 |
|
( |
230 |
|
kc, |
231 |
|
&packA[ ip * kc ], |
232 |
|
&packB[ jp * kc ], |
233 |
|
C, |
234 |
|
&V[ i * rs_v + j * cs_v ], rs_v, cs_v, |
235 |
|
&aux |
236 |
|
); |
237 |
|
} |
238 |
|
else |
239 |
|
{ |
240 |
|
TV vtmp[ MR * NR ]; |
241 |
|
if ( pc ) // initilize ctmp |
242 |
|
{ |
243 |
|
for ( auto jj = 0; jj < aux.jb; jj ++ ) |
244 |
|
for ( auto ii = 0; ii < aux.ib; ii ++ ) |
245 |
|
vtmp[ jj * MR + ii ] = |
246 |
|
V[ ( j + jj ) * cs_v + ( i + ii ) * rs_v ]; |
247 |
|
aux.V = vtmp; |
248 |
|
aux.ldv = MR; |
249 |
|
} |
250 |
|
fusedkernel |
251 |
|
( |
252 |
|
kc, |
253 |
|
&packA[ ip * kc ], |
254 |
|
&packB[ jp * kc ], |
255 |
|
C, |
256 |
|
vtmp, 1, MR, |
257 |
|
&aux |
258 |
|
); |
259 |
|
} |
260 |
|
} |
261 |
|
} |
262 |
|
|
263 |
|
}; /** end fused_macro_kernel() */ |
264 |
|
|
265 |
|
|
266 |
|
|
267 |
|
|
268 |
|
/** |
269 |
|
* @breif This function contains the loop body of the 6th to 4th loops, |
270 |
|
* including all packing and unpacking routines. Notice that this |
271 |
|
* function is executed by all threads in the root communicator. |
272 |
|
* To access each thread in different level of communicators, use |
273 |
|
* their ids. |
274 |
|
*/ |
275 |
|
template< |
276 |
|
int MC, int NC, int KC, |
277 |
|
typename TPACKA, typename TPACKB, typename TV, |
278 |
|
typename TA, typename TB, typename TC, |
279 |
|
typename SEMIRINGKERNEL, typename MICROKERNEL> |
280 |
|
void gnbx_internal |
281 |
|
( |
282 |
|
Worker &thread, |
283 |
|
int batchId, int m, int n, int k, int k_stra, |
284 |
|
TA& A, |
285 |
|
TB& B, |
286 |
|
TC& C, |
287 |
|
TV* V, int rs_v, int cs_v, |
288 |
|
SEMIRINGKERNEL semiringkernel, |
289 |
|
MICROKERNEL microkernel |
290 |
|
) |
291 |
|
{ |
292 |
|
/** Get all block sizes */ |
293 |
|
const static int MR = SEMIRINGKERNEL::mr; |
294 |
|
const static int NR = SEMIRINGKERNEL::nr; |
295 |
|
const static int PACK_MR = SEMIRINGKERNEL::pack_mr; |
296 |
|
const static int PACK_NR = SEMIRINGKERNEL::pack_nr; |
297 |
|
const static int ALIGN_SIZE = SEMIRINGKERNEL::align_size; |
298 |
|
const static int PACK_MC = ( MC / MR ) * PACK_MR; |
299 |
|
const static int PACK_NC = ( NC / NR ) * PACK_NR; |
300 |
|
|
301 |
|
/** Create subcommunicators for each loop */ |
302 |
|
auto CommGLB = thread.Split(); |
303 |
|
auto Comm6th = CommGLB.Split(); |
304 |
|
auto Comm5th = Comm6th.Split(); |
305 |
|
auto Comm4th = Comm5th.Split(); |
306 |
|
|
307 |
|
|
308 |
|
/** Adjuest nc and pack_nc if the 6th loop is parallelized */ |
309 |
|
int nc = CommGLB.BalanceOver1DGangs( n, NC, NR ); |
310 |
|
int pack_nc = ( nc / NR ) * PACK_NR; |
311 |
|
|
312 |
|
|
313 |
|
|
314 |
|
//printf( "CommGLB %s tid %d gid %d ngangs %d\n", CommGLB.comm->name.data(), CommGLB.tid, CommGLB.gid, CommGLB.comm->GetNumGroups() ); |
315 |
|
//printf( "Comm6th %s tid %d gid %d ngangs %d\n", Comm6th.comm->name.data(), Comm6th.tid, Comm6th.gid, Comm6th.comm->GetNumGroups() ); |
316 |
|
//printf( "Comm5th %s tid %d gid %d ngangs %d\n", Comm5th.comm->name.data(), Comm5th.tid, Comm5th.gid, Comm5th.comm->GetNumGroups() ); |
317 |
|
//printf( "Comm4th %s tid %d gid %d ngangs %d\n", Comm4th.comm->name.data(), Comm4th.tid, Comm4th.gid, Comm4th.comm->GetNumGroups() ); |
318 |
|
//fflush( stdout ); |
319 |
|
|
320 |
|
/** |
321 |
|
* Allocate packing buffers: |
322 |
|
* |
323 |
|
* packA is shared over Comm4th |
324 |
|
* packB is shared over Comm5th |
325 |
|
*/ |
326 |
|
auto *packA = Comm4th.AllocateSharedMemory<ALIGN_SIZE, TPACKA>( KC * ( PACK_MC + 1 ) ); |
327 |
|
auto *packB = Comm5th.AllocateSharedMemory<ALIGN_SIZE, TPACKB>( KC * ( pack_nc + 1 ) ); |
328 |
|
|
329 |
|
/** Compute loop ranges for each thread */ |
330 |
|
auto Loop6th = CommGLB.DistributeOver1DGangs( 0, n, nc ); |
331 |
|
auto Loop5th = Comm6th.DistributeOver1DGangs( k_stra, k, KC ); |
332 |
|
auto Loop4th = Comm5th.DistributeOver1DGangs( 0, m, MC ); |
333 |
|
|
334 |
|
/** Comm6th is used inside the 6th loop (i.e. jc loop) */ |
335 |
|
for ( int jc = get<0>( Loop6th ); |
336 |
|
jc < get<1>( Loop6th ); |
337 |
|
jc += get<2>( Loop6th ) ) |
338 |
|
{ |
339 |
|
auto jb = std::min( n - jc, nc ); |
340 |
|
|
341 |
|
|
342 |
|
/** Comm5th is used inside the 6th loop (i.e. pc loop) */ |
343 |
|
for ( int pc = get<0>( Loop5th ); |
344 |
|
pc < get<1>( Loop5th ); |
345 |
|
pc += get<2>( Loop5th ) ) |
346 |
|
{ |
347 |
|
auto pb = std::min( k - pc, KC ); |
348 |
|
auto is_the_last_pc_iteration = ( pc + KC >= k ); |
349 |
|
auto LooppkB = Comm5th.DistributeOver1DThreads( 0, jb, NR ); |
350 |
|
auto PackpkB = Comm5th.DistributeOver1DThreads( 0, jb, PACK_NR ); |
351 |
|
|
352 |
|
for ( int j = get<0>( LooppkB ), jp = get<0>( PackpkB ); |
353 |
|
j < get<1>( LooppkB ); |
354 |
|
j += get<2>( LooppkB ), jp += get<2>( PackpkB ) ) |
355 |
|
{ |
356 |
|
/** packB and typecast from TB to TPACKB */ |
357 |
|
B.Pack( |
358 |
|
k, pc, pb, |
359 |
|
n, jc + j, std::min( jb - j, NR ), |
360 |
|
&packB[ jp * pb ] ); |
361 |
|
} |
362 |
|
Comm5th.Barrier(); |
363 |
|
|
364 |
|
|
365 |
|
/** Comm4th is used inside the 6th loop (i.e. pc loop) */ |
366 |
|
for ( int ic = get<0>( Loop4th ); |
367 |
|
ic < get<1>( Loop4th ); |
368 |
|
ic += get<2>( Loop4th ) ) |
369 |
|
{ |
370 |
|
auto &ic_comm = *thread.ic_comm; |
371 |
|
auto ib = std::min( m - ic, MC ); |
372 |
|
auto LooppkA = Comm4th.DistributeOver1DThreads( 0, ib, MR ); |
373 |
|
auto PackpkA = Comm4th.DistributeOver1DThreads( 0, ib, PACK_MR ); |
374 |
|
|
375 |
|
for ( int i = get<0>( LooppkA ), ip = get<0>( PackpkA ); |
376 |
|
i < get<1>( LooppkA ); |
377 |
|
i += get<2>( LooppkA ), ip += get<2>( PackpkA ) ) |
378 |
|
{ |
379 |
|
/** packA and typecast from TA to TPACKA */ |
380 |
|
A.Pack( |
381 |
|
m, ic + i, std::min( ib - i, MR ), |
382 |
|
k, pc, pb, |
383 |
|
&packA[ ip * pb ] ); |
384 |
|
} |
385 |
|
Comm4th.Barrier(); |
386 |
|
|
387 |
|
if ( is_the_last_pc_iteration ) // fused_macro_kernel |
388 |
|
{ |
389 |
|
fused_macro_kernel<KC> |
390 |
|
( |
391 |
|
Comm4th, |
392 |
|
m, n, |
393 |
|
ic, jc, pc, |
394 |
|
ib, jb, pb, |
395 |
|
packA, |
396 |
|
packB, |
397 |
|
&C, |
398 |
|
V + ic * rs_v + jc * cs_v, rs_v, cs_v, |
399 |
|
batchId, |
400 |
|
microkernel |
401 |
|
); |
402 |
|
|
403 |
|
} |
404 |
|
else // semiring rank-k update |
405 |
|
{ |
406 |
|
rank_k_macro_kernel<KC> |
407 |
|
( |
408 |
|
Comm4th, |
409 |
|
ic, jc, pc, |
410 |
|
ib, jb, pb, |
411 |
|
packA, |
412 |
|
packB, |
413 |
|
V + ic * rs_v + jc * cs_v, rs_v, cs_v, |
414 |
|
semiringkernel |
415 |
|
); |
416 |
|
} |
417 |
|
Comm4th.Barrier(); |
418 |
|
} // end 4th loop |
419 |
|
Comm5th.Barrier(); |
420 |
|
} // end 5th loop |
421 |
|
Comm6th.Barrier(); |
422 |
|
} // end 6th loop |
423 |
|
CommGLB.Barrier(); |
424 |
|
|
425 |
|
/** Free packing buffer */ |
426 |
|
Comm4th.FreeSharedMemory( packA ); |
427 |
|
Comm5th.FreeSharedMemory( packB ); |
428 |
|
|
429 |
|
}; /** end gnbx_internal() */ |
430 |
|
|
431 |
|
|
432 |
|
|
433 |
|
|
434 |
|
|
435 |
|
/** |
436 |
|
* @breif This is the main routine of gkmx. All packing buffers are |
437 |
|
* managed here. The communicator and the parallel section |
438 |
|
* start here. |
439 |
|
* |
440 |
|
*/ |
441 |
|
template< |
442 |
|
int MC, int NC, int KC, |
443 |
|
typename TPACKA, typename TPACKB, typename TV, |
444 |
|
typename TA, typename TB, typename TC, |
445 |
|
typename SEMIRINGKERNEL, typename MICROKERNEL> |
446 |
|
void gnbx |
447 |
|
( |
448 |
|
int batchId, int m, int n, int k, |
449 |
|
TA& A, |
450 |
|
TB& B, |
451 |
|
TC& C, |
452 |
|
SEMIRINGKERNEL semiringkernel, |
453 |
|
MICROKERNEL microkernel |
454 |
|
) |
455 |
|
{ |
456 |
|
const static int MR = SEMIRINGKERNEL::mr; |
457 |
|
const static int NR = SEMIRINGKERNEL::nr; |
458 |
|
const static int PACK_MR = SEMIRINGKERNEL::pack_mr; |
459 |
|
const static int PACK_NR = SEMIRINGKERNEL::pack_nr; |
460 |
|
const static int ALIGN_SIZE = SEMIRINGKERNEL::align_size; |
461 |
|
const static int PACK_MC = ( MC / MR ) * PACK_MR; |
462 |
|
const static int PACK_NC = ( NC / NR ) * PACK_NR; |
463 |
|
const static bool USE_STRASSEN = false; |
464 |
|
|
465 |
|
/** Early return if possible */ |
466 |
|
if ( m == 0 || n == 0 || k == 0 ) return; |
467 |
|
|
468 |
|
|
469 |
|
TV *V = NULL; |
470 |
|
int rs_v = 0; |
471 |
|
int cs_v = 0; |
472 |
|
|
473 |
|
|
474 |
|
if ( k > KC && !is_same<TC, MatrixLike<PACK_MR, TV, TV>>::value ) |
475 |
|
{ |
476 |
|
//printf( "here m %d n %d\n", m, n ); |
477 |
|
V = hmlp_malloc<ALIGN_SIZE, TV>( m * n ); |
478 |
|
rs_v = 1; |
479 |
|
cs_v = m; |
480 |
|
} |
481 |
|
else |
482 |
|
{ |
483 |
|
/** Directly use C for intermediate semiring rank-k update */ |
484 |
|
V = reinterpret_cast<TV*>( C.X ); |
485 |
|
rs_v = C.rs; |
486 |
|
cs_v = C.cs; |
487 |
|
} |
488 |
|
|
489 |
|
|
490 |
|
int k_stra = 0; |
491 |
|
if ( USE_STRASSEN ) |
492 |
|
{ |
493 |
|
assert( typeid(TPACKA) == typeid(TPACKB) ); |
494 |
|
assert( typeid(TC) == typeid(TV) ); |
495 |
|
k_stra = k - k % KC; |
496 |
|
|
497 |
|
if ( k_stra == k ) k_stra -= KC; |
498 |
|
} |
499 |
|
|
500 |
|
int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1; |
501 |
|
if ( omp_get_num_threads() == 1 && omp_get_max_threads() > 1 ) |
502 |
|
{ |
503 |
|
/** Check the environment variable. */ |
504 |
|
jc_nt = hmlp_read_nway_from_env( "KS_JC_NT" ); |
505 |
|
ic_nt = hmlp_read_nway_from_env( "KS_IC_NT" ); |
506 |
|
jr_nt = hmlp_read_nway_from_env( "KS_JR_NT" ); |
507 |
|
} |
508 |
|
|
509 |
|
/** allocate tree communicator */ |
510 |
|
thread_communicator my_comm( jc_nt, pc_nt, ic_nt, jr_nt ); |
511 |
|
|
512 |
|
#pragma omp parallel num_threads( my_comm.GetNumThreads() ) |
513 |
|
{ |
514 |
|
Worker thread( &my_comm ); |
515 |
|
|
516 |
|
/** TODO: */ |
517 |
|
thread.InitWithCommunicator( &my_comm, omp_get_thread_num(), 0 ); |
518 |
|
|
519 |
|
//if ( USE_STRASSEN ) |
520 |
|
//{ |
521 |
|
// strassen::strassen_internal |
522 |
|
// <MC, NC, KC, MR, NR, |
523 |
|
// PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, |
524 |
|
// USE_STRASSEN, |
525 |
|
// SEMIRINGKERNEL, SEMIRINGKERNEL, |
526 |
|
// TA, TPACKA, TB, TPACKB, TC, TV> |
527 |
|
// ( |
528 |
|
// thread, |
529 |
|
// m, n, k_stra, |
530 |
|
// A, packakernel, |
531 |
|
// B, packbkernel, |
532 |
|
// V, ldv, |
533 |
|
// semiringkernel, semiringkernel, |
534 |
|
// nc, pack_nc, |
535 |
|
// packA_buff, |
536 |
|
// packB_buff |
537 |
|
// ); |
538 |
|
//} |
539 |
|
|
540 |
|
gnbx_internal<MC, NC, KC, TPACKA, TPACKB> |
541 |
|
( |
542 |
|
thread, |
543 |
|
batchId, m, n, k, k_stra, |
544 |
|
A, |
545 |
|
B, |
546 |
|
C, |
547 |
|
V, rs_v, cs_v, |
548 |
|
semiringkernel, microkernel |
549 |
|
); |
550 |
|
} // end omp parallel |
551 |
|
|
552 |
|
if ( k > KC && !is_same<TC, MatrixLike<PACK_MR, TV, TV>>::value ) |
553 |
|
{ |
554 |
|
hmlp_free( V ); |
555 |
|
} |
556 |
|
}; // end gkmx |
557 |
|
|
558 |
|
|
559 |
|
|
560 |
|
|
561 |
|
|
562 |
|
/** |
563 |
|
* @beief |
564 |
|
*/ |
565 |
|
template< |
566 |
|
int MR, int NR, int MC, int NC, int KC, |
567 |
|
typename TPACKA, typename TPACKB, typename TPACKC, typename TV, |
568 |
|
typename TA, typename TB, typename TC, |
569 |
|
typename OPKERNEL, typename OP1, typename OP2> |
570 |
|
void gnbx |
571 |
|
( |
572 |
|
int batchId, int m, int n, int k, |
573 |
|
TA& A, |
574 |
|
TB& B, |
575 |
|
TC& C, |
576 |
|
OPKERNEL opkernel, OP1 op1, OP2 op2, TV initV |
577 |
|
) |
578 |
|
{ |
579 |
|
semiring_mrxnr<MR, NR, OP1, OP2, TPACKA, TPACKB, TV, TV> semiringkernel; |
580 |
|
gnbx_mrxnr<MR, NR, OPKERNEL, OP1, OP2, TPACKA, TPACKB, TC, TPACKC, TV> gkrmkernel; |
581 |
|
|
582 |
|
semiringkernel.op1 = op1; |
583 |
|
semiringkernel.op2 = op2; |
584 |
|
semiringkernel.initV = initV; |
585 |
|
|
586 |
|
gkrmkernel.op1 = op1; |
587 |
|
gkrmkernel.op2 = op2; |
588 |
|
gkrmkernel.opkernel = opkernel; |
589 |
|
gkrmkernel.initV = initV; |
590 |
|
|
591 |
|
gnbx<MC, NC, KC, TPACKA, TPACKB, TV> |
592 |
|
( batchId, m, n, k, A, B, C, semiringkernel, gkrmkernel ); |
593 |
|
|
594 |
|
}; /** end gnbx() */ |
595 |
|
|
596 |
|
}; /** end namespace gnbx */ |
597 |
|
}; /** end namespace hmlp */ |
598 |
|
|
599 |
|
#endif /** define GNBX_HPP */ |