|
|
|
@ -205,14 +205,14 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(cudnn_lstm, ops::CudnnLSTMOp, ops::CudnnLSTMOpMaker,
|
|
|
|
|
REGISTER_OPERATOR(lstm_cudnn, ops::CudnnLSTMOp, ops::CudnnLSTMOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp);
|
|
|
|
|
REGISTER_OPERATOR(lstm_cudnn_grad, ops::CudnnLSTMGradOp);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
cudnn_lstm,
|
|
|
|
|
lstm_cudnn,
|
|
|
|
|
ops::CudnnLSTMKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
cudnn_lstm_grad,
|
|
|
|
|
lstm_cudnn_grad,
|
|
|
|
|
ops::CudnnLSTMGradKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|