|
|
@ -48,18 +48,49 @@ public:
|
|
|
|
return node;
|
|
|
|
return node;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int CountOfAtomicCleanNode() {
|
|
|
|
|
|
|
|
int node_num = 0;
|
|
|
|
|
|
|
|
for (NodePtr &node : graph_->GetDirectNode()) {
|
|
|
|
|
|
|
|
if (node->GetType() == ATOMICADDRCLEAN) {
|
|
|
|
|
|
|
|
++node_num;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return node_num;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ComputeGraphPtr graph_;
|
|
|
|
ComputeGraphPtr graph_;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
// node1 -> node2 -> node3
|
|
|
|
/*
|
|
|
|
|
|
|
|
* Data Data Atomic_clean
|
|
|
|
|
|
|
|
* | | / |
|
|
|
|
|
|
|
|
* relu relu |
|
|
|
|
|
|
|
|
* | ==> | |
|
|
|
|
|
|
|
|
* relu(atomic) relu(atomic)
|
|
|
|
|
|
|
|
* | |
|
|
|
|
|
|
|
|
* netoutput netoutput
|
|
|
|
|
|
|
|
*/
|
|
|
|
TEST_F(UtestGraphPassesAtomicAddrCleanPass, pass_run_success) {
|
|
|
|
TEST_F(UtestGraphPassesAtomicAddrCleanPass, pass_run_success) {
|
|
|
|
auto node1 = NewNode("node1", DATA, 0, 1);
|
|
|
|
auto node1 = NewNode("node1", DATA, 0, 1);
|
|
|
|
|
|
|
|
|
|
|
|
auto node2 = NewNode("node2", RELU, 1, 1);
|
|
|
|
auto node2 = NewNode("node2", RELU, 1, 1);
|
|
|
|
auto node3 = NewNode("node3", NETOUTPUT, 1, 0);
|
|
|
|
auto node3 = NewNode("node3", RELU, 1, 1);
|
|
|
|
|
|
|
|
auto op_desc = node3->GetOpDesc();
|
|
|
|
|
|
|
|
vector<int64_t> atomic_input_index = {123, 456};
|
|
|
|
|
|
|
|
AttrUtils::SetListInt(op_desc, "atomic_input_index", atomic_input_index);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto node4 = NewNode("node4", NETOUTPUT, 1, 0);
|
|
|
|
GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node2->GetInDataAnchor(0));
|
|
|
|
GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node2->GetInDataAnchor(0));
|
|
|
|
GraphUtils::AddEdge(node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0));
|
|
|
|
GraphUtils::AddEdge(node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0));
|
|
|
|
|
|
|
|
GraphUtils::AddEdge(node3->GetOutDataAnchor(0), node4->GetInDataAnchor(0));
|
|
|
|
AtomicAddrCleanPass atomi_addr_clean_pass;
|
|
|
|
AtomicAddrCleanPass atomi_addr_clean_pass;
|
|
|
|
Status ret = atomi_addr_clean_pass.Run(graph_);
|
|
|
|
Status ret = atomi_addr_clean_pass.Run(graph_);
|
|
|
|
EXPECT_EQ(ret, SUCCESS);
|
|
|
|
EXPECT_EQ(ret, SUCCESS);
|
|
|
|
|
|
|
|
EXPECT_EQ(1, CountOfAtomicCleanNode());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto atomic_clean = graph_->FindNode("atomic_addr_clean");
|
|
|
|
|
|
|
|
EXPECT_NE(atomic_clean, nullptr);
|
|
|
|
|
|
|
|
auto out_ctrl_nodes = atomic_clean->GetOutControlNodes();
|
|
|
|
|
|
|
|
EXPECT_EQ(out_ctrl_nodes.size(), 2);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} // namespace ge
|
|
|
|
} // namespace ge
|
|
|
|