Fix const key

pull/899/head
zhangxiaokun 4 years ago
parent 464aff111e
commit 8c256bda4d

@ -23,7 +23,7 @@ namespace ge {
constexpr uint32_t kZeroIndex = 0;
constexpr uint32_t kCaseInputBase = 1;
constexpr uint32_t kInvalidParent = 0x7fffffffU;
const char *const kMbatchNodeNameMark = "_ascend_mbatch_batch_";
const string kMbatchNodeNameMark = "_ascend_mbatch_batch_";
bool IsSameConstNode(const NodePtr &src_node, const NodePtr &dst_node) {
if ((src_node == nullptr) && (dst_node == nullptr)) {
@ -164,11 +164,16 @@ Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &gra
string node_full_name = peer_node->GetName();
size_t pos = node_full_name.find(kMbatchNodeNameMark);
if (pos == string::npos) {
GELOGE(FAILED, "Cannot find: %s of multi-batch in node: %s", kMbatchNodeNameMark, node_full_name.c_str());
GELOGE(FAILED, "find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str());
return FAILED;
}
string fixed_name = node_full_name.substr(0, pos);
pos = node_full_name.find("_", pos + kMbatchNodeNameMark.length());
if (pos != string::npos) {
fixed_name += node_full_name.substr(pos);
}
peer_name_list.insert(fixed_name + ":" + std::to_string(in_anchor->GetIdx()));
}
@ -336,14 +341,19 @@ Status SubgraphConstMigrationPass::AppendParallelNode(const NodePtr &func_node,
/// @param [in] outputs: Parent index of Node output.
/// @return 0: SUCCESS / others: FAILED
///
Status SubgraphConstMigrationPass::DetachParallelNode(const map<string, NodePtr> &const_nodes,
Status SubgraphConstMigrationPass::DetachParallelNode(const ComputeGraphPtr &graph,
const map<string, NodePtr> &const_nodes,
const NodePtr &const_node, const NodePtr &data_node) {
// Break Data and Move node.
const auto &in_anchor = const_node->GetInControlAnchor();
const auto out_anchors = in_anchor->GetPeerOutControlAnchors();
for (const auto out_anchor : out_anchors) {
GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed");
GELOGI("Remove Edge: %s %s", out_anchor->GetOwnerNode()->GetName().c_str(), const_node->GetName().c_str());
const auto owner_node = out_anchor->GetOwnerNode();
GELOGI("Remove Edge: %s %s", owner_node->GetName().c_str(), const_node->GetName().c_str());
if (owner_node->GetInAllNodes().empty() && owner_node->GetOutAllNodes().empty()) {
graph->RemoveNode(owner_node);
}
}
const auto &ctrl_anchor = const_node->GetOutControlAnchor();
@ -454,7 +464,7 @@ Status SubgraphConstMigrationPass::MoveNodeToParent(const ComputeGraphPtr &graph
return FAILED;
}
if (DetachParallelNode(item.second, move_node, it_data->second) != SUCCESS) {
if (DetachParallelNode(subgraph, item.second, move_node, it_data->second) != SUCCESS) {
GELOGE(FAILED, "Data: %s not found, index: %u", move_node->GetName().c_str(), parent_index);
return FAILED;
}

@ -119,8 +119,8 @@ class SubgraphConstMigrationPass : public GraphPass {
/// @param [in] outputs: Parent index of Node output.
/// @return 0: SUCCESS / others: FAILED
///
Status DetachParallelNode(const map<string, NodePtr> &const_nodes, const NodePtr &const_node,
const NodePtr &data_node);
Status DetachParallelNode(const ComputeGraphPtr &graph, const map<string, NodePtr> &const_nodes,
const NodePtr &const_node, const NodePtr &data_node);
///
/// @ingroup ge

Loading…
Cancel
Save