|
|
|
@ -96,11 +96,6 @@ void RowConv<DEVICE_TYPE_GPU>(GpuMatrix& out,
|
|
|
|
|
const size_t height = in.getHeight();
|
|
|
|
|
const size_t width = in.getWidth();
|
|
|
|
|
|
|
|
|
|
LOG(INFO) << numSeq;
|
|
|
|
|
LOG(INFO) << contextLength;
|
|
|
|
|
LOG(INFO) << height;
|
|
|
|
|
LOG(INFO) << width;
|
|
|
|
|
|
|
|
|
|
real* y = out.getData();
|
|
|
|
|
const real* x = in.getData();
|
|
|
|
|
const real* w = filter.getData();
|
|
|
|
@ -108,7 +103,6 @@ void RowConv<DEVICE_TYPE_GPU>(GpuMatrix& out,
|
|
|
|
|
|
|
|
|
|
dim3 dimBlock(32, 32);
|
|
|
|
|
dim3 dimGrid(DIVUP(width, dimBlock.x), 1);
|
|
|
|
|
LOG(INFO) << dimGrid.x;
|
|
|
|
|
|
|
|
|
|
if (contextLength <= 32) {
|
|
|
|
|
KeRowConv<32, 32><<<dimGrid, dimBlock, 0, STREAM_DEFAULT>>>
|
|
|
|
@ -131,12 +125,12 @@ __global__ void KeRowConvBwWeight(real* dw, const real* x, const real* dy,
|
|
|
|
|
const int blky = blockDim.y;
|
|
|
|
|
const int gidx = blockIdx.x * blockDim.x;
|
|
|
|
|
|
|
|
|
|
__shared__ real sh_x[BLOCK_H][BLOCK_W];
|
|
|
|
|
__shared__ real sh_dy[BLOCK_H][BLOCK_W];
|
|
|
|
|
__shared__ real sh_x[BLOCK_W][BLOCK_H];
|
|
|
|
|
__shared__ real sh_dy[BLOCK_W][BLOCK_H + CONTEXT - 1];
|
|
|
|
|
__shared__ real sh_dw[CONTEXT][BLOCK_W];
|
|
|
|
|
|
|
|
|
|
for (int t = tidy; t < context; t += blky) {
|
|
|
|
|
sh_dw[t][tidx] = 0.0;
|
|
|
|
|
if (tidy < context) {
|
|
|
|
|
sh_dw[tidy][tidx] = 0.0;
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
@ -144,21 +138,31 @@ __global__ void KeRowConvBwWeight(real* dw, const real* x, const real* dy,
|
|
|
|
|
const int start = starts[i];
|
|
|
|
|
const int end = starts[i + 1];
|
|
|
|
|
const int steps = end - start;
|
|
|
|
|
for (int j = tidy; j < steps; j += BLOCK_H) {
|
|
|
|
|
const int size = ((steps + BLOCK_H - 1)/BLOCK_H) * BLOCK_H;
|
|
|
|
|
for (int j = tidy; j < size; j += BLOCK_H) {
|
|
|
|
|
int xoff = gidx + tidx;
|
|
|
|
|
int yoff = start + j;
|
|
|
|
|
|
|
|
|
|
// transpose
|
|
|
|
|
sh_x[tidx][tidy] = xoff < width && yoff < end ? x[yoff * width + xoff] : 0.0;
|
|
|
|
|
sh_dy[tidx][tidy] = xoff < width && yoff < end ? dy[yoff * width + xoff] : 0.0;
|
|
|
|
|
sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0;
|
|
|
|
|
sh_dy[tidx][tidy + context - 1] = (xoff < width && yoff < end) ? dy[yoff * width + xoff] : 0.0;
|
|
|
|
|
__syncthreads();
|
|
|
|
|
if (tidy < (context - 1)) {
|
|
|
|
|
yoff = yoff - context + 1;
|
|
|
|
|
sh_dy[tidx][tidy] = (xoff < width && yoff >= start) ? dy[yoff * width + xoff] : 0.0;
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
for (int t = 0; t < context; t++) {
|
|
|
|
|
real val = tidx + t < blockDim.x ? sh_x[tidy][tidx + t] * sh_dy[tidy][tidx]: 0.0;
|
|
|
|
|
real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx + context - 1 - t];
|
|
|
|
|
__syncthreads();
|
|
|
|
|
// warp size and blockDim.x is 32.
|
|
|
|
|
for (int offset = 16; offset > 0; offset /= 2) {
|
|
|
|
|
val += __shfl_down(val, offset);
|
|
|
|
|
}
|
|
|
|
|
val += __shfl_down(val, 16);
|
|
|
|
|
val += __shfl_down(val, 8);
|
|
|
|
|
val += __shfl_down(val, 4);
|
|
|
|
|
val += __shfl_down(val, 2);
|
|
|
|
|
val += __shfl_down(val, 1);
|
|
|
|
|
__syncthreads();
|
|
|
|
|
if (tidx == 0) {
|
|
|
|
|
sh_dw[t][tidy] += val;
|
|
|
|
|
}
|
|
|
|
@ -167,7 +171,7 @@ __global__ void KeRowConvBwWeight(real* dw, const real* x, const real* dy,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int t = tidy; t < context && (gidx + tidx) < width; t += blky) {
|
|
|
|
|
for (int t = tidy; (t < context) && ((gidx + tidx) < width); t += blky) {
|
|
|
|
|
dw[t * width + gidx + tidx] += sh_dw[t][tidx];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -188,21 +192,30 @@ __global__ void KeRowConvBwWeight2(real* dw, const real* x, const real* dy,
|
|
|
|
|
const int start = starts[i];
|
|
|
|
|
const int end = starts[i + 1];
|
|
|
|
|
const int steps = end - start;
|
|
|
|
|
for (int j = 0; j < steps; j += BLOCK_H) {
|
|
|
|
|
|
|
|
|
|
const int size = ((steps + BLOCK_H - 1)/BLOCK_H) * BLOCK_H;
|
|
|
|
|
for (int j = tidy; j < size; j += BLOCK_H) {
|
|
|
|
|
int xoff = gidx + tidx;
|
|
|
|
|
int yoff = start + j;
|
|
|
|
|
|
|
|
|
|
// transpose
|
|
|
|
|
sh_x[tidx][tidy] = xoff < width && yoff < end ? x[yoff * width + xoff] : 0.0;
|
|
|
|
|
sh_dy[tidx][tidy] = xoff < width && yoff < end ? dy[yoff * width + xoff] : 0.0;
|
|
|
|
|
sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0;
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
for (int t = 0; t < context; t++) {
|
|
|
|
|
real val = tidx + t < blockDim.x ? sh_x[tidy][tidx + t] * sh_dy[tidy][tidx]: 0.0;
|
|
|
|
|
sh_dy[tidx][tidy] = (xoff < width && (yoff - t) >= start && yoff - t < end) ? dy[(yoff - t) * width + xoff] : 0.0;
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx];
|
|
|
|
|
__syncthreads();
|
|
|
|
|
// warp size and blockDim.x is 32.
|
|
|
|
|
for (int offset = 16; offset > 0; offset /= 2) {
|
|
|
|
|
val += __shfl_down(val, offset);
|
|
|
|
|
}
|
|
|
|
|
val += __shfl_down(val, 16);
|
|
|
|
|
val += __shfl_down(val, 8);
|
|
|
|
|
val += __shfl_down(val, 4);
|
|
|
|
|
val += __shfl_down(val, 2);
|
|
|
|
|
val += __shfl_down(val, 1);
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
if (tidx == 0 && (gidx + tidy) < width) {
|
|
|
|
|
dw[t*width + gidx + tidy] += val;
|
|
|
|
|
}
|
|
|
|
@ -293,13 +306,12 @@ void RowConvGrad<DEVICE_TYPE_GPU>(const GpuMatrix& outG,
|
|
|
|
|
const real* dy = outG.getData();
|
|
|
|
|
const real* x = in.getData();
|
|
|
|
|
const real* w = filter.getData();
|
|
|
|
|
real* dx = inG.getData();
|
|
|
|
|
real* dw = filterG.getData();
|
|
|
|
|
const int* starts = seq.getData();
|
|
|
|
|
|
|
|
|
|
if (filterG) {
|
|
|
|
|
dim3 dimBlock(32, 32);
|
|
|
|
|
dim3 dimGrid(DIVUP(width, dimBlock.x), 1);
|
|
|
|
|
|
|
|
|
|
real* dw = filterG.getData();
|
|
|
|
|
if (contextLength <= 16) {
|
|
|
|
|
KeRowConvBwWeight<32, 32, 16>
|
|
|
|
|
<<<dimGrid, dimBlock, 0, STREAM_DEFAULT>>>
|
|
|
|
@ -309,8 +321,10 @@ void RowConvGrad<DEVICE_TYPE_GPU>(const GpuMatrix& outG,
|
|
|
|
|
<<<dimGrid, dimBlock, 0, STREAM_DEFAULT>>>
|
|
|
|
|
(dw, x, dy, starts, height, width, numSeq, contextLength);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (inG) {
|
|
|
|
|
real* dx = inG.getData();
|
|
|
|
|
dim3 dimBlock2(32, 32);
|
|
|
|
|
dim3 dimGrid2(DIVUP(width, dimBlock2.x), 1);
|
|
|
|
|
if (contextLength <= 64) {
|
|
|
|
@ -322,6 +336,7 @@ void RowConvGrad<DEVICE_TYPE_GPU>(const GpuMatrix& outG,
|
|
|
|
|
<<<dimGrid2, dimBlock2, 0, STREAM_DEFAULT>>>
|
|
|
|
|
(dx, w, dy, starts, height, width, numSeq, contextLength);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CHECK_SYNC("RowConvGrad");
|
|
|
|
|
}
|
|
|
|
|