diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc index 86632fc9fb..e63d57be57 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cc @@ -12,12 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/cudnn_lstm_op.h" #include <string> - -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/platform/cudnn_helper.h" -#endif +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -201,18 +197,22 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel { } }; +template <typename T> +class NotImpleKernel : public framework::OpKernel<T> { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW( + "CPU is not support for this kernel now. Will be add in the future"); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(cudnn_lstm, ops::CudnnLSTMOp, ops::CudnnLSTMOpMaker, paddle::framework::DefaultGradOpDescMaker<true>); -REGISTER_OPERATOR(lstm_cudnn_grad, ops::CudnnLSTMGradOp); - -REGISTER_OP_CPU_KERNEL( - cudnn_lstm, - ops::CudnnLSTMKernel<paddle::platform::CPUDeviceContext, float>); +REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp); -REGISTER_OP_CPU_KERNEL( - lstm_cudnn_grad, - ops::CudnnLSTMGradKernel<paddle::platform::CPUDeviceContext, float>); +REGISTER_OP_CPU_KERNEL(cudnn_lstm, ops::NotImpleKernel<float>); +REGISTER_OP_CPU_KERNEL(cudnn_lstm_grad, ops::NotImpleKernel<float>); diff --git a/paddle/fluid/operators/cudnn_lstm_op.cu.cc b/paddle/fluid/operators/cudnn_lstm_op.cu.cc index 811975a9f3..cad62de754 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cu.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cu.cc @@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/cudnn_lstm_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cudnn_helper.h" namespace paddle { @@ -246,7 +247,7 @@ struct CudnnRNNCache { } }; -template <typename DeviceContext, typename T> +template <typename T> class CudnnLSTMGPUKernel : public framework::OpKernel<T> { public: void Compute(const framework::ExecutionContext &ctx) const override { @@ -343,7 +344,7 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> { } }; -template <typename DeviceContext, typename T> +template <typename T> class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> { public: void Compute(const framework::ExecutionContext &ctx) const override { @@ -380,7 +381,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> { auto init_c_dims = init_c->dims(); in_grad->mutable_data<T>(ctx.GetPlace()); weight_grad->mutable_data<T>(ctx.GetPlace()); - math::SetConstant<DeviceContext, T> zero; + math::SetConstant<paddle::platform::CUDADeviceContext, T> zero; zero(dev_ctx, in_grad, static_cast<T>(0.0)); zero(dev_ctx, weight_grad, static_cast<T>(0.0)); @@ -486,9 +487,5 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - cudnn_lstm, - ops::CudnnLSTMGPUKernel<paddle::platform::CUDADeviceContext, float>); -REGISTER_OP_CUDA_KERNEL( - cudnn_lstm_grad, - ops::CudnnLSTMGPUGradKernel<paddle::platform::CUDADeviceContext, float>); +REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel<float>); +REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel<float>); diff --git a/paddle/fluid/operators/cudnn_lstm_op.h b/paddle/fluid/operators/cudnn_lstm_op.h deleted file mode 100644 index fc329cc239..0000000000 --- a/paddle/fluid/operators/cudnn_lstm_op.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include <string> -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/detail/activation_functions.h" -#include "paddle/fluid/operators/math/lstm_compute.h" -#include "paddle/fluid/operators/math/sequence2batch.h" - -namespace paddle { -namespace operators { - -using LoDTensor = framework::LoDTensor; -using Tensor = framework::Tensor; - -template <typename DeviceContext, typename T> -class CudnnLSTMKernel : public framework::OpKernel<T> { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_THROW( - "CPU is not support for this kernel now. Will be add in the future"); - } -}; - -template <typename DeviceContext, typename T> -class CudnnLSTMGradKernel : public framework::OpKernel<T> { - public: - void Compute(const framework::ExecutionContext& ctx) const override {} -}; - -} // namespace operators -} // namespace paddle