|
|
|
@ -261,13 +261,14 @@ void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
EraseAssign(std::make_shared<ReferenceCounter>(parameter_count), all_nodes, para_to_written_node, root_graph);
|
|
|
|
|
EraseAssign(std::make_shared<ReferenceCounter>(parameter_count), all_nodes, para_to_written_node, root_graph,
|
|
|
|
|
graph_list);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count,
|
|
|
|
|
const std::set<CNodePtr> &all_nodes,
|
|
|
|
|
const std::map<AnfNodePtr, CNodePtr> ¶_to_written_node,
|
|
|
|
|
NotNull<KernelGraphPtr> root_graph) {
|
|
|
|
|
NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list) {
|
|
|
|
|
std::vector<CNodePtr> exec_order = root_graph->execution_order();
|
|
|
|
|
while (parameter_count->HasValidElem()) {
|
|
|
|
|
auto [para, read, written] = parameter_count->GetOneValidElem();
|
|
|
|
@ -292,6 +293,8 @@ void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> paramete
|
|
|
|
|
if (visit_source->isa<Parameter>()) {
|
|
|
|
|
parameter_count->AddReadCount(visit_source, read - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// replace parameter in node
|
|
|
|
|
for (auto &node : all_nodes) {
|
|
|
|
|
for (size_t i = 0; i < node->size(); ++i) {
|
|
|
|
|
if (node->input(i) == para) {
|
|
|
|
@ -300,6 +303,14 @@ void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> paramete
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// replace parameter in graph input
|
|
|
|
|
for (auto &g : graph_list) {
|
|
|
|
|
auto child_graph_inputs = g->MutableInputs();
|
|
|
|
|
std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), para, source);
|
|
|
|
|
MS_LOG_INFO << "Replace parameter " << para->DebugString() << " by " << source->DebugString() << " in graph "
|
|
|
|
|
<< g->graph_id() << " inputs";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
root_graph->set_execution_order(exec_order);
|
|
|
|
|
}
|
|
|
|
|