diff --git a/paddle/operators/lstm_unit_op.cu b/paddle/operators/lstm_unit_op.cu index fe45360bb3..6e5e497899 100644 --- a/paddle/operators/lstm_unit_op.cu +++ b/paddle/operators/lstm_unit_op.cu @@ -35,7 +35,7 @@ __device__ Dtype cuda_tanh(const Dtype x) { } template <typename T> -__global__ void LSTMUnitKernel(const int nthreads, const int dim, const int t, +__global__ void LSTMUnitKernel(const int nthreads, const int dim, const T* C_prev, const T* X, T* C, T* H, const T forget_bias) { CUDA_1D_KERNEL_LOOP(index, nthreads) { @@ -159,9 +159,9 @@ class LstmUnitGradOpCUDAKernel : public framework::OpKernel { int n = N * D; int grid = (n + block - 1) / block; - LSTMUnitGradientKernel<T><<<N * D, block>>>(n, D, C_prev, X, C, H, C_diff, - H_diff, C_prev_diff, X_diff, - T forget_bias) + LSTMUnitGradientKernel<T><<<grid, block>>>(n, D, C_prev, X, C, H, C_diff, + H_diff, C_prev_diff, X_diff, + forget_bias); } };