|
|
|
@ -144,12 +144,15 @@ __global__ void KeRowConvBwWeight(real* dw, const real* x, const real* dy,
|
|
|
|
|
int yoff = start + j;
|
|
|
|
|
|
|
|
|
|
// transpose
|
|
|
|
|
sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0;
|
|
|
|
|
sh_dy[tidx][tidy + context - 1] = (xoff < width && yoff < end) ? dy[yoff * width + xoff] : 0.0;
|
|
|
|
|
sh_x[tidx][tidy] = (xoff < width && yoff < end) ?
|
|
|
|
|
x[yoff * width + xoff] : 0.0;
|
|
|
|
|
sh_dy[tidx][tidy + context - 1] = (xoff < width && yoff < end) ?
|
|
|
|
|
dy[yoff * width + xoff] : 0.0;
|
|
|
|
|
__syncthreads();
|
|
|
|
|
if (tidy < (context - 1)) {
|
|
|
|
|
yoff = yoff - context + 1;
|
|
|
|
|
sh_dy[tidx][tidy] = (xoff < width && yoff >= start) ? dy[yoff * width + xoff] : 0.0;
|
|
|
|
|
sh_dy[tidx][tidy] = (xoff < width && yoff >= start) ?
|
|
|
|
|
dy[yoff * width + xoff] : 0.0;
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
@ -199,11 +202,13 @@ __global__ void KeRowConvBwWeight2(real* dw, const real* x, const real* dy,
|
|
|
|
|
int yoff = start + j;
|
|
|
|
|
|
|
|
|
|
// transpose
|
|
|
|
|
sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0;
|
|
|
|
|
sh_x[tidx][tidy] = (xoff < width && yoff < end) ?
|
|
|
|
|
x[yoff * width + xoff] : 0.0;
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
for (int t = 0; t < context; t++) {
|
|
|
|
|
sh_dy[tidx][tidy] = (xoff < width && (yoff - t) >= start && yoff - t < end) ? dy[(yoff - t) * width + xoff] : 0.0;
|
|
|
|
|
sh_dy[tidx][tidy] = (xoff < width && (yoff - t) >= start &&
|
|
|
|
|
yoff - t < end) ? dy[(yoff - t) * width + xoff] : 0.0;
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx];
|
|
|
|
|