|
|
|
@ -24,13 +24,13 @@ namespace framework {
|
|
|
|
|
|
|
|
|
|
class NOP : public OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
NOP(const std::string &type, const VariableNameMap &inputs,
|
|
|
|
|
const VariableNameMap &outputs, const AttributeMap &attrs)
|
|
|
|
|
NOP(const std::string& type, const VariableNameMap& inputs,
|
|
|
|
|
const VariableNameMap& outputs, const AttributeMap& attrs)
|
|
|
|
|
: OperatorBase(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void RunImpl(const Scope &scope,
|
|
|
|
|
const platform::Place &place) const override {}
|
|
|
|
|
void RunImpl(const Scope& scope,
|
|
|
|
|
const platform::Place& place) const override {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SumOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
@ -44,20 +44,14 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
|
|
|
|
|
class SumOpVarTypeInference : public VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferVarTypeContext *ctx) const override {
|
|
|
|
|
auto &inputs = ctx->Input("X");
|
|
|
|
|
void operator()(framework::InferVarTypeContext* ctx) const override {
|
|
|
|
|
auto default_var_type = proto::VarType::SELECTED_ROWS;
|
|
|
|
|
|
|
|
|
|
bool any_input_is_lod_tensor = std::any_of(
|
|
|
|
|
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
|
|
|
|
|
return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
|
|
|
|
|
});
|
|
|
|
|
if (any_input_is_lod_tensor) {
|
|
|
|
|
if (ctx->InputTypeAnyOf("X", proto::VarType::LOD_TENSOR)) {
|
|
|
|
|
default_var_type = proto::VarType::LOD_TENSOR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto out_var_name = ctx->Output("Out").front();
|
|
|
|
|
ctx->SetType(out_var_name, default_var_type);
|
|
|
|
|
ctx->SetOutputType("Out", default_var_type);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace framework
|
|
|
|
@ -71,9 +65,79 @@ REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP,
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
class TestStaticGraphVarTypeInference : public StaticGraphVarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(InferVarTypeContext* context) const override {}
|
|
|
|
|
|
|
|
|
|
bool HasVar(InferVarTypeContext* ctx, const std::string& name) const {
|
|
|
|
|
return StaticGraphVarTypeInference::HasVar(ctx, name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string>& Input(InferVarTypeContext* ctx,
|
|
|
|
|
const std::string& name) const {
|
|
|
|
|
return StaticGraphVarTypeInference::Input(ctx, name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string>& Output(InferVarTypeContext* ctx,
|
|
|
|
|
const std::string& name) const {
|
|
|
|
|
return StaticGraphVarTypeInference::Output(ctx, name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
proto::VarType::Type GetType(InferVarTypeContext* ctx,
|
|
|
|
|
const std::string& name) const {
|
|
|
|
|
return StaticGraphVarTypeInference::GetType(ctx, name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetType(InferVarTypeContext* ctx, const std::string& name,
|
|
|
|
|
proto::VarType::Type type) const {
|
|
|
|
|
StaticGraphVarTypeInference::SetType(ctx, name, type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
proto::VarType::Type GetDataType(InferVarTypeContext* ctx,
|
|
|
|
|
const std::string& name) const {
|
|
|
|
|
return StaticGraphVarTypeInference::GetDataType(ctx, name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetDataType(InferVarTypeContext* ctx, const std::string& name,
|
|
|
|
|
proto::VarType::Type type) const {
|
|
|
|
|
StaticGraphVarTypeInference::SetDataType(ctx, name, type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<proto::VarType::Type> GetDataTypes(
|
|
|
|
|
InferVarTypeContext* ctx, const std::string& name) const {
|
|
|
|
|
return StaticGraphVarTypeInference::GetDataTypes(ctx, name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetDataTypes(
|
|
|
|
|
InferVarTypeContext* ctx, const std::string& name,
|
|
|
|
|
const std::vector<proto::VarType::Type>& multiple_data_type) {
|
|
|
|
|
return StaticGraphVarTypeInference::SetDataTypes(ctx, name,
|
|
|
|
|
multiple_data_type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> GetShape(InferVarTypeContext* ctx,
|
|
|
|
|
const std::string& name) const {
|
|
|
|
|
return StaticGraphVarTypeInference::GetShape(ctx, name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetShape(InferVarTypeContext* ctx, const std::string& name,
|
|
|
|
|
const std::vector<int64_t>& dims) const {
|
|
|
|
|
StaticGraphVarTypeInference::SetShape(ctx, name, dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int32_t GetLoDLevel(InferVarTypeContext* ctx, const std::string& name) const {
|
|
|
|
|
return StaticGraphVarTypeInference::GetLoDLevel(ctx, name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetLoDLevel(InferVarTypeContext* ctx, const std::string& name,
|
|
|
|
|
int32_t lod_level) const {
|
|
|
|
|
StaticGraphVarTypeInference::SetLoDLevel(ctx, name, lod_level);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
TEST(InferVarType, sum_op) {
|
|
|
|
|
ProgramDesc prog;
|
|
|
|
|
auto *op = prog.MutableBlock(0)->AppendOp();
|
|
|
|
|
auto* op = prog.MutableBlock(0)->AppendOp();
|
|
|
|
|
op->SetType("sum");
|
|
|
|
|
op->SetInput("X", {"test_a", "test_b", "test_c"});
|
|
|
|
|
op->SetOutput("Out", {"test_out"});
|
|
|
|
@ -96,7 +160,7 @@ TEST(InferVarType, sum_op) {
|
|
|
|
|
|
|
|
|
|
TEST(InferVarType, sum_op_without_infer_var_type) {
|
|
|
|
|
ProgramDesc prog;
|
|
|
|
|
auto *op = prog.MutableBlock(0)->AppendOp();
|
|
|
|
|
auto* op = prog.MutableBlock(0)->AppendOp();
|
|
|
|
|
op->SetType("sum_without_infer_var_type");
|
|
|
|
|
op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
|
|
|
|
|
op->SetOutput("Out", {"test2_out"});
|
|
|
|
@ -112,5 +176,112 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
|
|
|
|
|
prog.MutableBlock(0)->Var("test2_out")->GetType());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(InferVarType, multiple_api) {
|
|
|
|
|
ProgramDesc prog;
|
|
|
|
|
|
|
|
|
|
auto* block = prog.MutableBlock(0);
|
|
|
|
|
auto* op = block->AppendOp();
|
|
|
|
|
op->SetType("sum_without_infer_var_type");
|
|
|
|
|
op->SetInput("X", {"test2_a", "test2_b"});
|
|
|
|
|
op->SetOutput("Out", {"test2_a_out", "test2_b_out"});
|
|
|
|
|
|
|
|
|
|
block->Var("test2_a")->SetType(proto::VarType::SELECTED_ROWS);
|
|
|
|
|
block->Var("test2_b")->SetType(proto::VarType::SELECTED_ROWS);
|
|
|
|
|
block->Var("test2_a_out");
|
|
|
|
|
block->Var("test2_b_out");
|
|
|
|
|
|
|
|
|
|
InferVarTypeContext ctx(op, block);
|
|
|
|
|
|
|
|
|
|
ASSERT_TRUE(ctx.HasInput("X"));
|
|
|
|
|
ASSERT_TRUE(ctx.HasOutput("Out"));
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(2u, ctx.InputSize("X"));
|
|
|
|
|
ASSERT_EQ("test2_a", ctx.InputVarName("X", 0));
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(proto::VarType::SELECTED_ROWS, ctx.GetInputType("X"));
|
|
|
|
|
|
|
|
|
|
ASSERT_TRUE(ctx.InputTypeAllOf("X", proto::VarType::SELECTED_ROWS));
|
|
|
|
|
ASSERT_FALSE(ctx.InputTypeAnyOf("X", proto::VarType::LOD_TENSOR));
|
|
|
|
|
|
|
|
|
|
ctx.SyncTypeAndDataType("X", "Out");
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(proto::VarType::SELECTED_ROWS, ctx.GetOutputType("Out"));
|
|
|
|
|
ASSERT_EQ(proto::VarType::LOD_TENSOR, ctx.GetOutputType("Out", 1));
|
|
|
|
|
|
|
|
|
|
ctx.SetOutputType("Out", proto::VarType::SELECTED_ROWS, ALL_ELEMENTS);
|
|
|
|
|
ctx.SetOutputType("Out", proto::VarType::LOD_TENSOR, 1);
|
|
|
|
|
ASSERT_EQ(proto::VarType::SELECTED_ROWS, ctx.GetOutputType("Out"));
|
|
|
|
|
ASSERT_EQ(proto::VarType::LOD_TENSOR, ctx.GetOutputType("Out", 1));
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(0, ctx.GetInputDataType("X"));
|
|
|
|
|
|
|
|
|
|
ctx.SetOutputDataType("Out", proto::VarType::FP32, ALL_ELEMENTS);
|
|
|
|
|
ctx.SetOutputDataType("Out", proto::VarType::INT8, 1);
|
|
|
|
|
ASSERT_EQ(proto::VarType::FP32,
|
|
|
|
|
prog.MutableBlock(0)->Var("test2_a_out")->GetDataType());
|
|
|
|
|
ASSERT_EQ(proto::VarType::INT8,
|
|
|
|
|
prog.MutableBlock(0)->Var("test2_b_out")->GetDataType());
|
|
|
|
|
|
|
|
|
|
ASSERT_FALSE(ctx.IsDygraph());
|
|
|
|
|
|
|
|
|
|
// test StaticGraphVarTypeInference
|
|
|
|
|
TestStaticGraphVarTypeInference infer;
|
|
|
|
|
ASSERT_TRUE(infer.HasVar(&ctx, "test2_a"));
|
|
|
|
|
ASSERT_EQ(infer.Input(&ctx, "X").size(), infer.Output(&ctx, "Out").size());
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(proto::VarType::FP32, infer.GetDataType(&ctx, "test2_a_out"));
|
|
|
|
|
infer.SetDataType(&ctx, "test2_a_out", proto::VarType::FP64);
|
|
|
|
|
ASSERT_EQ(proto::VarType::FP64, infer.GetDataType(&ctx, "test2_a_out"));
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(proto::VarType::SELECTED_ROWS, infer.GetType(&ctx, "test2_a_out"));
|
|
|
|
|
infer.SetType(&ctx, "test2_a_out", proto::VarType::LOD_TENSOR);
|
|
|
|
|
ASSERT_EQ(proto::VarType::LOD_TENSOR, infer.GetType(&ctx, "test2_a_out"));
|
|
|
|
|
|
|
|
|
|
ASSERT_ANY_THROW(infer.GetDataTypes(&ctx, "test2_a_out"));
|
|
|
|
|
ASSERT_ANY_THROW(infer.SetDataTypes(&ctx, "test2_a_out", {}));
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(0u, infer.GetShape(&ctx, "test2_a_out").size());
|
|
|
|
|
infer.SetShape(&ctx, "test2_a_out", {
|
|
|
|
|
1, 3, 3,
|
|
|
|
|
});
|
|
|
|
|
ASSERT_EQ(3u, infer.GetShape(&ctx, "test2_a_out").size());
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(0, infer.GetLoDLevel(&ctx, "test2_a_out"));
|
|
|
|
|
infer.SetLoDLevel(&ctx, "test2_a_out", 2);
|
|
|
|
|
ASSERT_EQ(2, infer.GetLoDLevel(&ctx, "test2_a_out"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(InferVarType, test_enforce_check) {
|
|
|
|
|
InferVarTypeContext ctx(nullptr, nullptr);
|
|
|
|
|
ASSERT_ANY_THROW(ctx.HasInput("X"));
|
|
|
|
|
ASSERT_ANY_THROW(ctx.HasOutput("Out"));
|
|
|
|
|
|
|
|
|
|
ASSERT_ANY_THROW(ctx.InputSize("X"));
|
|
|
|
|
ASSERT_ANY_THROW(ctx.InputVarName("X"));
|
|
|
|
|
|
|
|
|
|
ASSERT_ANY_THROW(ctx.InputTypeAnyOf("X", proto::VarType::LOD_TENSOR));
|
|
|
|
|
ASSERT_ANY_THROW(ctx.InputTypeAllOf("X", proto::VarType::LOD_TENSOR));
|
|
|
|
|
|
|
|
|
|
ASSERT_ANY_THROW(ctx.SyncTypeAndDataType("X", "Out"));
|
|
|
|
|
|
|
|
|
|
ASSERT_ANY_THROW(ctx.SetOutputType("Out", proto::VarType::LOD_TENSOR));
|
|
|
|
|
ASSERT_ANY_THROW(ctx.GetInputType("X"));
|
|
|
|
|
ASSERT_ANY_THROW(ctx.GetOutputType("Out"));
|
|
|
|
|
|
|
|
|
|
ASSERT_ANY_THROW(ctx.GetInputDataType("X"));
|
|
|
|
|
ASSERT_ANY_THROW(ctx.SetOutputDataType("Out", proto::VarType::LOD_TENSOR));
|
|
|
|
|
|
|
|
|
|
ASSERT_ANY_THROW(ctx.GetInputDataTypes("X"));
|
|
|
|
|
ASSERT_ANY_THROW(ctx.SetOutputDataTypes("Out", {}));
|
|
|
|
|
|
|
|
|
|
ASSERT_ANY_THROW(ctx.GetInputShape("X"));
|
|
|
|
|
ASSERT_ANY_THROW(ctx.SetOutputShape("Out", {}));
|
|
|
|
|
|
|
|
|
|
ASSERT_ANY_THROW(ctx.GetInputLoDLevel("X"));
|
|
|
|
|
ASSERT_ANY_THROW(ctx.SetOutputLoDLevel("Out", 1));
|
|
|
|
|
|
|
|
|
|
ASSERT_ANY_THROW(ctx.InsertVar("var", proto::VarType::LOD_TENSOR));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|