|
|
|
@ -36,7 +36,7 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
void Make() {
|
|
|
|
|
AddInput("X", "").AsDuplicable();
|
|
|
|
|
AddOutput("Out", "");
|
|
|
|
|
AddOutput("Out", "").AsDuplicable();
|
|
|
|
|
AddComment("");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -59,11 +59,27 @@ class SumOpVarTypeInference : public VarTypeInference {
|
|
|
|
|
block->Var(out_var_name)->SetType(default_var_type);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class DummyOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
void Make() {
|
|
|
|
|
AddInput("X", "").AsDuplicable();
|
|
|
|
|
AddOutput("Out", "").AsDuplicable();
|
|
|
|
|
AddComment("");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class DummyOpVarTypeInference : public VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {}
|
|
|
|
|
};
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(sum, paddle::framework::NOP, paddle::framework::SumOpMaker,
|
|
|
|
|
paddle::framework::SumOpVarTypeInference);
|
|
|
|
|
REGISTER_OPERATOR(dummy, paddle::framework::NOP, paddle::framework::SumOpMaker,
|
|
|
|
|
paddle::framework::SumOpVarTypeInference);
|
|
|
|
|
REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP,
|
|
|
|
|
paddle::framework::SumOpMaker);
|
|
|
|
|
|
|
|
|
@ -110,5 +126,83 @@ TEST(GraphTest, Basic) {
|
|
|
|
|
}
|
|
|
|
|
ASSERT_EQ(nodes.size(), 5);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(GraphTest, WriteAfterRead) {
|
|
|
|
|
// void Test() {
|
|
|
|
|
ProgramDesc prog;
|
|
|
|
|
auto *op = prog.MutableBlock(0)->AppendOp();
|
|
|
|
|
op->SetType("sum");
|
|
|
|
|
op->SetInput("X", {"a"});
|
|
|
|
|
op->SetOutput("Out", {"b"});
|
|
|
|
|
op->SetAttr("op_role", 1);
|
|
|
|
|
|
|
|
|
|
op = prog.MutableBlock(0)->AppendOp();
|
|
|
|
|
op->SetType("dummy");
|
|
|
|
|
op->SetInput("X", {"c"});
|
|
|
|
|
op->SetOutput("Out", {"a"});
|
|
|
|
|
op->SetAttr("op_role", 1);
|
|
|
|
|
|
|
|
|
|
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
|
|
|
|
|
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
|
|
|
|
|
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
|
|
|
|
|
ir::Node *control_dep1 = nullptr;
|
|
|
|
|
ir::Node *control_dep2 = nullptr;
|
|
|
|
|
for (ir::Node *n : g->Nodes()) {
|
|
|
|
|
if (n->Name() == "sum") {
|
|
|
|
|
ASSERT_EQ(n->outputs[0]->Name(), "b");
|
|
|
|
|
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
|
|
|
|
|
control_dep1 = n->outputs[1];
|
|
|
|
|
ASSERT_EQ(n->outputs.size(), 2);
|
|
|
|
|
}
|
|
|
|
|
if (n->Name() == "dummy") {
|
|
|
|
|
ASSERT_EQ(n->inputs[0]->Name(), "c");
|
|
|
|
|
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
|
|
|
|
|
control_dep2 = n->inputs[1];
|
|
|
|
|
ASSERT_EQ(n->inputs.size(), 2);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ASSERT_EQ(control_dep1, control_dep2);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(GraphTest, WriteAfterWrite) {
|
|
|
|
|
// void Test() {
|
|
|
|
|
ProgramDesc prog;
|
|
|
|
|
auto *op = prog.MutableBlock(0)->AppendOp();
|
|
|
|
|
op->SetType("sum");
|
|
|
|
|
op->SetInput("X", {"a"});
|
|
|
|
|
op->SetOutput("Out", {"b"});
|
|
|
|
|
op->SetAttr("op_role", 1);
|
|
|
|
|
|
|
|
|
|
op = prog.MutableBlock(0)->AppendOp();
|
|
|
|
|
op->SetType("dummy");
|
|
|
|
|
op->SetInput("X", {"c"});
|
|
|
|
|
op->SetOutput("Out", {"b"});
|
|
|
|
|
op->SetAttr("op_role", 1);
|
|
|
|
|
|
|
|
|
|
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
|
|
|
|
|
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
|
|
|
|
|
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
|
|
|
|
|
ir::Node *control_dep1 = nullptr;
|
|
|
|
|
ir::Node *control_dep2 = nullptr;
|
|
|
|
|
for (ir::Node *n : g->Nodes()) {
|
|
|
|
|
if (n->Name() == "sum") {
|
|
|
|
|
ASSERT_EQ(n->outputs[0]->Name(), "b");
|
|
|
|
|
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
|
|
|
|
|
ASSERT_EQ(n->outputs.size(), 2);
|
|
|
|
|
control_dep1 = n->outputs[1];
|
|
|
|
|
}
|
|
|
|
|
if (n->Name() == "dummy") {
|
|
|
|
|
ASSERT_EQ(n->inputs[0]->Name(), "c");
|
|
|
|
|
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
|
|
|
|
|
control_dep2 = n->inputs[1];
|
|
|
|
|
ASSERT_EQ(n->inputs.size(), 2);
|
|
|
|
|
ASSERT_EQ(control_dep1, control_dep2);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|