Fix const key

pull/899/head
zhangxiaokun 5 years ago
parent 464aff111e
commit 8c256bda4d

@ -23,7 +23,7 @@ namespace ge {
constexpr uint32_t kZeroIndex = 0; constexpr uint32_t kZeroIndex = 0;
constexpr uint32_t kCaseInputBase = 1; constexpr uint32_t kCaseInputBase = 1;
constexpr uint32_t kInvalidParent = 0x7fffffffU; constexpr uint32_t kInvalidParent = 0x7fffffffU;
const char *const kMbatchNodeNameMark = "_ascend_mbatch_batch_"; const string kMbatchNodeNameMark = "_ascend_mbatch_batch_";
bool IsSameConstNode(const NodePtr &src_node, const NodePtr &dst_node) { bool IsSameConstNode(const NodePtr &src_node, const NodePtr &dst_node) {
if ((src_node == nullptr) && (dst_node == nullptr)) { if ((src_node == nullptr) && (dst_node == nullptr)) {
@ -164,11 +164,16 @@ Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &gra
string node_full_name = peer_node->GetName(); string node_full_name = peer_node->GetName();
size_t pos = node_full_name.find(kMbatchNodeNameMark); size_t pos = node_full_name.find(kMbatchNodeNameMark);
if (pos == string::npos) { if (pos == string::npos) {
GELOGE(FAILED, "Cannot find: %s of multi-batch in node: %s", kMbatchNodeNameMark, node_full_name.c_str()); GELOGE(FAILED, "find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str());
return FAILED; return FAILED;
} }
string fixed_name = node_full_name.substr(0, pos); string fixed_name = node_full_name.substr(0, pos);
pos = node_full_name.find("_", pos + kMbatchNodeNameMark.length());
if (pos != string::npos) {
fixed_name += node_full_name.substr(pos);
}
peer_name_list.insert(fixed_name + ":" + std::to_string(in_anchor->GetIdx())); peer_name_list.insert(fixed_name + ":" + std::to_string(in_anchor->GetIdx()));
} }
@ -336,14 +341,19 @@ Status SubgraphConstMigrationPass::AppendParallelNode(const NodePtr &func_node,
/// @param [in] outputs: Parent index of Node output. /// @param [in] outputs: Parent index of Node output.
/// @return 0: SUCCESS / others: FAILED /// @return 0: SUCCESS / others: FAILED
/// ///
Status SubgraphConstMigrationPass::DetachParallelNode(const map<string, NodePtr> &const_nodes, Status SubgraphConstMigrationPass::DetachParallelNode(const ComputeGraphPtr &graph,
const map<string, NodePtr> &const_nodes,
const NodePtr &const_node, const NodePtr &data_node) { const NodePtr &const_node, const NodePtr &data_node) {
// Break Data and Move node. // Break Data and Move node.
const auto &in_anchor = const_node->GetInControlAnchor(); const auto &in_anchor = const_node->GetInControlAnchor();
const auto out_anchors = in_anchor->GetPeerOutControlAnchors(); const auto out_anchors = in_anchor->GetPeerOutControlAnchors();
for (const auto out_anchor : out_anchors) { for (const auto out_anchor : out_anchors) {
GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed"); GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, in_anchor), "Remove edge failed");
GELOGI("Remove Edge: %s %s", out_anchor->GetOwnerNode()->GetName().c_str(), const_node->GetName().c_str()); const auto owner_node = out_anchor->GetOwnerNode();
GELOGI("Remove Edge: %s %s", owner_node->GetName().c_str(), const_node->GetName().c_str());
if (owner_node->GetInAllNodes().empty() && owner_node->GetOutAllNodes().empty()) {
graph->RemoveNode(owner_node);
}
} }
const auto &ctrl_anchor = const_node->GetOutControlAnchor(); const auto &ctrl_anchor = const_node->GetOutControlAnchor();
@ -454,7 +464,7 @@ Status SubgraphConstMigrationPass::MoveNodeToParent(const ComputeGraphPtr &graph
return FAILED; return FAILED;
} }
if (DetachParallelNode(item.second, move_node, it_data->second) != SUCCESS) { if (DetachParallelNode(subgraph, item.second, move_node, it_data->second) != SUCCESS) {
GELOGE(FAILED, "Data: %s not found, index: %u", move_node->GetName().c_str(), parent_index); GELOGE(FAILED, "Data: %s not found, index: %u", move_node->GetName().c_str(), parent_index);
return FAILED; return FAILED;
} }

@ -119,8 +119,8 @@ class SubgraphConstMigrationPass : public GraphPass {
/// @param [in] outputs: Parent index of Node output. /// @param [in] outputs: Parent index of Node output.
/// @return 0: SUCCESS / others: FAILED /// @return 0: SUCCESS / others: FAILED
/// ///
Status DetachParallelNode(const map<string, NodePtr> &const_nodes, const NodePtr &const_node, Status DetachParallelNode(const ComputeGraphPtr &graph, const map<string, NodePtr> &const_nodes,
const NodePtr &data_node); const NodePtr &const_node, const NodePtr &data_node);
/// ///
/// @ingroup ge /// @ingroup ge

Loading…
Cancel
Save