From 1b6a2a09e8115f163161586413f052f5053fa82d Mon Sep 17 00:00:00 2001 From: lujun Date: Mon, 25 Mar 2019 13:03:13 +0800 Subject: [PATCH 1/3] fix mix input type error, test=develop --- paddle/fluid/operators/load_combine_op.cc | 1 + paddle/fluid/operators/load_op.cc | 1 + paddle/fluid/operators/save_combine_op.cc | 8 ++++++++ 3 files changed, 10 insertions(+) diff --git a/paddle/fluid/operators/load_combine_op.cc b/paddle/fluid/operators/load_combine_op.cc index 2948cf71a9..63d3f809f2 100644 --- a/paddle/fluid/operators/load_combine_op.cc +++ b/paddle/fluid/operators/load_combine_op.cc @@ -88,4 +88,5 @@ REGISTER_OP_CPU_KERNEL( ops::LoadCombineOpKernel, ops::LoadCombineOpKernel, ops::LoadCombineOpKernel, + ops::LoadCombineOpKernel, ops::LoadCombineOpKernel); diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index 2d8e6ca854..656728c609 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -64,4 +64,5 @@ REGISTER_OP_CPU_KERNEL( load, ops::LoadOpKernel, ops::LoadOpKernel, ops::LoadOpKernel, + ops::LoadOpKernel, ops::LoadOpKernel); diff --git a/paddle/fluid/operators/save_combine_op.cc b/paddle/fluid/operators/save_combine_op.cc index 62b1e09737..5c4be7a7f3 100644 --- a/paddle/fluid/operators/save_combine_op.cc +++ b/paddle/fluid/operators/save_combine_op.cc @@ -24,6 +24,13 @@ class SaveCombineOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override {} + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.MultiInput("X")[0]->type(), + ctx.GetPlace()); + } }; class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker { @@ -71,4 +78,5 @@ REGISTER_OP_CPU_KERNEL( ops::SaveCombineOpKernel, ops::SaveCombineOpKernel, ops::SaveCombineOpKernel, + ops::SaveCombineOpKernel, ops::SaveCombineOpKernel); From 18aa59493efacbdcf0b05941d0b7cfd970f4550a Mon Sep 17 00:00:00 2001 From: lujun Date: Mon, 25 Mar 2019 14:43:55 +0800 Subject: [PATCH 2/3] fix mix input type error, test=develop --- paddle/fluid/operators/save_combine_op.cc | 19 ++++++++++++------- paddle/fluid/operators/save_combine_op.cu | 4 +--- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/save_combine_op.cc b/paddle/fluid/operators/save_combine_op.cc index 5c4be7a7f3..ac1c2dde3b 100644 --- a/paddle/fluid/operators/save_combine_op.cc +++ b/paddle/fluid/operators/save_combine_op.cc @@ -23,14 +23,21 @@ class SaveCombineOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override {} + void InferShape(framework::InferShapeContext* ctx) const override {} protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.MultiInput("X")[0]->type(), + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, ctx.GetPlace()); } + // TODO(lujun): The override here is just to bypass transform + // in operator impl, which is not elegant enough. + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + return expected_kernel_type; + } }; class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker { @@ -61,7 +68,7 @@ to a file on disk. "(string)" "The \"file_path\" where the LoDTensor variables will be saved.") .AddCustomChecker( - [](const std::string &path) { return !path.empty(); }); + [](const std::string& path) { return !path.empty(); }); } }; @@ -77,6 +84,4 @@ REGISTER_OP_CPU_KERNEL( save_combine, ops::SaveCombineOpKernel, ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel); + ops::SaveCombineOpKernel); diff --git a/paddle/fluid/operators/save_combine_op.cu b/paddle/fluid/operators/save_combine_op.cu index bc4478b51b..78607823a0 100644 --- a/paddle/fluid/operators/save_combine_op.cu +++ b/paddle/fluid/operators/save_combine_op.cu @@ -20,6 +20,4 @@ REGISTER_OP_CUDA_KERNEL( save_combine, ops::SaveCombineOpKernel, ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel); + ops::SaveCombineOpKernel); From bc4d1c7246786aa6cfe3221f5f40718ee21ae721 Mon Sep 17 00:00:00 2001 From: lujun Date: Mon, 25 Mar 2019 18:58:13 +0800 Subject: [PATCH 3/3] fix mix input type error, test=develop --- paddle/fluid/operators/save_combine_op.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/operators/save_combine_op.cc b/paddle/fluid/operators/save_combine_op.cc index ac1c2dde3b..953e2655d1 100644 --- a/paddle/fluid/operators/save_combine_op.cc +++ b/paddle/fluid/operators/save_combine_op.cc @@ -19,6 +19,8 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; + class SaveCombineOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel;