|
|
|
@ -131,6 +131,22 @@ InputRWType GetInputRwTypeInConflict(const std::set<int> &rw_type_set) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsSubgraphInputNode(const NodePtr &node) {
|
|
|
|
|
if ((node == nullptr) || (node->GetOpDesc() == nullptr) || (node->GetType() != DATA) ||
|
|
|
|
|
(node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsSubgraphOutputNode(const NodePtr &node) {
|
|
|
|
|
if ((node == nullptr) || (node->GetOpDesc() == nullptr) || (node->GetType() != NETOUTPUT) ||
|
|
|
|
|
(node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) {
|
|
|
|
|
if (src_node.GetOpDesc() == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
@ -377,7 +393,7 @@ bool CheckIdentityIsNearSubgraph(const Node &node) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// near entrance of subgraph
|
|
|
|
|
if (in_node->GetType() == DATA && NodeUtils::IsSubgraphInput(in_node)) {
|
|
|
|
|
if (IsSubgraphInputNode(in_node)) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
// near subgraph
|
|
|
|
@ -392,7 +408,7 @@ bool CheckIdentityIsNearSubgraph(const Node &node) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// near output of subgraph
|
|
|
|
|
if (out_node->GetType() == NETOUTPUT && NodeUtils::IsSubgraphOutput(out_node)) {
|
|
|
|
|
if (IsSubgraphOutputNode(out_node)) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
// near subgraph
|
|
|
|
|