improve efficiency of runtime InferVarType (#22778)

* save InferVarType changes, test=develop

* remove code comments, test=develop

* tweak code, test=develop

* fix compilation warning, update merge_ids_op split_ids_op to new interface, test=develop

* modify fused_bn_activation_op, test=develop

* fix error of fused_bn_activation_op, test=develop

* fix PADDLE_ENFORCE and unittest coverage issue, test=develop

* tweak PADDLE_ENFORCE messages, test=develop

* improve unittest coverage, test=develop

* add StaticGraphInferVarType class, test=develop

* rebase develop branch, test=develop

* fix unittest error, test=develop

* remove comments, test=develop

* improve unittest coverage, test=develop

* imporve error message and imporve unittest coverage, test=develop

* upgrade InferVarType API, test=develop

* tweak pyfunc error message, test=develop

* fix compilation conflict - save_combine_op, test=develop
revert-22778-infer_var_type
liuwei1031 5 years ago committed by GitHub
parent a8eac7da61
commit 9a93f6aae0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -45,19 +45,13 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference {
public:
void operator()(InferVarTypeContext *ctx) const override {
auto &inputs = ctx->Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
if (ctx->InputTypeAnyOf("X", proto::VarType::LOD_TENSOR)) {
default_var_type = proto::VarType::LOD_TENSOR;
}
auto out_var_name = ctx->Output("Out").front();
ctx->SetType(out_var_name, default_var_type);
ctx->SetOutputType("Out", default_var_type);
}
};

File diff suppressed because it is too large Load Diff

@ -24,13 +24,13 @@ namespace framework {
class NOP : public OperatorBase {
public:
NOP(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs)
NOP(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope &scope,
const platform::Place &place) const override {}
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
};
class SumOpMaker : public OpProtoAndCheckerMaker {
@ -44,20 +44,14 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto &inputs = ctx->Input("X");
void operator()(framework::InferVarTypeContext* ctx) const override {
auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
if (ctx->InputTypeAnyOf("X", proto::VarType::LOD_TENSOR)) {
default_var_type = proto::VarType::LOD_TENSOR;
}
auto out_var_name = ctx->Output("Out").front();
ctx->SetType(out_var_name, default_var_type);
ctx->SetOutputType("Out", default_var_type);
}
};
} // namespace framework
@ -71,9 +65,79 @@ REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP,
namespace paddle {
namespace framework {
class TestStaticGraphVarTypeInference : public StaticGraphVarTypeInference {
public:
void operator()(InferVarTypeContext* context) const override {}
bool HasVar(InferVarTypeContext* ctx, const std::string& name) const {
return StaticGraphVarTypeInference::HasVar(ctx, name);
}
const std::vector<std::string>& Input(InferVarTypeContext* ctx,
const std::string& name) const {
return StaticGraphVarTypeInference::Input(ctx, name);
}
const std::vector<std::string>& Output(InferVarTypeContext* ctx,
const std::string& name) const {
return StaticGraphVarTypeInference::Output(ctx, name);
}
proto::VarType::Type GetType(InferVarTypeContext* ctx,
const std::string& name) const {
return StaticGraphVarTypeInference::GetType(ctx, name);
}
void SetType(InferVarTypeContext* ctx, const std::string& name,
proto::VarType::Type type) const {
StaticGraphVarTypeInference::SetType(ctx, name, type);
}
proto::VarType::Type GetDataType(InferVarTypeContext* ctx,
const std::string& name) const {
return StaticGraphVarTypeInference::GetDataType(ctx, name);
}
void SetDataType(InferVarTypeContext* ctx, const std::string& name,
proto::VarType::Type type) const {
StaticGraphVarTypeInference::SetDataType(ctx, name, type);
}
std::vector<proto::VarType::Type> GetDataTypes(
InferVarTypeContext* ctx, const std::string& name) const {
return StaticGraphVarTypeInference::GetDataTypes(ctx, name);
}
void SetDataTypes(
InferVarTypeContext* ctx, const std::string& name,
const std::vector<proto::VarType::Type>& multiple_data_type) {
return StaticGraphVarTypeInference::SetDataTypes(ctx, name,
multiple_data_type);
}
std::vector<int64_t> GetShape(InferVarTypeContext* ctx,
const std::string& name) const {
return StaticGraphVarTypeInference::GetShape(ctx, name);
}
void SetShape(InferVarTypeContext* ctx, const std::string& name,
const std::vector<int64_t>& dims) const {
StaticGraphVarTypeInference::SetShape(ctx, name, dims);
}
int32_t GetLoDLevel(InferVarTypeContext* ctx, const std::string& name) const {
return StaticGraphVarTypeInference::GetLoDLevel(ctx, name);
}
void SetLoDLevel(InferVarTypeContext* ctx, const std::string& name,
int32_t lod_level) const {
StaticGraphVarTypeInference::SetLoDLevel(ctx, name, lod_level);
}
};
TEST(InferVarType, sum_op) {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"});
@ -96,7 +160,7 @@ TEST(InferVarType, sum_op) {
TEST(InferVarType, sum_op_without_infer_var_type) {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum_without_infer_var_type");
op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
op->SetOutput("Out", {"test2_out"});
@ -112,5 +176,112 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
prog.MutableBlock(0)->Var("test2_out")->GetType());
}
TEST(InferVarType, multiple_api) {
ProgramDesc prog;
auto* block = prog.MutableBlock(0);
auto* op = block->AppendOp();
op->SetType("sum_without_infer_var_type");
op->SetInput("X", {"test2_a", "test2_b"});
op->SetOutput("Out", {"test2_a_out", "test2_b_out"});
block->Var("test2_a")->SetType(proto::VarType::SELECTED_ROWS);
block->Var("test2_b")->SetType(proto::VarType::SELECTED_ROWS);
block->Var("test2_a_out");
block->Var("test2_b_out");
InferVarTypeContext ctx(op, block);
ASSERT_TRUE(ctx.HasInput("X"));
ASSERT_TRUE(ctx.HasOutput("Out"));
ASSERT_EQ(2u, ctx.InputSize("X"));
ASSERT_EQ("test2_a", ctx.InputVarName("X", 0));
ASSERT_EQ(proto::VarType::SELECTED_ROWS, ctx.GetInputType("X"));
ASSERT_TRUE(ctx.InputTypeAllOf("X", proto::VarType::SELECTED_ROWS));
ASSERT_FALSE(ctx.InputTypeAnyOf("X", proto::VarType::LOD_TENSOR));
ctx.SyncTypeAndDataType("X", "Out");
ASSERT_EQ(proto::VarType::SELECTED_ROWS, ctx.GetOutputType("Out"));
ASSERT_EQ(proto::VarType::LOD_TENSOR, ctx.GetOutputType("Out", 1));
ctx.SetOutputType("Out", proto::VarType::SELECTED_ROWS, ALL_ELEMENTS);
ctx.SetOutputType("Out", proto::VarType::LOD_TENSOR, 1);
ASSERT_EQ(proto::VarType::SELECTED_ROWS, ctx.GetOutputType("Out"));
ASSERT_EQ(proto::VarType::LOD_TENSOR, ctx.GetOutputType("Out", 1));
ASSERT_EQ(0, ctx.GetInputDataType("X"));
ctx.SetOutputDataType("Out", proto::VarType::FP32, ALL_ELEMENTS);
ctx.SetOutputDataType("Out", proto::VarType::INT8, 1);
ASSERT_EQ(proto::VarType::FP32,
prog.MutableBlock(0)->Var("test2_a_out")->GetDataType());
ASSERT_EQ(proto::VarType::INT8,
prog.MutableBlock(0)->Var("test2_b_out")->GetDataType());
ASSERT_FALSE(ctx.IsDygraph());
// test StaticGraphVarTypeInference
TestStaticGraphVarTypeInference infer;
ASSERT_TRUE(infer.HasVar(&ctx, "test2_a"));
ASSERT_EQ(infer.Input(&ctx, "X").size(), infer.Output(&ctx, "Out").size());
ASSERT_EQ(proto::VarType::FP32, infer.GetDataType(&ctx, "test2_a_out"));
infer.SetDataType(&ctx, "test2_a_out", proto::VarType::FP64);
ASSERT_EQ(proto::VarType::FP64, infer.GetDataType(&ctx, "test2_a_out"));
ASSERT_EQ(proto::VarType::SELECTED_ROWS, infer.GetType(&ctx, "test2_a_out"));
infer.SetType(&ctx, "test2_a_out", proto::VarType::LOD_TENSOR);
ASSERT_EQ(proto::VarType::LOD_TENSOR, infer.GetType(&ctx, "test2_a_out"));
ASSERT_ANY_THROW(infer.GetDataTypes(&ctx, "test2_a_out"));
ASSERT_ANY_THROW(infer.SetDataTypes(&ctx, "test2_a_out", {}));
ASSERT_EQ(0u, infer.GetShape(&ctx, "test2_a_out").size());
infer.SetShape(&ctx, "test2_a_out", {
1, 3, 3,
});
ASSERT_EQ(3u, infer.GetShape(&ctx, "test2_a_out").size());
ASSERT_EQ(0, infer.GetLoDLevel(&ctx, "test2_a_out"));
infer.SetLoDLevel(&ctx, "test2_a_out", 2);
ASSERT_EQ(2, infer.GetLoDLevel(&ctx, "test2_a_out"));
}
TEST(InferVarType, test_enforce_check) {
InferVarTypeContext ctx(nullptr, nullptr);
ASSERT_ANY_THROW(ctx.HasInput("X"));
ASSERT_ANY_THROW(ctx.HasOutput("Out"));
ASSERT_ANY_THROW(ctx.InputSize("X"));
ASSERT_ANY_THROW(ctx.InputVarName("X"));
ASSERT_ANY_THROW(ctx.InputTypeAnyOf("X", proto::VarType::LOD_TENSOR));
ASSERT_ANY_THROW(ctx.InputTypeAllOf("X", proto::VarType::LOD_TENSOR));
ASSERT_ANY_THROW(ctx.SyncTypeAndDataType("X", "Out"));
ASSERT_ANY_THROW(ctx.SetOutputType("Out", proto::VarType::LOD_TENSOR));
ASSERT_ANY_THROW(ctx.GetInputType("X"));
ASSERT_ANY_THROW(ctx.GetOutputType("Out"));
ASSERT_ANY_THROW(ctx.GetInputDataType("X"));
ASSERT_ANY_THROW(ctx.SetOutputDataType("Out", proto::VarType::LOD_TENSOR));
ASSERT_ANY_THROW(ctx.GetInputDataTypes("X"));
ASSERT_ANY_THROW(ctx.SetOutputDataTypes("Out", {}));
ASSERT_ANY_THROW(ctx.GetInputShape("X"));
ASSERT_ANY_THROW(ctx.SetOutputShape("Out", {}));
ASSERT_ANY_THROW(ctx.GetInputLoDLevel("X"));
ASSERT_ANY_THROW(ctx.SetOutputLoDLevel("Out", 1));
ASSERT_ANY_THROW(ctx.InsertVar("var", proto::VarType::LOD_TENSOR));
}
} // namespace framework
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -37,33 +37,154 @@ using vb_vector = std::vector<std::shared_ptr<imperative::VarBase>>;
using var_pair = std::pair<std::string, vb_vector>;
template <typename VarType>
class TestRuntimeInferVarTypeContext
: public RuntimeInferVarTypeContext<VarType> {
public:
TestRuntimeInferVarTypeContext(const NameVarMap<VarType>& inputs,
const NameVarMap<VarType>& outputs,
const framework::AttributeMap& attrs_map)
: RuntimeInferVarTypeContext<VarType>(inputs, outputs, attrs_map) {}
bool HasVar(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::HasVar(name);
}
const std::vector<std::string>& InputVars(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::InputVars(name);
}
const std::vector<std::string>& OutputVars(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::OutputVars(name);
}
framework::proto::VarType::Type GetVarType(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::GetVarType(name);
}
void SetVarType(const std::string& name,
framework::proto::VarType::Type type) {
RuntimeInferVarTypeContext<VarType>::SetVarType(name, type);
}
framework::proto::VarType::Type GetVarDataType(
const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::GetVarDataType(name);
}
void SetVarDataType(const std::string& name,
framework::proto::VarType::Type type) {
RuntimeInferVarTypeContext<VarType>::SetVarDataType(name, type);
}
std::vector<framework::proto::VarType::Type> GetVarDataTypes(
const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::GetVarDataTypes(name);
}
void SetVarDataTypes(
const std::string& name,
const std::vector<framework::proto::VarType::Type>& multiple_data_type) {
RuntimeInferVarTypeContext<VarType>::SetVarDataTypes(name,
multiple_data_type);
}
std::vector<int64_t> GetVarShape(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::GetVarShape(name);
}
void SetVarShape(const std::string& name, const std::vector<int64_t>& dims) {
RuntimeInferVarTypeContext<VarType>::SetVarShape(name, dims);
}
int32_t GetVarLoDLevel(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::GetVarLoDLevel(name);
}
void SetVarLoDLevel(const std::string& name, int32_t lod_level) {
RuntimeInferVarTypeContext<VarType>::SetVarLoDLevel(name, lod_level);
}
};
TEST(test_layer, test_runtime_context) {
std::shared_ptr<imperative::VarBase> vin(
new imperative::VarBase(false, "vin"));
std::shared_ptr<imperative::VarBase> vin_b(
new imperative::VarBase(false, "vin_b"));
std::shared_ptr<imperative::VarBase> vout(
new imperative::VarBase(false, "vout"));
var_pair in_pair = var_pair("X", vb_vector(1, vin));
var_pair out_pair = var_pair("Out", vb_vector(1, vout));
std::shared_ptr<imperative::VarBase> vout_b(
new imperative::VarBase(false, "vout_b"));
var_pair in_pair = var_pair("X", {vin, vin_b});
var_pair out_pair = var_pair("Out", {vout, vout_b});
imperative::NameVarBaseMap ins = {in_pair};
imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap attrs;
auto *ctx = new imperative::RuntimeInferVarTypeContext<imperative::VarBase>(
ins, outs, attrs);
ASSERT_TRUE(ctx->HasVar("vin"));
auto* ctx =
new imperative::TestRuntimeInferVarTypeContext<imperative::VarBase>(
ins, outs, attrs);
ASSERT_TRUE(ctx->HasInput("X"));
ASSERT_TRUE(ctx->HasOutput("Out"));
ASSERT_ANY_THROW(ctx->GetDataTypes("vin"));
ASSERT_EQ(2u, ctx->InputSize("X"));
ASSERT_EQ("vin", ctx->InputVarName("X", 0));
ASSERT_TRUE(ctx->InputTypeAnyOf("X", framework::proto::VarType::LOD_TENSOR));
ASSERT_TRUE(ctx->InputTypeAllOf("X", framework::proto::VarType::LOD_TENSOR));
ASSERT_EQ(framework::proto::VarType::LOD_TENSOR, ctx->GetInputType("X"));
ASSERT_EQ(framework::proto::VarType::FP32, ctx->GetInputDataType("X"));
ctx->SyncTypeAndDataType("X", "Out");
ASSERT_EQ(framework::proto::VarType::FP32, vout->DataType());
ASSERT_EQ(framework::proto::VarType::LOD_TENSOR, ctx->GetOutputType("Out"));
ctx->SetOutputType("Out", framework::proto::VarType::SELECTED_ROWS,
framework::ALL_ELEMENTS);
ctx->SetOutputType("Out", framework::proto::VarType::LOD_TENSOR_ARRAY);
ASSERT_EQ(framework::proto::VarType::LOD_TENSOR_ARRAY, vout->Type());
ASSERT_EQ(framework::proto::VarType::SELECTED_ROWS, vout_b->Type());
ctx->SetOutputDataType("Out", framework::proto::VarType::FP64,
framework::ALL_ELEMENTS);
ctx->SetOutputDataType("Out", framework::proto::VarType::INT8);
ASSERT_EQ(framework::proto::VarType::INT8, vout->DataType());
ASSERT_EQ(framework::proto::VarType::FP64, vout_b->DataType());
// no throw, but do nothing
ASSERT_NO_THROW(
ctx->InsertVar("vout", framework::proto::VarType::LOD_TENSOR));
ASSERT_EQ(framework::proto::VarType::LOD_TENSOR_ARRAY, vout->Type());
ASSERT_ANY_THROW(ctx->HasVar("vin"));
ASSERT_ANY_THROW(ctx->InputVars("X"));
ASSERT_ANY_THROW(ctx->OutputVars("Out"));
ASSERT_ANY_THROW(ctx->GetVarType("vin"));
ASSERT_ANY_THROW(
ctx->SetVarType("vin", framework::proto::VarType::LOD_TENSOR));
ASSERT_ANY_THROW(ctx->GetVarDataType("vin"));
ASSERT_ANY_THROW(
ctx->SetVarDataType("vout", framework::proto::VarType::FP32));
ASSERT_ANY_THROW(ctx->GetVarDataTypes("vin"));
std::vector<framework::proto::VarType::Type> NullType;
ASSERT_ANY_THROW(ctx->SetDataTypes("vin", NullType));
ASSERT_ANY_THROW(ctx->GetShape("vin"));
ASSERT_ANY_THROW(ctx->GetLoDLevel("vin"));
ASSERT_ANY_THROW(ctx->SetLoDLevel("vin", 2));
ASSERT_ANY_THROW(ctx->SetVarDataTypes("vin", NullType));
ASSERT_ANY_THROW(ctx->GetVarShape("vin"));
ASSERT_ANY_THROW(ctx->SetVarShape("vin", {}));
ASSERT_ANY_THROW(ctx->GetVarLoDLevel("vin"));
ASSERT_ANY_THROW(ctx->SetVarLoDLevel("vin", 2));
ASSERT_TRUE(ctx->IsDygraph());
}
std::string LayerDebugString(const std::string &op_type,
const NameVarBaseMap &ins,
const NameVarBaseMap &outs);
std::string LayerDebugString(const std::string& op_type,
const NameVarBaseMap& ins,
const NameVarBaseMap& outs);
TEST(test_layer, test_debug_string) {
platform::CPUPlace place;
@ -71,7 +192,7 @@ TEST(test_layer, test_debug_string) {
new imperative::VarBase(false, "vin"));
var_pair in_pair = var_pair("X", vb_vector(1, vin));
auto test_func = [&](std::shared_ptr<imperative::VarBase> &vout) {
auto test_func = [&](std::shared_ptr<imperative::VarBase>& vout) {
var_pair out_pair = var_pair("Out", vb_vector(1, vout));
imperative::NameVarBaseMap ins = {in_pair};
imperative::NameVarBaseMap outs = {out_pair};
@ -124,26 +245,26 @@ TEST(test_layer, test_debug_string) {
}
static std::shared_ptr<imperative::GradOpNode> CreateGradNode(
size_t id, const std::string &type, const imperative::NameVarBaseMap &ins,
const imperative::NameVarBaseMap &outs,
const framework::AttributeMap &attrs, const platform::Place &place) {
size_t id, const std::string& type, const imperative::NameVarBaseMap& ins,
const imperative::NameVarBaseMap& outs,
const framework::AttributeMap& attrs, const platform::Place& place) {
auto node = std::make_shared<imperative::GradOpNode>();
auto *op = &(node->emplace_back());
auto* op = &(node->emplace_back());
op->SetId(id);
op->SetPlace(place);
op->SetType(type);
op->SetAttrMap(attrs);
for (auto &pair : ins) {
for (auto& pair : ins) {
std::vector<std::shared_ptr<VariableWrapper>> vars;
for (auto &var : pair.second) {
for (auto& var : pair.second) {
vars.emplace_back(var->SharedVar());
}
op->SetInput(pair.first, vars, false);
}
for (auto &pair : outs) {
for (auto& pair : outs) {
std::vector<std::shared_ptr<VariableWrapper>> vars;
for (auto &var : pair.second) {
for (auto& var : pair.second) {
vars.emplace_back(var->SharedVar());
}
op->SetOutput(pair.first, vars, false);
@ -173,7 +294,7 @@ TEST(test_layer, test_clear_backward_info) {
node->InsertGradPendingNode(pending_node);
ASSERT_EQ(node->size(), 1UL);
auto *op = &(node->back());
auto* op = &(node->back());
ASSERT_GT(op->GetInsMap().size(), 0UL);
ASSERT_GT(op->GetOutsMap().size(), 0UL);
@ -196,10 +317,10 @@ TEST(test_layer, test_varbase_basic) {
std::shared_ptr<imperative::VarBase> vin_with_grad(
new imperative::VarBase(true, "vin"));
ASSERT_ANY_THROW(vin->MutableGradVar());
ASSERT_NO_THROW(ASSERT_TRUE(dynamic_cast<framework::Variable *>(
ASSERT_NO_THROW(ASSERT_TRUE(dynamic_cast<framework::Variable*>(
vin_with_grad->MutableGradVar()) != 0));
ASSERT_TRUE(dynamic_cast<framework::Variable *>(
vin_with_grad->MutableGradVar()) != 0);
ASSERT_TRUE(
dynamic_cast<framework::Variable*>(vin_with_grad->MutableGradVar()) != 0);
vin_with_grad->SetOverridedStopGradient(false);
ASSERT_FALSE(vin_with_grad->OverridedStopGradient());
ASSERT_NO_FATAL_FAILURE(vin_with_grad->SetPersistable(true));
@ -228,9 +349,9 @@ TEST(test_layer, test_dygraph_execution_context) {
auto op = framework::OpRegistry::CreateOp("mul", {}, {}, {}, false);
paddle::platform::CPUPlace cpu_place;
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto *dev_ctx = pool.Get(cpu_place);
auto* dev_ctx = pool.Get(cpu_place);
paddle::framework::RuntimeContext ctx({}, {});
framework::Scope scope;

@ -129,9 +129,10 @@ class ActivationOp : public framework::OperatorWithKernel {
class ActivationOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};

@ -103,8 +103,7 @@ class AllcloseOp : public framework::OperatorWithKernel {
class AllcloseOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output("Out").front();
ctx->SetDataType(out_var_name, framework::proto::VarType::BOOL);
ctx->SetOutputDataType("Out", framework::proto::VarType::BOOL);
}
};

@ -60,11 +60,7 @@ class AssignOp : public framework::OperatorWithKernel {
class AssignInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output("Out")[0];
auto input_type = ctx->GetType(ctx->Input("X")[0]);
auto input_data_type = ctx->GetDataType(ctx->Input("X")[0]);
ctx->SetType(out_var_name, input_type);
ctx->SetDataType(out_var_name, input_data_type);
ctx->SyncTypeAndDataType("X", "Out");
}
};

@ -171,9 +171,10 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker<T> {
class BatchNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Y"}};
return m;
}
};

@ -204,12 +204,10 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
for (auto& o : ctx->Output("SentenceIds")) {
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
for (auto& o : ctx->Output("SentenceScores")) {
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
ctx->SetOutputType("SentenceIds", framework::proto::VarType::LOD_TENSOR,
framework::ALL_ELEMENTS);
ctx->SetOutputType("SentenceScores", framework::proto::VarType::LOD_TENSOR,
framework::ALL_ELEMENTS);
}
};

@ -122,12 +122,10 @@ class BeamSearchOp : public framework::OperatorWithKernel {
class BeamSearchInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &o : ctx->Output("selected_ids")) {
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
for (auto &o : ctx->Output("selected_scores")) {
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
ctx->SetOutputType("selected_ids", framework::proto::VarType::LOD_TENSOR,
framework::ALL_ELEMENTS);
ctx->SetOutputType("selected_scores", framework::proto::VarType::LOD_TENSOR,
framework::ALL_ELEMENTS);
}
};

@ -94,9 +94,8 @@ execution.
class GetPlacesInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &o_name : ctx->Output("Out")) {
ctx->SetType(o_name, framework::proto::VarType::PLACE_LIST);
}
ctx->SetOutputType("Out", framework::proto::VarType::PLACE_LIST,
framework::ALL_ELEMENTS);
}
};

@ -111,15 +111,15 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
}
};
class WriteToArrayInferVarType : public framework::VarTypeInference {
class WriteToArrayInferVarType : public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = ctx->Input("X")[0];
auto out_name = ctx->Output("Out")[0];
auto x_name = Input(ctx, "X")[0];
auto out_name = Output(ctx, "Out")[0];
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
ctx->SetType(out_name, framework::proto::VarType::LOD_TENSOR_ARRAY);
if (ctx->HasVar(x_name)) {
ctx->SetDataType(out_name, ctx->GetDataType(x_name));
SetType(ctx, out_name, framework::proto::VarType::LOD_TENSOR_ARRAY);
if (HasVar(ctx, x_name)) {
SetDataType(ctx, out_name, GetDataType(ctx, x_name));
}
}
};

@ -434,18 +434,19 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
class WhileGradOpVarTypeInference : public framework::VarTypeInference {
class WhileGradOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto p_names = ctx->Input(kX);
auto pg_ig_names = ctx->Output(framework::GradVarName(kX));
auto p_names = Input(ctx, kX);
auto pg_ig_names = Output(ctx, framework::GradVarName(kX));
for (size_t i = 0; i < p_names.size(); ++i) {
if (ctx->HasVar(pg_ig_names[i])) {
if (HasVar(ctx, pg_ig_names[i])) {
VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i]
<< " type: " << ctx->GetType(p_names[i]);
ctx->SetType(pg_ig_names[i], ctx->GetType(p_names[i]));
ctx->SetDataType(pg_ig_names[i], ctx->GetDataType(p_names[i]));
<< " type: " << GetType(ctx, p_names[i]);
SetType(ctx, pg_ig_names[i], GetType(ctx, p_names[i]));
SetDataType(ctx, pg_ig_names[i], GetDataType(ctx, p_names[i]));
}
}
}

@ -254,10 +254,11 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{
static std::unordered_map<std::string, std::string> m{
{"Input", /*->*/ "Output"}};
return m;
}
};

@ -177,9 +177,10 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
class CrossEntropyOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Y"}};
return m;
}
};

@ -115,10 +115,8 @@ class MergeIdsOp : public framework::OperatorWithKernel {
class MergeIdsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto input_type = ctx->GetType(ctx->Input("Ids")[0]);
for (auto &out_var : ctx->Output("Out")) {
ctx->SetType(out_var, input_type);
}
auto input_type = ctx->GetInputType("Ids");
ctx->SetOutputType("Out", input_type, framework::ALL_ELEMENTS);
}
};

@ -73,10 +73,8 @@ class SplitIdsOp : public framework::OperatorWithKernel {
class SplitIdsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto input_type = ctx->GetType(ctx->Input("Ids")[0]);
for (auto &out_var : ctx->Output("Out")) {
ctx->SetType(out_var, input_type);
}
auto input_type = ctx->GetInputType("Ids");
ctx->SetOutputType("Out", input_type, framework::ALL_ELEMENTS);
}
};

@ -119,9 +119,10 @@ class ElementwiseOp : public framework::OperatorWithKernel {
class ElementwiseOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};

@ -49,8 +49,7 @@ class EyeOpVarTypeInference : public framework::VarTypeInference {
void operator()(framework::InferVarTypeContext* ctx) const override {
auto data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype")));
auto& out_var_name = ctx->Output("Out").front();
ctx->SetDataType(out_var_name, data_type);
ctx->SetOutputDataType("Out", data_type);
}
};

@ -72,14 +72,12 @@ The output will have the same shape and dtype as the input.
class FillAnyLikeVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output("Out").front();
auto var_data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype")));
if (var_data_type < 0) {
const auto &input_var_name = ctx->Input("X").front();
ctx->SetDataType(out_var_name, ctx->GetDataType(input_var_name));
ctx->SetOutputDataType("Out", ctx->GetInputDataType("X"));
} else {
ctx->SetDataType(out_var_name, var_data_type);
ctx->SetOutputDataType("Out", var_data_type);
}
}
};

@ -64,8 +64,7 @@ class FillConstantOpVarTypeInference : public framework::VarTypeInference {
void operator()(framework::InferVarTypeContext* ctx) const override {
auto data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype")));
auto& out_var_name = ctx->Output("Out").front();
ctx->SetDataType(out_var_name, data_type);
ctx->SetOutputDataType("Out", data_type);
}
};

@ -63,8 +63,7 @@ class FillOpVarTypeInference : public framework::VarTypeInference {
void operator()(framework::InferVarTypeContext* ctx) const override {
auto data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype")));
auto& out_var_name = ctx->Output("Out").front();
ctx->SetDataType(out_var_name, data_type);
ctx->SetOutputDataType("Out", data_type);
}
};

@ -114,9 +114,10 @@ class FlipOpMaker : public framework::OpProtoAndCheckerMaker {
class FlipOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};

@ -85,9 +85,10 @@ class FusedBatchNormActGradOpMaker : public framework::SingleGradOpMaker<T> {
class FusedBatchNormActOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Y"}};
return m;
}
};

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save