|
|
@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License. */
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/lstm_cudnn_op.h"
|
|
|
|
#include "paddle/fluid/operators/cudnn_lstm_op.h"
|
|
|
|
#include <string>
|
|
|
|
#include <string>
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
@ -205,12 +205,12 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
REGISTER_OPERATOR(lstm_cudnn, ops::CudnnLSTMOp, ops::CudnnLSTMOpMaker,
|
|
|
|
REGISTER_OPERATOR(cudnn_lstm, ops::CudnnLSTMOp, ops::CudnnLSTMOpMaker,
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
REGISTER_OPERATOR(lstm_cudnn_grad, ops::CudnnLSTMGradOp);
|
|
|
|
REGISTER_OPERATOR(lstm_cudnn_grad, ops::CudnnLSTMGradOp);
|
|
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
lstm_cudnn,
|
|
|
|
cudnn_lstm,
|
|
|
|
ops::CudnnLSTMKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
ops::CudnnLSTMKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|