diff --git a/paddle/fluid/framework/details/graph_test_base.h b/paddle/fluid/framework/details/graph_test_base.h index 126959bcd8..2fae684516 100644 --- a/paddle/fluid/framework/details/graph_test_base.h +++ b/paddle/fluid/framework/details/graph_test_base.h @@ -68,11 +68,11 @@ class SplitOpMaker : public OpProtoAndCheckerMaker { class DummyVarTypeInference : public VarTypeInference { public: - void operator()(const OpDesc& op_desc, BlockDesc* block) const override { - auto& inputs = op_desc.Input("X"); - auto type = block->Var(inputs.front())->GetType(); - auto out_var_name = op_desc.Output("Out").front(); - block->Var(out_var_name)->SetType(type); + void operator()(framework::InferVarTypeContext& ctx) const override { + auto& inputs = ctx.Input("X"); + auto type = ctx.GetType(inputs.front()); + auto out_var_name = ctx.Output("Out").front(); + ctx.SetType(out_var_name, type); } }; diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index 0901e59f97..79281863f6 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -127,9 +127,9 @@ struct OpInfoFiller { template struct OpInfoFiller { void operator()(const char* op_type, OpInfo* info) const { - info->infer_var_type_ = [](const OpDesc& fwd_op, BlockDesc* block) { + info->infer_var_type_ = [](InferVarTypeContext& context) { T inference; - inference(fwd_op, block); + inference(context); }; } }; diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 7ed2f96eb2..2940f3ceeb 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -43,20 +43,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker { class SumOpVarTypeInference : public VarTypeInference { public: - void operator()(const OpDesc &op_desc, BlockDesc *block) const override { - auto &inputs = op_desc.Input("X"); + void operator()(InferVarTypeContext &ctx) const override { + auto &inputs = ctx.Input("X"); auto default_var_type = proto::VarType::SELECTED_ROWS; bool any_input_is_lod_tensor = std::any_of( - inputs.begin(), inputs.end(), [block](const std::string &name) { - return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR; + inputs.begin(), inputs.end(), [ctx](const std::string &name) { + return ctx.GetType(name) == proto::VarType::LOD_TENSOR; }); if (any_input_is_lod_tensor) { default_var_type = proto::VarType::LOD_TENSOR; } - auto out_var_name = op_desc.Output("Out").front(); - block->Var(out_var_name)->SetType(default_var_type); + auto out_var_name = ctx.Output("Out").front(); + ctx.SetType(out_var_name, default_var_type); } }; @@ -71,7 +71,7 @@ class DummyOpMaker : public OpProtoAndCheckerMaker { class DummyOpVarTypeInference : public VarTypeInference { public: - void operator()(const OpDesc &op_desc, BlockDesc *block) const override {} + void operator()(framework::InferVarTypeContext &ctx) const override {} }; } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 0e7b0cbeb9..aae0eafe6c 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/shape_inference.h" +#include "paddle/fluid/framework/var_type_inference.h" namespace paddle { namespace framework { @@ -677,7 +678,8 @@ void OpDesc::InferVarType(BlockDesc *block) const { // var type inference. Hence, we don't do any "default" setting here. auto &info = OpInfoMap::Instance().Get(this->Type()); if (info.infer_var_type_) { - info.infer_var_type_(*this, block); + InferVarTypeContext context(this, block); + info.infer_var_type_(context); } } diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index d02c699b97..a774f9ff49 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -27,6 +27,7 @@ namespace framework { class OperatorBase; class OpDesc; class InferShapeContext; +class InferVarTypeContext; class BlockDesc; class Variable; @@ -53,7 +54,7 @@ using GradOpMakerFN = std::function>( const std::vector& grad_block)>; using InferVarTypeFN = - std::function; + std::function; using InferShapeFN = std::function; diff --git a/paddle/fluid/framework/var_type_inference.h b/paddle/fluid/framework/var_type_inference.h index 64236b78d2..ed52e1ad81 100644 --- a/paddle/fluid/framework/var_type_inference.h +++ b/paddle/fluid/framework/var_type_inference.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/type_defs.h" @@ -21,26 +22,113 @@ limitations under the License. */ namespace paddle { namespace framework { +class OpDesc; +class BlockDesc; +// default infer var type context +class InferVarTypeContext { + public: + InferVarTypeContext(const OpDesc* op, BlockDesc* block) + : op_(op), block_(block) {} + + Attribute GetAttr(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(op_); + return op_->GetAttr(name); + } + + inline bool HasVar(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(block_); + return block_->FindVarRecursive(name) != nullptr; + } + + inline bool HasInput(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(op_); + return op_->Inputs().count(name) > 0; + } + + inline bool HasOutput(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(op_); + return op_->Outputs().count(name) > 0; + } + + inline const std::vector& Input(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(op_); + return op_->Input(name); + } + + inline const std::vector& Output(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(op_); + return op_->Output(name); + } + + inline proto::VarType::Type GetType(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(block_); + return block_->FindRecursiveOrCreateVar(name).GetType(); + } + + inline void SetType(const std::string& name, proto::VarType::Type type) { + PADDLE_ENFORCE_NOT_NULL(block_); + block_->FindRecursiveOrCreateVar(name).SetType(type); + } + + inline proto::VarType::Type GetDataType(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(block_); + return block_->FindRecursiveOrCreateVar(name).GetDataType(); + } + + inline void SetDataType(const std::string& name, proto::VarType::Type type) { + PADDLE_ENFORCE_NOT_NULL(block_); + block_->FindRecursiveOrCreateVar(name).SetDataType(type); + } + + inline std::vector GetShape(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(block_); + return block_->FindRecursiveOrCreateVar(name).GetShape(); + } + + inline void SetShape(const std::string& name, + const std::vector& dims) { + PADDLE_ENFORCE_NOT_NULL(block_); + block_->FindRecursiveOrCreateVar(name).SetShape(dims); + } + + inline int32_t GetLoDLevel(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(block_); + return block_->FindRecursiveOrCreateVar(name).GetLoDLevel(); + } + + inline void SetLoDLevel(const std::string& name, int32_t lod_level) { + PADDLE_ENFORCE_NOT_NULL(block_); + block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level); + } + + private: + const OpDesc* op_; + BlockDesc* block_; +}; + +// infer var type context for imperative mode +class RuntimeInferVarTypeContext : public InferVarTypeContext { + public: + RuntimeInferVarTypeContext() : InferVarTypeContext(nullptr, nullptr) {} +}; + class VarTypeInference { public: virtual ~VarTypeInference() {} - virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0; + virtual void operator()(InferVarTypeContext& context) const = 0; // NOLINT }; class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const final { + void operator()(framework::InferVarTypeContext& ctx) const final { // NOLINT auto in_out_var_names = this->GetInputOutputWithSameType(); for (auto& i_o_n : in_out_var_names) { - auto& x_name = op_desc.Input(i_o_n.first).at(0); - auto& out_name = op_desc.Output(i_o_n.second).at(0); + auto& x_name = ctx.Input(i_o_n.first).at(0); + auto& out_name = ctx.Output(i_o_n.second).at(0); - auto& x = block->FindRecursiveOrCreateVar(x_name); - auto& out = block->FindRecursiveOrCreateVar(out_name); - out.SetType(x.GetType()); - out.SetDataType(x.GetDataType()); + ctx.SetType(out_name, ctx.GetType(x_name)); + ctx.SetDataType(out_name, ctx.GetDataType(x_name)); } } diff --git a/paddle/fluid/framework/var_type_inference_test.cc b/paddle/fluid/framework/var_type_inference_test.cc index 2a75394fca..d7d3e0a033 100644 --- a/paddle/fluid/framework/var_type_inference_test.cc +++ b/paddle/fluid/framework/var_type_inference_test.cc @@ -44,20 +44,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker { class SumOpVarTypeInference : public VarTypeInference { public: - void operator()(const OpDesc &op_desc, BlockDesc *block) const override { - auto &inputs = op_desc.Input("X"); + void operator()(framework::InferVarTypeContext &ctx) const override { + auto &inputs = ctx.Input("X"); auto default_var_type = proto::VarType::SELECTED_ROWS; bool any_input_is_lod_tensor = std::any_of( - inputs.begin(), inputs.end(), [block](const std::string &name) { - return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR; + inputs.begin(), inputs.end(), [ctx](const std::string &name) { + return ctx.GetType(name) == proto::VarType::LOD_TENSOR; }); if (any_input_is_lod_tensor) { default_var_type = proto::VarType::LOD_TENSOR; } - auto out_var_name = op_desc.Output("Out").front(); - block->Var(out_var_name)->SetType(default_var_type); + auto out_var_name = ctx.Output("Out").front(); + ctx.SetType(out_var_name, default_var_type); } }; } // namespace framework diff --git a/paddle/fluid/operators/beam_search_decode_op.cc b/paddle/fluid/operators/beam_search_decode_op.cc index cf78c83297..703edcad11 100644 --- a/paddle/fluid/operators/beam_search_decode_op.cc +++ b/paddle/fluid/operators/beam_search_decode_op.cc @@ -178,10 +178,10 @@ Beam Search Decode Operator. This Operator constructs the full hypotheses for each source sentence by walking back along the LoDTensorArray Input(ids) whose lods can be used to restore the path in the beam search tree. -The Output(SentenceIds) and Output(SentenceScores) separately contain the -generated id sequences and the corresponding scores. The shapes and lods of the -two LodTensor are same. The lod level is 2 and the two levels separately -indicate how many hypotheses each source sentence has and how many ids each +The Output(SentenceIds) and Output(SentenceScores) separately contain the +generated id sequences and the corresponding scores. The shapes and lods of the +two LodTensor are same. The lod level is 2 and the two levels separately +indicate how many hypotheses each source sentence has and how many ids each hypothesis has. )DOC"); } @@ -203,15 +203,12 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase { class BeamSearchDecodeInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - for (auto& o : op_desc.Output("SentenceIds")) { - auto& sentence_ids = block->FindRecursiveOrCreateVar(o); - sentence_ids.SetType(framework::proto::VarType::LOD_TENSOR); + void operator()(framework::InferVarTypeContext& ctx) const override { + for (auto& o : ctx.Output("SentenceIds")) { + ctx.SetType(o, framework::proto::VarType::LOD_TENSOR); } - for (auto& o : op_desc.Output("SentenceScores")) { - auto& sentence_scores = block->FindRecursiveOrCreateVar(o); - sentence_scores.SetType(framework::proto::VarType::LOD_TENSOR); + for (auto& o : ctx.Output("SentenceScores")) { + ctx.SetType(o, framework::proto::VarType::LOD_TENSOR); } } }; diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index fa6b09b4e7..8958d00a68 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -65,7 +65,7 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(true); AddComment(R"DOC( -This operator does the search in beams for one time step. +This operator does the search in beams for one time step. Specifically, it selects the top-K candidate word ids of current step from Input(ids) according to their Input(scores) for all source sentences, where K is Attr(beam_size) and Input(ids), Input(scores) are predicted results @@ -120,15 +120,12 @@ class BeamSearchOp : public framework::OperatorWithKernel { class BeamSearchInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - for (auto &o : op_desc.Output("selected_ids")) { - auto &selected_ids = block->FindRecursiveOrCreateVar(o); - selected_ids.SetType(framework::proto::VarType::LOD_TENSOR); + void operator()(framework::InferVarTypeContext &ctx) const override { + for (auto &o : ctx.Output("selected_ids")) { + ctx.SetType(o, framework::proto::VarType::LOD_TENSOR); } - for (auto &o : op_desc.Output("selected_scores")) { - auto &selected_scores = block->FindRecursiveOrCreateVar(o); - selected_scores.SetType(framework::proto::VarType::LOD_TENSOR); + for (auto &o : ctx.Output("selected_scores")) { + ctx.SetType(o, framework::proto::VarType::LOD_TENSOR); } } }; diff --git a/paddle/fluid/operators/controlflow/get_places_op.cc b/paddle/fluid/operators/controlflow/get_places_op.cc index 1a157688f3..0258739d6d 100644 --- a/paddle/fluid/operators/controlflow/get_places_op.cc +++ b/paddle/fluid/operators/controlflow/get_places_op.cc @@ -93,11 +93,9 @@ execution. class GetPlacesInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - for (auto &o_name : op_desc.Output("Out")) { - block->FindRecursiveOrCreateVar(o_name).SetType( - framework::proto::VarType::PLACE_LIST); + void operator()(framework::InferVarTypeContext &ctx) const override { + for (auto &o_name : ctx.Output("Out")) { + ctx.SetType(o_name, framework::proto::VarType::PLACE_LIST); } } }; diff --git a/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc b/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc index fa18ade323..041eef602e 100644 --- a/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc +++ b/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc @@ -100,16 +100,13 @@ class WriteToArrayInferShape : public framework::InferShapeBase { class WriteToArrayInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - auto x_name = op_desc.Input("X")[0]; - auto out_name = op_desc.Output("Out")[0]; + void operator()(framework::InferVarTypeContext &ctx) const override { + auto x_name = ctx.Input("X")[0]; + auto out_name = ctx.Output("Out")[0]; VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY"; - auto &out = block->FindRecursiveOrCreateVar(out_name); - out.SetType(framework::proto::VarType::LOD_TENSOR_ARRAY); - auto *x = block->FindVarRecursive(x_name); - if (x != nullptr) { - out.SetDataType(x->GetDataType()); + ctx.SetType(out_name, framework::proto::VarType::LOD_TENSOR_ARRAY); + if (ctx.HasVar(x_name)) { + ctx.SetDataType(out_name, ctx.GetDataType(x_name)); } } }; diff --git a/paddle/fluid/operators/distributed_ops/merge_ids_op.cc b/paddle/fluid/operators/distributed_ops/merge_ids_op.cc index da0185b8c4..0a269c7575 100644 --- a/paddle/fluid/operators/distributed_ops/merge_ids_op.cc +++ b/paddle/fluid/operators/distributed_ops/merge_ids_op.cc @@ -114,11 +114,10 @@ class MergeIdsOp : public framework::OperatorWithKernel { class MergeIdsOpInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - auto *input_var = block->Var(op_desc.Input("Ids")[0]); - for (auto &out_var : op_desc.Output("Out")) { - block->Var(out_var)->SetType(input_var->GetType()); + void operator()(framework::InferVarTypeContext &ctx) const override { + auto input_type = ctx.GetType(ctx.Input("Ids")[0]); + for (auto &out_var : ctx.Output("Out")) { + ctx.SetType(out_var, input_type); } } }; diff --git a/paddle/fluid/operators/distributed_ops/split_ids_op.cc b/paddle/fluid/operators/distributed_ops/split_ids_op.cc index f61d387fbe..2932a202a5 100644 --- a/paddle/fluid/operators/distributed_ops/split_ids_op.cc +++ b/paddle/fluid/operators/distributed_ops/split_ids_op.cc @@ -71,11 +71,10 @@ class SplitIdsOp : public framework::OperatorWithKernel { class SplitIdsOpInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - auto *input_var = block->Var(op_desc.Input("Ids")[0]); - for (auto &out_var : op_desc.Output("Out")) { - block->Var(out_var)->SetType(input_var->GetType()); + void operator()(framework::InferVarTypeContext &ctx) const override { + auto input_type = ctx.GetType(ctx.Input("Ids")[0]); + for (auto &out_var : ctx.Output("Out")) { + ctx.SetType(out_var, input_type); } } }; diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index c86430524e..eb5996d50e 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -39,12 +39,11 @@ class FillConstantOp : public framework::OperatorWithKernel { class FillConstantOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { + void operator()(framework::InferVarTypeContext& ctx) const override { auto data_type = static_cast( - boost::get(op_desc.GetAttr("dtype"))); - auto& out_var_name = op_desc.Output("Out").front(); - block->Var(out_var_name)->SetDataType(data_type); + boost::get(ctx.GetAttr("dtype"))); + auto& out_var_name = ctx.Output("Out").front(); + ctx.SetDataType(out_var_name, data_type); } }; diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc index a0026427e2..27a761c29f 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc @@ -137,22 +137,20 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel { class FusedEmbeddingSeqPoolOpGradVarTypeInference : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto out_var_name = op_desc.Output(framework::GradVarName("W")).front(); - auto attr = op_desc.GetAttr("is_sparse"); + void operator()(framework::InferVarTypeContext& ctx) const override { + auto out_var_name = ctx.Output(framework::GradVarName("W")).front(); + auto attr = ctx.GetAttr("is_sparse"); bool is_sparse = boost::get(attr); if (is_sparse) { VLOG(3) << "fused_embedding_seq_pool_grad op " << framework::GradVarName("W") << " is set to SelectedRows"; - block->Var(out_var_name) - ->SetType(framework::proto::VarType::SELECTED_ROWS); + ctx.SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS); } else { VLOG(3) << "fused_embedding_seq_pool_grad op " << framework::GradVarName("W") << " is set to LoDTensor"; - block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); + ctx.SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); } - block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType()); + ctx.SetDataType(out_var_name, ctx.GetDataType(ctx.Input("W")[0])); } }; diff --git a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc index a4ae19d9c1..5388e65497 100644 --- a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc +++ b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc @@ -81,15 +81,12 @@ GetTensorFromSelectedRows is used to get the tensor from SelectedRows. class GetTensorFromSelectedRowsOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const final { - auto out_var_name = op_desc.Output("Out").front(); - auto in_var_name = op_desc.Input("X").front(); - - auto out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto in_var = block->FindRecursiveOrCreateVar(in_var_name); - out_var.SetType(framework::proto::VarType::LOD_TENSOR); - out_var.SetDataType(in_var.GetDataType()); + void operator()(framework::InferVarTypeContext &ctx) const { // NOLINT + auto out_var_name = ctx.Output("Out").front(); + auto in_var_name = ctx.Input("X").front(); + + ctx.SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); + ctx.SetDataType(out_var_name, ctx.GetDataType(in_var_name)); } }; diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 6ca6f0bc04..508c99b953 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -197,38 +197,32 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { class HierarchicalSigmoidGradOpGradVarTypeInference : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto w_grad_var_name = op_desc.Output(framework::GradVarName("W")).front(); - auto bias_grad_var_name_vec = - op_desc.Output(framework::GradVarName("Bias")); + void operator()(framework::InferVarTypeContext& ctx) const override { + auto w_grad_var_name = ctx.Output(framework::GradVarName("W")).front(); + auto bias_grad_var_name_vec = ctx.Output(framework::GradVarName("Bias")); std::string bias_grad_var_name; bool hasBias = false; if (bias_grad_var_name_vec.size()) { hasBias = true; - bias_grad_var_name = - op_desc.Output(framework::GradVarName("Bias")).front(); + bias_grad_var_name = ctx.Output(framework::GradVarName("Bias")).front(); } - auto attr = op_desc.GetAttr("is_sparse"); + auto attr = ctx.GetAttr("is_sparse"); bool is_sparse = boost::get(attr); if (is_sparse) { VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") << " is set to SelectedRows"; - block->Var(w_grad_var_name) - ->SetType(framework::proto::VarType::SELECTED_ROWS); + ctx.SetType(w_grad_var_name, framework::proto::VarType::SELECTED_ROWS); } else { VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") << " is set to LoDTensor"; - block->Var(w_grad_var_name) - ->SetType(framework::proto::VarType::LOD_TENSOR); + ctx.SetType(w_grad_var_name, framework::proto::VarType::LOD_TENSOR); } if (hasBias) { VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("Bias") << " is set to LoDTensor"; - block->Var(bias_grad_var_name) - ->SetType(framework::proto::VarType::LOD_TENSOR); + ctx.SetType(bias_grad_var_name, framework::proto::VarType::LOD_TENSOR); } - block->Var(w_grad_var_name)->SetDataType(block->Var("W")->GetDataType()); + ctx.SetDataType(w_grad_var_name, ctx.GetDataType(ctx.Input("W")[0])); } }; diff --git a/paddle/fluid/operators/lod_rank_table_op.cc b/paddle/fluid/operators/lod_rank_table_op.cc index 166952fe23..a7bbb49827 100644 --- a/paddle/fluid/operators/lod_rank_table_op.cc +++ b/paddle/fluid/operators/lod_rank_table_op.cc @@ -64,11 +64,9 @@ class LoDRankTableInferShape : public framework::InferShapeBase { class LoDRankTableInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - for (auto &o : op_desc.Output("Out")) { - block->FindRecursiveOrCreateVar(o).SetType( - framework::proto::VarType::LOD_RANK_TABLE); + void operator()(framework::InferVarTypeContext &ctx) const override { + for (auto &o : ctx.Output("Out")) { + ctx.SetType(o, framework::proto::VarType::LOD_RANK_TABLE); } } }; diff --git a/paddle/fluid/operators/lod_tensor_to_array_op.cc b/paddle/fluid/operators/lod_tensor_to_array_op.cc index 9b91cf5260..4fd45db67b 100644 --- a/paddle/fluid/operators/lod_tensor_to_array_op.cc +++ b/paddle/fluid/operators/lod_tensor_to_array_op.cc @@ -201,10 +201,9 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase { class LoDTensorToArrayInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - for (auto &out_var : op_desc.Output("Out")) { - block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY); + void operator()(framework::InferVarTypeContext &ctx) const override { + for (auto &out_var : ctx.Output("Out")) { + ctx.SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY); } } }; diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 0029932bc0..a59ff23f93 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -147,22 +147,20 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto out_var_name = op_desc.Output(framework::GradVarName("W")).front(); - auto attr = op_desc.GetAttr("is_sparse"); + void operator()(framework::InferVarTypeContext& ctx) const override { + auto out_var_name = ctx.Output(framework::GradVarName("W")).front(); + auto attr = ctx.GetAttr("is_sparse"); bool is_sparse = boost::get(attr); if (is_sparse) { VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") << " is set to SelectedRows"; - block->Var(out_var_name) - ->SetType(framework::proto::VarType::SELECTED_ROWS); + ctx.SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS); } else { VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") << " is set to LoDTensor"; - block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); + ctx.SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); } - block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType()); + ctx.SetDataType(out_var_name, ctx.GetDataType(ctx.Input("W")[0])); } }; diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc index 256da34912..3c3d79cc7b 100644 --- a/paddle/fluid/operators/nce_op.cc +++ b/paddle/fluid/operators/nce_op.cc @@ -237,23 +237,21 @@ class NCEOpGrad : public framework::OperatorWithKernel { class NCEOpGradVarTypeInference : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - auto weight_grad = op_desc.Output(framework::GradVarName("Weight")).front(); + void operator()(framework::InferVarTypeContext &ctx) const override { + auto weight_grad = ctx.Output(framework::GradVarName("Weight")).front(); - auto attr = op_desc.GetAttr("is_sparse"); + auto attr = ctx.GetAttr("is_sparse"); bool is_sparse = boost::get(attr); if (is_sparse) { VLOG(3) << "nce_op_grad op " << weight_grad << " and " << " is set to SelectedRows"; - block->Var(weight_grad) - ->SetType(framework::proto::VarType::SELECTED_ROWS); + ctx.SetType(weight_grad, framework::proto::VarType::SELECTED_ROWS); } else { VLOG(3) << "nce_op_grad op " << weight_grad << " and " << " is set to LoDTensor"; - block->Var(weight_grad)->SetType(framework::proto::VarType::LOD_TENSOR); + ctx.SetType(weight_grad, framework::proto::VarType::LOD_TENSOR); } - block->Var(weight_grad)->SetDataType(block->Var("Input")->GetDataType()); + ctx.SetDataType(weight_grad, ctx.GetDataType(ctx.Input("Input")[0])); } }; diff --git a/paddle/fluid/operators/ngraph/ngraph_engine_op.cc b/paddle/fluid/operators/ngraph/ngraph_engine_op.cc index f941f917c8..a88ddf33a0 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine_op.cc +++ b/paddle/fluid/operators/ngraph/ngraph_engine_op.cc @@ -37,8 +37,7 @@ class NgraphEngineOpMaker : public framework::OpProtoAndCheckerMaker { class NgraphEngineInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override {} + void operator()(framework::InferVarTypeContext &ctx) const override {} }; } // namespace operators diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cc b/paddle/fluid/operators/optimizers/lars_momentum_op.cc index 574a03680b..668fa889ac 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cc @@ -56,9 +56,9 @@ This optimizer use LARS (https://arxiv.org/abs/1708.03888) to optimize each weight using a local learning rate: $$ -local\_lr = \eta * +local\_lr = \eta * \frac{\left \| param \right \|}{\left \| grad \right \| + \beta *\left \| param \right \|} \\ -velocity = mu * velocity + +velocity = mu * velocity + local\_lr * (grad + \beta * param) \\ param = param - velocity. \\ $$ @@ -72,8 +72,7 @@ use L2 regularizers in case of using LARS. class LarsMomentumOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override {} + void operator()(framework::InferVarTypeContext &ctx) const override {} }; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/optimizers/momentum_op.cc b/paddle/fluid/operators/optimizers/momentum_op.cc index cde238c076..1be423da5b 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.cc +++ b/paddle/fluid/operators/optimizers/momentum_op.cc @@ -21,18 +21,14 @@ using Tensor = framework::Tensor; class MomentumOpInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto input_var = op_desc.Input("Param")[0]; - for (auto& out_var : op_desc.Output("ParamOut")) { - if (block->FindRecursiveOrCreateVar(input_var).GetType() == - framework::proto::VarType::SELECTED_ROWS) { - block->FindRecursiveOrCreateVar(out_var).SetType( - framework::proto::VarType::SELECTED_ROWS); - } else if (block->FindRecursiveOrCreateVar(input_var).GetType() == + void operator()(framework::InferVarTypeContext& ctx) const override { + auto& input_var = ctx.Input("Param")[0]; + for (auto& out_var : ctx.Output("ParamOut")) { + if (ctx.GetType(input_var) == framework::proto::VarType::SELECTED_ROWS) { + ctx.SetType(out_var, framework::proto::VarType::SELECTED_ROWS); + } else if (ctx.GetType(input_var) == framework::proto::VarType::LOD_TENSOR) { - block->FindRecursiveOrCreateVar(out_var).SetType( - framework::proto::VarType::LOD_TENSOR); + ctx.SetType(out_var, framework::proto::VarType::LOD_TENSOR); } else { PADDLE_THROW( "Only support LodTensor and SelectedRows, Unexpected Input Type."); diff --git a/paddle/fluid/operators/optimizers/sgd_op.cc b/paddle/fluid/operators/optimizers/sgd_op.cc index 690381a67f..cac3d9b68f 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.cc +++ b/paddle/fluid/operators/optimizers/sgd_op.cc @@ -50,20 +50,18 @@ class SGDOp : public framework::OperatorWithKernel { class SGDOpInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - auto input_var_n = op_desc.Input("Param")[0]; - auto in_var_type = block->FindRecursiveOrCreateVar(input_var_n).GetType(); + void operator()(framework::InferVarTypeContext &ctx) const override { + auto &input_var_n = ctx.Input("Param")[0]; + auto in_var_type = ctx.GetType(input_var_n); PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS || in_var_type == framework::proto::VarType::LOD_TENSOR, "The input Var's type should be LoDtensor or SelectedRows," " but the received var(%s)'s type is %s", input_var_n, in_var_type); - for (auto &out_var_n : op_desc.Output("ParamOut")) { - auto &out_var = block->FindRecursiveOrCreateVar(out_var_n); - if (out_var.GetType() != in_var_type) { - out_var.SetType(in_var_type); + for (auto &out_var_n : ctx.Output("ParamOut")) { + if (ctx.GetType(out_var_n) != in_var_type) { + ctx.SetType(out_var_n, in_var_type); } } } diff --git a/paddle/fluid/operators/py_func_op.cc b/paddle/fluid/operators/py_func_op.cc index 53eff2de3e..f630ad678f 100644 --- a/paddle/fluid/operators/py_func_op.cc +++ b/paddle/fluid/operators/py_func_op.cc @@ -91,15 +91,12 @@ static void CallPythonFunc(py::object *callable, } } -class PyFuncOpVarTypInference : public framework::VarTypeInference { +class PyFuncOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op, - framework::BlockDesc *block) const override { - auto &outs = op.Outputs(); - bool has_out = (outs.count("Out") > 0 && !outs.at("Out").empty()); + void operator()(framework::InferVarTypeContext &ctx) const override { + bool has_out = (ctx.HasOutput("Out") && !ctx.Output("Out").empty()); - auto &ins = op.Inputs(); - bool has_in = (ins.count("X") > 0 && !ins.at("X").empty()); + bool has_in = (ctx.HasInput("X") && !ctx.Input("Out").empty()); /** * X or Out can be empty, so that py_func can be more flexible @@ -107,7 +104,7 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference { */ PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist"); - PADDLE_ENFORCE_GE(boost::get(op.GetAttr(kForwardPythonCallableId)), 0, + PADDLE_ENFORCE_GE(boost::get(ctx.GetAttr(kForwardPythonCallableId)), 0, "Function id cannot be less than 0"); if (!has_out) return; @@ -118,7 +115,7 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference { * the corresponding forward variable */ const std::string kGradVarSuffix = framework::kGradVarSuffix; - auto &out_var_names = outs.at("Out"); + auto &out_var_names = ctx.Output("Out"); for (auto &out_var_name : out_var_names) { if (out_var_name == framework::kEmptyVarName || out_var_name.size() < kGradVarSuffix.size()) { @@ -128,18 +125,17 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference { size_t len = out_var_name.size() - kGradVarSuffix.size(); if (out_var_name.substr(len) == kGradVarSuffix) { auto fwd_var_name = out_var_name.substr(0, len); - auto *out_var_desc = block->FindVarRecursive(out_var_name); - auto *fwd_var_desc = block->FindVarRecursive(fwd_var_name); - PADDLE_ENFORCE_NOT_NULL(out_var_desc, "Backward variable %s not found", - out_var_name); - PADDLE_ENFORCE_NOT_NULL(fwd_var_desc, "Forward variable %s not found", - fwd_var_name); + PADDLE_ENFORCE(ctx.HasVar(out_var_name), + "Backward variable %s not found", out_var_name); + PADDLE_ENFORCE(ctx.HasVar(fwd_var_name), + "Backward variable %s not found", fwd_var_name); VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input(" << fwd_var_name << ")"; - out_var_desc->SetShape(fwd_var_desc->GetShape()); - out_var_desc->SetDataType(fwd_var_desc->GetDataType()); - out_var_desc->SetLoDLevel(fwd_var_desc->GetLoDLevel()); - out_var_desc->SetType(fwd_var_desc->GetType()); + + ctx.SetShape(out_var_name, ctx.GetShape(fwd_var_name)); + ctx.SetDataType(out_var_name, ctx.GetDataType(fwd_var_name)); + ctx.SetLoDLevel(out_var_name, ctx.GetLoDLevel(fwd_var_name)); + ctx.SetType(out_var_name, ctx.GetType(fwd_var_name)); } } } @@ -309,5 +305,5 @@ class PyFuncOp : public framework::OperatorBase { namespace ops = paddle::operators; REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker, - ops::PyFuncOpVarTypInference, ops::PyFuncOpShapeInference, + ops::PyFuncOpVarTypeInference, ops::PyFuncOpShapeInference, ops::PyFuncOpGradDescMaker); diff --git a/paddle/fluid/operators/reader/create_custom_reader_op.cc b/paddle/fluid/operators/reader/create_custom_reader_op.cc index 85394b336f..915325905b 100644 --- a/paddle/fluid/operators/reader/create_custom_reader_op.cc +++ b/paddle/fluid/operators/reader/create_custom_reader_op.cc @@ -85,10 +85,10 @@ class CreateCustomReaderOpMaker : public DecoratedReaderMakerBase { AddComment(R"DOC( CreateCustomReader Operator - A custom reader can be used for input data preprocessing. - A custom reader holds its own sub-block, which will be executed in CPU - in its 'ReadNext()' function. Users can configurate their own - preprocessing pipelines by inserting operators into custom reader's + A custom reader can be used for input data preprocessing. + A custom reader holds its own sub-block, which will be executed in CPU + in its 'ReadNext()' function. Users can configurate their own + preprocessing pipelines by inserting operators into custom reader's sub-block. )DOC"); } @@ -123,23 +123,22 @@ class CustomReaderInferShape : public framework::InferShapeBase { class CustomReaderInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - framework::VarDesc* out_reader = block->FindVar(op_desc.Output("Out")[0]); - PADDLE_ENFORCE_NOT_NULL(out_reader); - out_reader->SetType(framework::proto::VarType::READER); + void operator()(const framework::InferVarTypeContext& ctx) const override { + auto& out_var_name = ctx.Output("Out")[0]; + PADDLE_ENFORCE(ctx.HasVar(out_var_name)); + ctx.SetType(out_var_name, framework::proto::VarType::READER); auto sink_var_names = - boost::get>(op_desc.GetAttr("sink_var_names")); + boost::get>(ctx.GetAttr("sink_var_names")); const auto* sub_block = - boost::get(op_desc.GetAttr("sub_block")); + boost::get(ctx.GetAttr("sub_block")); std::vector res_data_types; for (const std::string& var_name : sink_var_names) { framework::VarDesc* var = sub_block->FindVar(var_name); PADDLE_ENFORCE_NOT_NULL(var); res_data_types.emplace_back(var->GetDataType()); } - out_reader->SetDataTypes(res_data_types); + ctx.SetDataTypes(out_var_name, res_data_types); } }; diff --git a/paddle/fluid/operators/reader/read_op.cc b/paddle/fluid/operators/reader/read_op.cc index 846b2ed77e..9a98d68e13 100644 --- a/paddle/fluid/operators/reader/read_op.cc +++ b/paddle/fluid/operators/reader/read_op.cc @@ -51,19 +51,16 @@ class ReadInferShape : public framework::InferShapeBase { class ReadInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - bool infer_out = boost::get(op_desc.GetAttr("infer_out")); + void operator()(const framework::InferVarTypeContext& ctx) const override { + bool infer_out = boost::get(ctx.GetAttr("infer_out")); if (infer_out) { - std::string reader_name = op_desc.Input("Reader")[0]; - std::vector out_names = op_desc.Output("Out"); - framework::VarDesc* reader = block->FindVarRecursive(reader_name); - auto dtypes = reader->GetDataTypes(); + std::string reader_name = ctx.Input("Reader")[0]; + std::vector out_names = ctx.Output("Out"); + auto dtypes = ctx.GetDataTypes(reader_name); PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); for (size_t i = 0; i < dtypes.size(); ++i) { - framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); - out.SetType(framework::proto::VarType::LOD_TENSOR); - out.SetDataType(dtypes[i]); + ctx.SetType(out_names[i], framework::proto::VarType::LOD_TENSOR); + ctx.SetDataType(out_names[i], dtypes[i]); } } } diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h index 25c3e7d77b..58b0dfd555 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.h +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -59,8 +59,7 @@ class FileReaderInferShape : public framework::InferShapeBase { class FileReaderInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override; + void operator()(framework::InferVarTypeContext& ctx) const override; }; // general infershape for decorated reader @@ -72,8 +71,7 @@ class DecoratedReaderInferShape : public framework::InferShapeBase { // general var type inference for decorated reader class DecoratedReaderInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override; + void operator()(framework::InferVarTypeContext& ctx) const override; }; class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index fcc598f4f1..45da2ac4c6 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -159,12 +159,9 @@ This operator will serialize and write LoDTensor / SelectedRows variable to file class SaveOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - auto out_var_name = op_desc.Output(LOOKUP_TABLE_PATH).front(); - auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); + void operator()(framework::InferVarTypeContext &ctx) const override { + auto out_var_name = ctx.Output(LOOKUP_TABLE_PATH).front(); + ctx.SetType(out_var_name, framework::proto::VarType::RAW); } }; diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index 4ea77ed30d..d2f05c42a7 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -69,17 +69,13 @@ $$Out = scale*(X + bias)$$ class ScaleOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - auto &in_var_name = op_desc.Input("X").front(); - auto &in_var = detail::Ref(block->FindVarRecursive(in_var_name)); - - auto out_var_name = op_desc.Output("Out").front(); - auto *out_var = block->FindVarRecursive(out_var_name); + void operator()(framework::InferVarTypeContext &ctx) const override { + auto &in_var_name = ctx.Input("X").front(); + auto out_var_name = ctx.Output("Out").front(); if (in_var_name != out_var_name) { - out_var->SetType(in_var.GetType()); - out_var->SetDataType(in_var.GetDataType()); + ctx.SetType(out_var_name, ctx.GetType(in_var_name)); + ctx.SetDataType(out_var_name, ctx.GetDataType(in_var_name)); } } }; diff --git a/paddle/fluid/operators/split_selected_rows_op.cc b/paddle/fluid/operators/split_selected_rows_op.cc index 0e7b1463d1..e950f30a42 100644 --- a/paddle/fluid/operators/split_selected_rows_op.cc +++ b/paddle/fluid/operators/split_selected_rows_op.cc @@ -60,10 +60,9 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel { class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - for (auto &out_var : op_desc.Output("Out")) { - block->Var(out_var)->SetType(framework::proto::VarType::SELECTED_ROWS); + void operator()(framework::InferVarTypeContext &ctx) const override { + for (auto &out_var : ctx.Output("Out")) { + ctx.SetType(out_var, framework::proto::VarType::SELECTED_ROWS); } } }; diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 7abfbbd3cb..d674711392 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -159,24 +159,20 @@ the LoD information with the first input. class SumOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto& inputs = op_desc.Input("X"); + void operator()(framework::InferVarTypeContext& ctx) const override { + auto& inputs = ctx.Input("X"); auto var_type = framework::proto::VarType::SELECTED_ROWS; - for (auto& name : op_desc.Input("X")) { - VLOG(10) << name << " " - << block->FindRecursiveOrCreateVar(name).GetType(); + for (auto& name : ctx.Input("X")) { + VLOG(10) << name << " " << ctx.GetType(name); } bool any_input_is_lod_tensor = std::any_of( - inputs.begin(), inputs.end(), [block](const std::string& name) { - return block->FindRecursiveOrCreateVar(name).GetType() == - framework::proto::VarType::LOD_TENSOR; + inputs.begin(), inputs.end(), [ctx](const std::string& name) { + return ctx.GetType(name) == framework::proto::VarType::LOD_TENSOR; }); - auto is_tensor_array = [block](const std::string& name) { - return block->FindRecursiveOrCreateVar(name).GetType() == - framework::proto::VarType::LOD_TENSOR_ARRAY; + auto is_tensor_array = [ctx](const std::string& name) { + return ctx.GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY; }; bool any_input_is_tensor_array = @@ -188,8 +184,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference { if (!all_inputs_are_tensor_array) { std::ostringstream os; for (auto& each : inputs) { - os << " " << each << " type is " - << block->FindRecursiveOrCreateVar(each).GetType() << "\n"; + os << " " << each << " type is " << ctx.GetType(each) << "\n"; } PADDLE_ENFORCE(all_inputs_are_tensor_array, "Not all inputs are tensor array:\n%s", os.str()); @@ -199,11 +194,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference { var_type = framework::proto::VarType::LOD_TENSOR; } - auto out_var_name = op_desc.Output("Out").front(); - auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); - out_var.SetType(var_type); - auto& in_var = detail::Ref(block->FindVarRecursive(inputs.front())); - out_var.SetDataType(in_var.GetDataType()); + auto out_var_name = ctx.Output("Out").front(); + ctx.SetType(out_var_name, var_type); + ctx.SetDataType(out_var_name, ctx.GetDataType(inputs.front())); } }; diff --git a/paddle/fluid/operators/tensor_array_to_tensor_op.cc b/paddle/fluid/operators/tensor_array_to_tensor_op.cc index 58a74ec2c1..d7f67ccb2f 100644 --- a/paddle/fluid/operators/tensor_array_to_tensor_op.cc +++ b/paddle/fluid/operators/tensor_array_to_tensor_op.cc @@ -177,10 +177,9 @@ class LoDTensorArray2TensorGradInferShape : public framework::InferShapeBase { class LoDTensorArray2TensorGradInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - for (auto &out_var : op_desc.Output(framework::GradVarName("X"))) { - block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY); + void operator()(framework::InferVarTypeContext &ctx) const override { + for (auto &out_var : ctx.Output(framework::GradVarName("X"))) { + ctx.SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY); } } }; diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc index a8c86de9f9..845629d40f 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc @@ -46,8 +46,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { class TensorRTEngineInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override {} + void operator()(framework::InferVarTypeContext &ctx) const override {} }; } // namespace operators diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index e3132ae76f..b3a8b6a141 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -112,17 +112,15 @@ uniform distribution. The random result is in set [min, max]. class UniformRandomOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - auto out_var_name = op_desc.Output("Out").front(); + void operator()(framework::InferVarTypeContext &ctx) const override { + auto out_var_name = ctx.Output("Out").front(); auto var_data_type = static_cast( - boost::get(op_desc.GetAttr("dtype"))); + boost::get(ctx.GetAttr("dtype"))); - auto out_var = block->FindRecursiveOrCreateVar(out_var_name); - if (out_var.GetType() != framework::proto::VarType::SELECTED_ROWS) { - out_var.SetType(framework::proto::VarType::LOD_TENSOR); + if (ctx.GetType(out_var_name) != framework::proto::VarType::SELECTED_ROWS) { + ctx.SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); } - out_var.SetDataType(var_data_type); + ctx.SetDataType(out_var_name, var_data_type); } };