fix operator.cmake

test=develop
add_cudnn_lstm
chengduozh 6 years ago
parent 679d8fc6fe
commit af8c2cec13

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/lstm_cudnn_op.h" #include "paddle/fluid/operators/cudnn_lstm_op.h"
#include <string> #include <string>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
@ -205,12 +205,12 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(lstm_cudnn, ops::CudnnLSTMOp, ops::CudnnLSTMOpMaker, REGISTER_OPERATOR(cudnn_lstm, ops::CudnnLSTMOp, ops::CudnnLSTMOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(lstm_cudnn_grad, ops::CudnnLSTMGradOp); REGISTER_OPERATOR(lstm_cudnn_grad, ops::CudnnLSTMGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
lstm_cudnn, cudnn_lstm,
ops::CudnnLSTMKernel<paddle::platform::CPUDeviceContext, float>); ops::CudnnLSTMKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/lstm_cudnn_op.h" #include "paddle/fluid/operators/cudnn_lstm_op.h"
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle { namespace paddle {
@ -487,8 +487,8 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
lstm_cudnn, cudnn_lstm,
ops::CudnnLSTMGPUKernel<paddle::platform::CUDADeviceContext, float>); ops::CudnnLSTMGPUKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
lstm_cudnn_grad, cudnn_lstm_grad,
ops::CudnnLSTMGPUGradKernel<paddle::platform::CUDADeviceContext, float>); ops::CudnnLSTMGPUGradKernel<paddle::platform::CUDADeviceContext, float>);
Loading…
Cancel
Save