|
|
|
@ -23,7 +23,7 @@ namespace ge {
|
|
|
|
|
constexpr uint32_t kZeroIndex = 0;
|
|
|
|
|
constexpr uint32_t kCaseInputBase = 1;
|
|
|
|
|
constexpr uint32_t kInvalidParent = 0x7fffffffU;
|
|
|
|
|
const char *const kMbatchNodeNameMark = "_ascend_mbatch_batch_";
|
|
|
|
|
const string kMbatchNodeNameMark = "_ascend_mbatch_batch_";
|
|
|
|
|
|
|
|
|
|
bool IsSameConstNode(const NodePtr &src_node, const NodePtr &dst_node) {
|
|
|
|
|
if ((src_node == nullptr) && (dst_node == nullptr)) {
|
|
|
|
@ -164,11 +164,16 @@ Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &gra
|
|
|
|
|
string node_full_name = peer_node->GetName();
|
|
|
|
|
size_t pos = node_full_name.find(kMbatchNodeNameMark);
|
|
|
|
|
if (pos == string::npos) {
|
|
|
|
|
GELOGE(FAILED, "Cannot find: %s of multi-batch in node: %s", kMbatchNodeNameMark, node_full_name.c_str());
|
|
|
|
|
GELOGE(FAILED, "find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
string fixed_name = node_full_name.substr(0, pos);
|
|
|
|
|
pos = node_full_name.find("_", pos + kMbatchNodeNameMark.length());
|
|
|
|
|
if (pos != string::npos) {
|
|
|
|
|
fixed_name += node_full_name.substr(pos);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
peer_name_list.insert(fixed_name + ":" + std::to_string(in_anchor->GetIdx()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -336,14 +341,19 @@ Status SubgraphConstMigrationPass::AppendParallelNode(const NodePtr &func_node,
|
|
|
|
|
/// @param [in] outputs: Parent index of Node output.
|
|
|
|
|
/// @return 0: SUCCESS / others: FAILED
|
|
|
|
|
///
|
|
|
|
|
Status SubgraphConstMigrationPass::DetachParallelNode(const map<string, NodePtr> &const_nodes,
|
|
|
|
|
Status SubgraphConstMigrationPass::DetachParallelNode(const ComputeGraphPtr &graph,
|
|
|
|
|
const map<string, NodePtr> &const_nodes,
|
|
|
|
|
const NodePtr &const_node, const NodePtr &data_node) {
|
|
|
|
|
// Break Data and Move node.
|
|
|
|
|
const auto &in_anchor = const_node->GetInControlAnchor();
|
|
|
|
|
const auto out_anchors = in_anchor->GetPeerOutControlAnchors();
|
|
|
|
|
for (const auto out_anchor : out_anchors) {
|
|
|
|
|
GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed");
|
|
|
|
|
GELOGI("Remove Edge: %s %s", out_anchor->GetOwnerNode()->GetName().c_str(), const_node->GetName().c_str());
|
|
|
|
|
const auto owner_node = out_anchor->GetOwnerNode();
|
|
|
|
|
GELOGI("Remove Edge: %s %s", owner_node->GetName().c_str(), const_node->GetName().c_str());
|
|
|
|
|
if (owner_node->GetInAllNodes().empty() && owner_node->GetOutAllNodes().empty()) {
|
|
|
|
|
graph->RemoveNode(owner_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const auto &ctrl_anchor = const_node->GetOutControlAnchor();
|
|
|
|
@ -454,7 +464,7 @@ Status SubgraphConstMigrationPass::MoveNodeToParent(const ComputeGraphPtr &graph
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (DetachParallelNode(item.second, move_node, it_data->second) != SUCCESS) {
|
|
|
|
|
if (DetachParallelNode(subgraph, item.second, move_node, it_data->second) != SUCCESS) {
|
|
|
|
|
GELOGE(FAILED, "Data: %s not found, index: %u", move_node->GetName().c_str(), parent_index);
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|