|
|
|
@ -312,8 +312,8 @@ void RowConvGrad<DEVICE_TYPE_GPU>(const GpuMatrix& outG,
|
|
|
|
|
dim3 dimBlock(32, 32);
|
|
|
|
|
dim3 dimGrid(DIVUP(width, dimBlock.x), 1);
|
|
|
|
|
real* dw = filterG.getData();
|
|
|
|
|
if (contextLength <= 16) {
|
|
|
|
|
KeRowConvBwWeight<32, 32, 16>
|
|
|
|
|
if (contextLength <= 32) {
|
|
|
|
|
KeRowConvBwWeight<32, 32, 32>
|
|
|
|
|
<<<dimGrid, dimBlock, 0, STREAM_DEFAULT>>>
|
|
|
|
|
(dw, x, dy, starts, height, width, numSeq, contextLength);
|
|
|
|
|
} else {
|
|
|
|
|