1 |
|
/** |
2 |
|
* HMLP (High-Performance Machine Learning Primitives) |
3 |
|
* |
4 |
|
* Copyright (C) 2014-2017, The University of Texas at Austin |
5 |
|
* |
6 |
|
* This program is free software: you can redistribute it and/or modify |
7 |
|
* it under the terms of the GNU General Public License as published by |
8 |
|
* the Free Software Foundation, either version 3 of the License, or |
9 |
|
* (at your option) any later version. |
10 |
|
* |
11 |
|
* This program is distributed in the hope that it will be useful, |
12 |
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of |
13 |
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
14 |
|
* GNU General Public License for more details. |
15 |
|
* |
16 |
|
* You should have received a copy of the GNU General Public License |
17 |
|
* along with this program. If not, see the LICENSE file. |
18 |
|
* |
19 |
|
**/ |
20 |
|
|
21 |
|
|
22 |
|
|
23 |
|
|
24 |
|
#ifndef GSKS_HXX |
25 |
|
#define GSKS_HXX |
26 |
|
|
27 |
|
#include <math.h> |
28 |
|
#include <vector> |
29 |
|
|
30 |
|
#include <hmlp.h> |
31 |
|
#include <hmlp_internal.hpp> |
32 |
|
#include <hmlp_base.hpp> |
33 |
|
|
34 |
|
#include <KernelMatrix.hpp> |
35 |
|
|
36 |
|
|
37 |
|
namespace hmlp |
38 |
|
{ |
39 |
|
namespace gsks |
40 |
|
{ |
41 |
|
|
42 |
|
#define min( i, j ) ( (i)<(j) ? (i): (j) ) |
43 |
|
#define KS_RHS 1 |
44 |
|
|
45 |
|
/** |
46 |
|
* |
47 |
|
*/ |
48 |
|
template< |
49 |
|
int KC, int MR, int NR, int PACK_MR, int PACK_NR, |
50 |
|
typename SEMIRINGKERNEL, |
51 |
|
typename TA, typename TB, typename TC, typename TV> |
52 |
|
void rank_k_macro_kernel |
53 |
|
( |
54 |
|
Worker &thread, |
55 |
|
int ic, int jc, int pc, |
56 |
|
int m, int n, int k, |
57 |
|
TA *packA, |
58 |
|
TB *packB, |
59 |
|
TV *packC, int ldc, |
60 |
|
SEMIRINGKERNEL semiringkernel |
61 |
|
) |
62 |
|
{ |
63 |
|
thread_communicator &ic_comm = *thread.ic_comm; |
64 |
|
|
65 |
|
auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() ); |
66 |
|
auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() ); |
67 |
|
auto loop2nd = GetRange( 0, m, MR ); |
68 |
|
auto pack2nd = GetRange( 0, m, PACK_MR ); |
69 |
|
|
70 |
|
for ( int j = loop3rd.beg(), jp = pack3rd.beg(); |
71 |
|
j < loop3rd.end(); |
72 |
|
j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop |
73 |
|
{ |
74 |
|
struct aux_s<TA, TB, TC, TV> aux; |
75 |
|
aux.pc = pc; |
76 |
|
aux.b_next = packB; |
77 |
|
aux.do_packC = 1; |
78 |
|
aux.jb = min( n - j, NR ); |
79 |
|
|
80 |
|
for ( int i = loop2nd.beg(), ip = pack2nd.beg(); |
81 |
|
i < loop2nd.end(); |
82 |
|
i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop |
83 |
|
{ |
84 |
|
aux.ib = min( m - i, MR ); |
85 |
|
if ( i + MR >= m ) |
86 |
|
{ |
87 |
|
aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k; |
88 |
|
} |
89 |
|
semiringkernel |
90 |
|
( |
91 |
|
k, |
92 |
|
&packA[ ip * k ], |
93 |
|
&packB[ jp * k ], |
94 |
|
//&packC[ j * ldc + i * NR ], ldc, |
95 |
|
&packC[ j * ldc + i * NR ], 1, MR, |
96 |
|
&aux |
97 |
|
); |
98 |
|
} // end 2nd loop |
99 |
|
} // end 3rd loop |
100 |
|
} // end rank_k_macro_kernel |
101 |
|
|
102 |
|
/** |
103 |
|
* |
104 |
|
*/ |
105 |
|
template< |
106 |
|
int KC, int MR, int NR, int PACK_MR, int PACK_NR, |
107 |
|
typename MICROKERNEL, |
108 |
|
typename TA, typename TB, typename TC, typename TV> |
109 |
|
void fused_macro_kernel |
110 |
|
( |
111 |
|
//ks_t *kernel, |
112 |
|
kernel_s<TV, TC> *kernel, |
113 |
|
Worker &thread, |
114 |
|
int ic, int jc, int pc, |
115 |
|
int m, int n, int k, |
116 |
|
TC *packu, |
117 |
|
TA *packA, TA *packA2, TV *packAh, |
118 |
|
TB *packB, TB *packB2, TV *packBh, |
119 |
|
TC *packw, |
120 |
|
TV *packC, int ldc, |
121 |
|
MICROKERNEL microkernel |
122 |
|
) |
123 |
|
{ |
124 |
|
thread_communicator &ic_comm = *thread.ic_comm; |
125 |
|
|
126 |
|
auto loop3rd = GetRange( 0, n, NR, thread.jr_id, ic_comm.GetNumThreads() ); |
127 |
|
auto pack3rd = GetRange( 0, n, PACK_NR, thread.jr_id, ic_comm.GetNumThreads() ); |
128 |
|
auto loop2nd = GetRange( 0, m, MR ); |
129 |
|
auto pack2nd = GetRange( 0, m, PACK_MR ); |
130 |
|
|
131 |
|
for ( int j = loop3rd.beg(), jp = pack3rd.beg(); |
132 |
|
j < loop3rd.end(); |
133 |
|
j += loop3rd.inc(), jp += pack3rd.inc() ) // beg 3rd loop |
134 |
|
{ |
135 |
|
struct aux_s<TA, TB, TC, TV> aux; |
136 |
|
aux.pc = pc; |
137 |
|
aux.b_next = packB; |
138 |
|
aux.do_packC = 1; |
139 |
|
aux.jb = min( n - j, NR ); |
140 |
|
|
141 |
|
for ( int i = loop2nd.beg(), ip = pack2nd.beg(); |
142 |
|
i < loop2nd.end(); |
143 |
|
i += loop2nd.inc(), ip += pack2nd.inc() ) // beg 2nd loop |
144 |
|
{ |
145 |
|
aux.ib = min( m - i, MR ); |
146 |
|
if ( i + MR >= m ) |
147 |
|
{ |
148 |
|
aux.b_next += ic_comm.GetNumThreads() * PACK_NR * k; |
149 |
|
} |
150 |
|
aux.hi = packAh + ip; |
151 |
|
aux.hj = packBh + jp; |
152 |
|
microkernel |
153 |
|
( |
154 |
|
kernel, |
155 |
|
k, |
156 |
|
KS_RHS, |
157 |
|
packu + ip * KS_RHS, |
158 |
|
packA + ip * k, |
159 |
|
packA2 + ip, |
160 |
|
packB + jp * k, |
161 |
|
packB2 + jp, |
162 |
|
packw + jp * KS_RHS, |
163 |
|
packC + j * ldc + i * NR, MR, // packed |
164 |
|
&aux |
165 |
|
); |
166 |
|
} // end 2nd loop |
167 |
|
} // end 3rd loop |
168 |
|
} // end fused_macro_kernel |
169 |
|
|
170 |
|
|
171 |
|
/** |
172 |
|
* |
173 |
|
*/ |
174 |
|
template< |
175 |
|
int MC, int NC, int KC, int MR, int NR, |
176 |
|
int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE, |
177 |
|
bool USE_L2NORM, bool USE_VAR_BANDWIDTH, bool USE_STRASSEN, |
178 |
|
typename SEMIRINGKERNEL, typename MICROKERNEL, |
179 |
|
typename TA, typename TB, typename TC, typename TV> |
180 |
|
void gsks_internal |
181 |
|
( |
182 |
|
Worker &thread, |
183 |
|
//ks_t *kernel, |
184 |
|
kernel_s<TV, TC> *kernel, |
185 |
|
int m, int n, int k, |
186 |
|
TC *u, int *umap, |
187 |
|
TA *A, TA *A2, int *amap, |
188 |
|
TB *B, TB *B2, int *bmap, |
189 |
|
TC *w, int *wmap, |
190 |
|
SEMIRINGKERNEL semiringkernel, |
191 |
|
MICROKERNEL microkernel, |
192 |
|
int nc, int pack_nc, |
193 |
|
TC *packu, |
194 |
|
TA *packA, TA *packA2, TA *packAh, |
195 |
|
TB *packB, TB *packB2, TB *packBh, |
196 |
|
TC *packw, |
197 |
|
TV *packC, int ldpackc, int padn |
198 |
|
) |
199 |
|
{ |
200 |
|
packu += ( thread.jc_id * thread.ic_nt * thread.jr_nt ) * PACK_MC * KS_RHS |
201 |
|
+ ( thread.ic_id * thread.jr_nt + thread.jr_id ) * PACK_MC * KS_RHS; |
202 |
|
packA += ( thread.jc_id * thread.ic_nt ) * PACK_MC * KC |
203 |
|
+ ( thread.ic_id ) * PACK_MC * KC; |
204 |
|
packA2 += ( thread.jc_id * thread.ic_nt + thread.ic_id ) * PACK_MC; |
205 |
|
packAh += ( thread.jc_id * thread.ic_nt + thread.ic_id ) * PACK_MC; |
206 |
|
packB += ( thread.jc_id ) * pack_nc * KC; |
207 |
|
packB2 += ( thread.jc_id ) * pack_nc; |
208 |
|
packBh += ( thread.jc_id ) * pack_nc; |
209 |
|
packw += ( thread.jc_id ) * pack_nc; |
210 |
|
packC += ( thread.jc_id ) * ldpackc * padn; |
211 |
|
|
212 |
|
auto loop6th = GetRange( 0, n, nc, thread.jc_id, thread.jc_nt ); |
213 |
|
auto loop5th = GetRange( 0, k, KC ); |
214 |
|
auto loop4th = GetRange( 0, m, MC, thread.ic_id, thread.ic_nt ); |
215 |
|
|
216 |
|
for ( int jc = loop6th.beg(); |
217 |
|
jc < loop6th.end(); |
218 |
|
jc += loop6th.inc() ) // beg 6th loop |
219 |
|
{ |
220 |
|
auto &jc_comm = *thread.jc_comm; |
221 |
|
auto jb = min( n - jc, nc ); |
222 |
|
|
223 |
|
for ( int pc = loop5th.beg(); |
224 |
|
pc < loop5th.end(); |
225 |
|
pc += loop5th.inc() ) |
226 |
|
{ |
227 |
|
auto &pc_comm = *thread.pc_comm; |
228 |
|
auto pb = min( k - pc, KC ); |
229 |
|
auto is_the_last_pc_iteration = ( pc + KC >= k ); |
230 |
|
|
231 |
|
auto looppkB = GetRange( 0, jb, NR, thread.ic_jr, pc_comm.GetNumThreads() ); |
232 |
|
auto packpkB = GetRange( 0, jb, PACK_NR, thread.ic_jr, pc_comm.GetNumThreads() ); |
233 |
|
|
234 |
|
for ( int j = looppkB.beg(), jp = packpkB.beg(); |
235 |
|
j < looppkB.end(); |
236 |
|
j += looppkB.inc(), jp += packpkB.inc() ) |
237 |
|
{ |
238 |
|
pack2D<true, PACK_NR> // packB |
239 |
|
( |
240 |
|
min( jb - j, NR ), pb, |
241 |
|
&B[ pc ], k, &bmap[ jc + j ], &packB[ jp * pb ] |
242 |
|
); |
243 |
|
|
244 |
|
|
245 |
|
if ( is_the_last_pc_iteration ) |
246 |
|
{ |
247 |
|
pack2D<true, PACK_NR, true> // packw |
248 |
|
( |
249 |
|
min( jb - j, NR ), 1, |
250 |
|
&w[ 0 ], 1, &wmap[ jc + j ], &packw[ jp * 1 ] |
251 |
|
); |
252 |
|
|
253 |
|
if ( USE_L2NORM ) |
254 |
|
{ |
255 |
|
pack2D<true, PACK_NR> // packB2 |
256 |
|
( |
257 |
|
min( jb - j, NR ), 1, |
258 |
|
&B2[ 0 ], 1, &bmap[ jc + j ], &packB2[ jp * 1 ] |
259 |
|
); |
260 |
|
} |
261 |
|
|
262 |
|
if ( USE_VAR_BANDWIDTH ) |
263 |
|
{ |
264 |
|
pack2D<true, PACK_NR> // packBh |
265 |
|
( |
266 |
|
min( jb - j, NR ), 1, |
267 |
|
kernel->hj, 1, &bmap[ jc + j ], &packBh[ jp * 1 ] |
268 |
|
); |
269 |
|
} |
270 |
|
} |
271 |
|
} |
272 |
|
pc_comm.Barrier(); |
273 |
|
|
274 |
|
for ( int ic = loop4th.beg(); |
275 |
|
ic < loop4th.end(); |
276 |
|
ic += loop4th.inc() ) // beg 4th loop |
277 |
|
{ |
278 |
|
auto &ic_comm = *thread.ic_comm; |
279 |
|
auto ib = min( m - ic, MC ); |
280 |
|
|
281 |
|
auto looppkA = GetRange( 0, ib, MR, thread.jr_id, thread.jr_nt ); |
282 |
|
auto packpkA = GetRange( 0, ib, PACK_MR, thread.jr_id, thread.jr_nt ); |
283 |
|
|
284 |
|
for ( int i = looppkA.beg(), ip = packpkA.beg(); |
285 |
|
i < looppkA.end(); |
286 |
|
i += looppkA.inc(), ip += packpkA.inc() ) |
287 |
|
{ |
288 |
|
pack2D<true, PACK_MR> // packA |
289 |
|
( |
290 |
|
min( ib - i, MR ), pb, |
291 |
|
&A[ pc ], k, &amap[ ic + i ], &packA[ ip * pb ] |
292 |
|
); |
293 |
|
|
294 |
|
if ( is_the_last_pc_iteration ) |
295 |
|
{ |
296 |
|
if ( USE_L2NORM ) |
297 |
|
{ |
298 |
|
pack2D<true, PACK_MR> // packA2 |
299 |
|
( |
300 |
|
min( ib - i, MR ), 1, |
301 |
|
&A2[ 0 ], 1, &amap[ ic + i ], &packA2[ ip * 1 ] |
302 |
|
); |
303 |
|
} |
304 |
|
|
305 |
|
if ( USE_VAR_BANDWIDTH ) // variable bandwidths |
306 |
|
{ |
307 |
|
pack2D<true, PACK_MR> // packAh |
308 |
|
( |
309 |
|
min( ib - i, MR ), 1, |
310 |
|
kernel->hi, 1, &amap[ ic + i ], &packAh[ ip * 1 ] |
311 |
|
); |
312 |
|
} |
313 |
|
} |
314 |
|
} |
315 |
|
|
316 |
|
if ( is_the_last_pc_iteration ) // Initialize packu to zeros. |
317 |
|
{ |
318 |
|
for ( auto i = 0, ip = 0; i < ib; i += MR, ip += PACK_MR ) |
319 |
|
{ |
320 |
|
for ( auto ir = 0; ir < min( ib - i, MR ); ir ++ ) |
321 |
|
{ |
322 |
|
packu[ ip + ir ] = 0.0; |
323 |
|
} |
324 |
|
} |
325 |
|
} |
326 |
|
ic_comm.Barrier(); |
327 |
|
|
328 |
|
|
329 |
|
if ( is_the_last_pc_iteration ) // fused_macro_kernel |
330 |
|
{ |
331 |
|
fused_macro_kernel |
332 |
|
<KC, MR, NR, PACK_MR, PACK_NR, MICROKERNEL, TA, TB, TC, TV> |
333 |
|
( |
334 |
|
kernel, |
335 |
|
thread, |
336 |
|
ic, jc, pc, |
337 |
|
ib, jb, pb, |
338 |
|
packu, |
339 |
|
packA, packA2, packAh, |
340 |
|
packB, packB2, packBh, |
341 |
|
packw, |
342 |
|
packC + ic * padn, // packed |
343 |
|
( ( ib - 1 ) / MR + 1 ) * MR, // packed ldc |
344 |
|
microkernel |
345 |
|
); |
346 |
|
} |
347 |
|
else // semiring rank-k update |
348 |
|
{ |
349 |
|
rank_k_macro_kernel |
350 |
|
<KC, MR, NR, PACK_MR, PACK_NR, SEMIRINGKERNEL, TA, TB, TC, TV> |
351 |
|
( |
352 |
|
thread, |
353 |
|
ic, jc, pc, |
354 |
|
ib, jb, pb, |
355 |
|
packA, |
356 |
|
packB, |
357 |
|
packC + ic * padn, // packed |
358 |
|
( ( ib - 1 ) / MR + 1 ) * MR, // packed ldc |
359 |
|
semiringkernel |
360 |
|
); |
361 |
|
} |
362 |
|
ic_comm.Barrier(); // sync all jr_id!! |
363 |
|
|
364 |
|
if ( is_the_last_pc_iteration ) |
365 |
|
{ |
366 |
|
for ( auto i = 0, ip = 0; i < ib; i += MR, ip += PACK_MR ) |
367 |
|
{ |
368 |
|
for ( auto ir = 0; ir < min( ib - i, MR ); ir ++ ) |
369 |
|
{ |
370 |
|
TC *uptr = &( u[ umap[ ic + i + ir ] ] ); |
371 |
|
#pragma omp atomic update // concurrent write |
372 |
|
*uptr += packu[ ip + ir ]; |
373 |
|
} |
374 |
|
} |
375 |
|
ic_comm.Barrier(); // sync all jr_id!! |
376 |
|
} |
377 |
|
} // end 4th loop |
378 |
|
pc_comm.Barrier(); |
379 |
|
} // end 5th loop |
380 |
|
} // end 6th loop |
381 |
|
} // end gsks_internal |
382 |
|
|
383 |
|
|
384 |
|
|
385 |
|
|
386 |
|
|
387 |
|
/** |
388 |
|
* |
389 |
|
*/ |
390 |
|
template< |
391 |
|
int MC, int NC, int KC, int MR, int NR, |
392 |
|
int PACK_MC, int PACK_NC, int PACK_MR, int PACK_NR, int ALIGN_SIZE, |
393 |
|
bool USE_L2NORM, bool USE_VAR_BANDWIDTH, bool USE_STRASSEN, |
394 |
|
typename SEMIRINGKERNEL, typename MICROKERNEL, |
395 |
|
typename TA, typename TB, typename TC, typename TV> |
396 |
|
void gsks |
397 |
|
( |
398 |
|
kernel_s<TV, TC> *kernel, |
399 |
|
int m, int n, int k, |
400 |
|
TC *u, int *umap, |
401 |
|
TA *A, TA *A2, int *amap, |
402 |
|
TB *B, TB *B2, int *bmap, |
403 |
|
TC *w, int *wmap, |
404 |
|
SEMIRINGKERNEL semiringkernel, |
405 |
|
MICROKERNEL microkernel |
406 |
|
) |
407 |
|
{ |
408 |
|
int jc_nt = 1, pc_nt = 1, ic_nt = 1, jr_nt = 1; |
409 |
|
int ldpackc = 0, padn = 0, nc = NC, pack_nc = PACK_NC; |
410 |
|
char *str; |
411 |
|
|
412 |
|
TC *packu_buff = NULL; |
413 |
|
TA *packA_buff = NULL, *packA2_buff = NULL, *packAh_buff = NULL; |
414 |
|
TB *packB_buff = NULL, *packB2_buff = NULL, *packBh_buff = NULL; |
415 |
|
TC *packw_buff = NULL; |
416 |
|
TV *packC_buff = NULL; |
417 |
|
|
418 |
|
// Early return if possible |
419 |
|
if ( m == 0 || n == 0 || k == 0 ) return; |
420 |
|
|
421 |
|
// Check the environment variable. |
422 |
|
jc_nt = hmlp_read_nway_from_env( "KS_JC_NT" ); |
423 |
|
ic_nt = hmlp_read_nway_from_env( "KS_IC_NT" ); |
424 |
|
jr_nt = hmlp_read_nway_from_env( "KS_JR_NT" ); |
425 |
|
|
426 |
|
if ( jc_nt > 1 ) |
427 |
|
{ |
428 |
|
nc = ( ( n - 1 ) / ( NR * jc_nt ) + 1 ) * NR; |
429 |
|
pack_nc = ( nc / NR ) * PACK_NR; |
430 |
|
} |
431 |
|
|
432 |
|
// allocate packing memory |
433 |
|
{ |
434 |
|
packA_buff = hmlp_malloc<ALIGN_SIZE, TA>( KC, ( PACK_MC + 1 ) * jc_nt * ic_nt, sizeof(TA) ); |
435 |
|
packB_buff = hmlp_malloc<ALIGN_SIZE, TB>( KC, ( pack_nc + 1 ) * jc_nt, sizeof(TB) ); |
436 |
|
packu_buff = hmlp_malloc<ALIGN_SIZE, TC>( 1, ( PACK_MC + 1 ) * jc_nt * ic_nt * jr_nt, sizeof(TC) ); |
437 |
|
packw_buff = hmlp_malloc<ALIGN_SIZE, TC>( 1, ( pack_nc + 1 ) * jc_nt, sizeof(TC) ); |
438 |
|
} |
439 |
|
|
440 |
|
// allocate extra packing buffer |
441 |
|
if ( USE_L2NORM ) |
442 |
|
{ |
443 |
|
packA2_buff = hmlp_malloc<ALIGN_SIZE, TA>( 1, ( PACK_MC + 1 ) * jc_nt * ic_nt, sizeof(TA) ); |
444 |
|
packB2_buff = hmlp_malloc<ALIGN_SIZE, TB>( 1, ( pack_nc + 1 ) * jc_nt, sizeof(TB) ); |
445 |
|
} |
446 |
|
|
447 |
|
if ( USE_VAR_BANDWIDTH ) |
448 |
|
{ |
449 |
|
packAh_buff = hmlp_malloc<ALIGN_SIZE, TA>( 1, ( PACK_MC + 1 ) * jc_nt * ic_nt, sizeof(TA) ); |
450 |
|
packBh_buff = hmlp_malloc<ALIGN_SIZE, TB>( 1, ( pack_nc + 1 ) * jc_nt, sizeof(TB) ); |
451 |
|
} |
452 |
|
|
453 |
|
// Temporary bufferm <TV> to store the semi-ring rank-k update |
454 |
|
if ( k > KC ) |
455 |
|
{ |
456 |
|
ldpackc = ( ( m - 1 ) / PACK_MR + 1 ) * PACK_MR; |
457 |
|
padn = pack_nc; |
458 |
|
if ( n < nc ) padn = ( ( n - 1 ) / PACK_NR + 1 ) * PACK_NR ; |
459 |
|
packC_buff = hmlp_malloc<ALIGN_SIZE, TV>( ldpackc, padn * jc_nt, sizeof(TV) ); |
460 |
|
} |
461 |
|
|
462 |
|
// allocate tree communicator |
463 |
|
thread_communicator my_comm( jc_nt, pc_nt, ic_nt, jr_nt ); |
464 |
|
|
465 |
|
|
466 |
|
#pragma omp parallel num_threads( my_comm.GetNumThreads() ) |
467 |
|
{ |
468 |
|
Worker thread( &my_comm ); |
469 |
|
|
470 |
|
if ( USE_STRASSEN ) |
471 |
|
{ |
472 |
|
printf( "gsks: strassen algorithms haven't been implemented." ); |
473 |
|
exit( 1 ); |
474 |
|
} |
475 |
|
|
476 |
|
gsks_internal |
477 |
|
<MC, NC, KC, MR, NR, PACK_MC, PACK_NC, PACK_MR, PACK_NR, ALIGN_SIZE, |
478 |
|
USE_L2NORM, USE_VAR_BANDWIDTH, USE_STRASSEN, |
479 |
|
SEMIRINGKERNEL, MICROKERNEL, |
480 |
|
TA, TB, TC, TB> |
481 |
|
( |
482 |
|
thread, |
483 |
|
kernel, |
484 |
|
m, n, k, |
485 |
|
u, umap, |
486 |
|
A, A2, amap, |
487 |
|
B, B2, bmap, |
488 |
|
w, wmap, |
489 |
|
semiringkernel, microkernel, |
490 |
|
nc, pack_nc, |
491 |
|
packu_buff, |
492 |
|
packA_buff, packA2_buff, packAh_buff, |
493 |
|
packB_buff, packB2_buff, packBh_buff, |
494 |
|
packw_buff, |
495 |
|
packC_buff, ldpackc, padn |
496 |
|
); |
497 |
|
|
498 |
|
} /** end omp region */ |
499 |
|
|
500 |
|
hmlp_free( packA_buff ); |
501 |
|
hmlp_free( packB_buff ); |
502 |
|
hmlp_free( packu_buff ); |
503 |
|
hmlp_free( packw_buff ); |
504 |
|
if ( USE_L2NORM ) |
505 |
|
{ |
506 |
|
hmlp_free( packA2_buff ); |
507 |
|
hmlp_free( packB2_buff ); |
508 |
|
} |
509 |
|
} /** end gsks() */ |
510 |
|
|
511 |
|
|
512 |
|
/** |
513 |
|
* |
514 |
|
*/ |
515 |
|
template<typename T> |
516 |
|
void gsks_ref |
517 |
|
( |
518 |
|
//ks_t *kernel, |
519 |
|
kernel_s<T, T> *kernel, |
520 |
|
int m, int n, int k, |
521 |
|
T *u, int *umap, |
522 |
|
T *A, T *A2, int *amap, |
523 |
|
T *B, T *B2, int *bmap, |
524 |
|
T *w, int *wmap |
525 |
|
) |
526 |
|
{ |
527 |
|
int nrhs = KS_RHS; |
528 |
|
T rank_k_scale, fone = 1.0, fzero = 0.0; |
529 |
|
std::vector<T> packA, packB, C, packu, packw; |
530 |
|
|
531 |
|
// Early return if possible |
532 |
|
if ( m == 0 || n == 0 || k == 0 ) return; |
533 |
|
|
534 |
|
packA.resize( k * m ); |
535 |
|
packB.resize( k * n ); |
536 |
|
C.resize( m * n ); |
537 |
|
packu.resize( m ); |
538 |
|
packw.resize( n ); |
539 |
|
|
540 |
|
switch ( kernel->type ) |
541 |
|
{ |
542 |
|
case GAUSSIAN: |
543 |
|
rank_k_scale = -2.0; |
544 |
|
break; |
545 |
|
case GAUSSIAN_VAR_BANDWIDTH: |
546 |
|
rank_k_scale = -2.0; |
547 |
|
break; |
548 |
|
default: |
549 |
|
exit( 1 ); |
550 |
|
} |
551 |
|
|
552 |
|
/* |
553 |
|
* Collect packA and packu |
554 |
|
*/ |
555 |
|
#pragma omp parallel for |
556 |
|
for ( int i = 0; i < m; i ++ ) |
557 |
|
{ |
558 |
|
for ( int p = 0; p < k; p ++ ) |
559 |
|
{ |
560 |
|
packA[ i * k + p ] = A[ amap[ i ] * k + p ]; |
561 |
|
} |
562 |
|
for ( int p = 0; p < KS_RHS; p ++ ) |
563 |
|
{ |
564 |
|
packu[ p * m + i ] = u[ umap[ i ] * KS_RHS + p ]; |
565 |
|
} |
566 |
|
} |
567 |
|
|
568 |
|
/* |
569 |
|
* Collect packB and packw |
570 |
|
*/ |
571 |
|
#pragma omp parallel for |
572 |
|
for ( int j = 0; j < n; j ++ ) |
573 |
|
{ |
574 |
|
for ( int p = 0; p < k; p ++ ) |
575 |
|
{ |
576 |
|
packB[ j * k + p ] = B[ bmap[ j ] * k + p ]; |
577 |
|
} |
578 |
|
for ( int p = 0; p < KS_RHS; p ++ ) |
579 |
|
{ |
580 |
|
packw[ p * n + j ] = w[ wmap[ j ] * KS_RHS + p ]; |
581 |
|
} |
582 |
|
} |
583 |
|
|
584 |
|
/* |
585 |
|
* C = -2.0 * A^T * B (GEMM) |
586 |
|
*/ |
587 |
|
#ifdef USE_BLAS |
588 |
|
xgemm |
589 |
|
( |
590 |
|
"T", "N", |
591 |
|
m, n, k, |
592 |
|
rank_k_scale, packA.data(), k, |
593 |
|
packB.data(), k, |
594 |
|
fzero, C.data(), m |
595 |
|
); |
596 |
|
#else |
597 |
|
#pragma omp parallel for |
598 |
|
for ( int j = 0; j < n; j ++ ) |
599 |
|
{ |
600 |
|
for ( int i = 0; i < m; i ++ ) |
601 |
|
{ |
602 |
|
C[ j * m + i ] = 0.0; |
603 |
|
for ( int p = 0; p < k; p ++ ) |
604 |
|
{ |
605 |
|
C[ j * m + i ] += packA[ i * k + p ] * packB[ j * k + p ]; |
606 |
|
} |
607 |
|
} |
608 |
|
} |
609 |
|
#pragma omp parallel for |
610 |
|
for ( int j = 0; j < n; j ++ ) |
611 |
|
{ |
612 |
|
for ( int i = 0; i < m; i ++ ) |
613 |
|
{ |
614 |
|
C[ j * m + i ] *= rank_k_scale; |
615 |
|
} |
616 |
|
} |
617 |
|
#endif |
618 |
|
|
619 |
|
switch ( kernel->type ) |
620 |
|
{ |
621 |
|
case GAUSSIAN: |
622 |
|
{ |
623 |
|
#pragma omp parallel for |
624 |
|
for ( int j = 0; j < n; j ++ ) |
625 |
|
{ |
626 |
|
for ( int i = 0; i < m; i ++ ) |
627 |
|
{ |
628 |
|
C[ j * m + i ] += A2[ amap[ i ] ]; |
629 |
|
C[ j * m + i ] += B2[ bmap[ j ] ]; |
630 |
|
C[ j * m + i ] *= kernel->scal; |
631 |
|
} |
632 |
|
for ( int i = 0; i < m; i ++ ) |
633 |
|
{ |
634 |
|
C[ j * m + i ] = exp( C[ j * m + i ] ); |
635 |
|
} |
636 |
|
} |
637 |
|
break; |
638 |
|
} |
639 |
|
case GAUSSIAN_VAR_BANDWIDTH: |
640 |
|
{ |
641 |
|
#pragma omp parallel for |
642 |
|
for ( int j = 0; j < n; j ++ ) |
643 |
|
{ |
644 |
|
for ( int i = 0; i < m; i ++ ) |
645 |
|
{ |
646 |
|
C[ j * m + i ] += A2[ amap[ i ] ]; |
647 |
|
C[ j * m + i ] += B2[ bmap[ j ] ]; |
648 |
|
C[ j * m + i ] *= -0.5; |
649 |
|
C[ j * m + i ] *= kernel->hi[ i ]; |
650 |
|
C[ j * m + i ] *= kernel->hj[ j ]; |
651 |
|
} |
652 |
|
for ( int i = 0; i < m; i ++ ) |
653 |
|
{ |
654 |
|
C[ j * m + i ] = exp( C[ j * m + i ] ); |
655 |
|
} |
656 |
|
} |
657 |
|
break; |
658 |
|
} |
659 |
|
default: |
660 |
|
exit( 1 ); |
661 |
|
} |
662 |
|
|
663 |
|
/* |
664 |
|
* Kernel Summation |
665 |
|
*/ |
666 |
|
#ifdef USE_BLAS |
667 |
|
xgemm |
668 |
|
( |
669 |
|
"N", "N", |
670 |
|
m, nrhs, n, |
671 |
|
fone, C.data(), m, |
672 |
|
packw.data(), n, |
673 |
|
fone, packu.data(), m |
674 |
|
); |
675 |
|
#else |
676 |
|
#pragma omp parallel for |
677 |
|
for ( int i = 0; i < m; i ++ ) |
678 |
|
{ |
679 |
|
for ( int j = 0; j < nrhs; j ++ ) |
680 |
|
{ |
681 |
|
for ( int p = 0; p < n; p ++ ) |
682 |
|
{ |
683 |
|
packu[ j * m + i ] += C[ p * m + i ] * packw[ j * n + p ]; |
684 |
|
} |
685 |
|
} |
686 |
|
} |
687 |
|
#endif |
688 |
|
|
689 |
|
/* |
690 |
|
* Assemble packu back |
691 |
|
*/ |
692 |
|
#pragma omp parallel for |
693 |
|
for ( int i = 0; i < m; i ++ ) |
694 |
|
{ |
695 |
|
for ( int p = 0; p < KS_RHS; p ++ ) |
696 |
|
{ |
697 |
|
u[ umap[ i ] * KS_RHS + p ] = packu[ p * m + i ]; |
698 |
|
} |
699 |
|
} |
700 |
|
|
701 |
|
} // end void gsks_ref |
702 |
|
|
703 |
|
|
704 |
|
}; /** end namespace gsks */ |
705 |
|
}; /** end namespace hmlp */ |
706 |
|
|
707 |
|
#endif // define GSKS_HXX |