!4030 replace unused parameter in graph inputs

Merge pull request !4030 from laiyongqiang/replace_parameter
pull/4030/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit c55c0e0f0c

@ -261,13 +261,14 @@ void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph,
}
}
EraseAssign(std::make_shared<ReferenceCounter>(parameter_count), all_nodes, para_to_written_node, root_graph);
EraseAssign(std::make_shared<ReferenceCounter>(parameter_count), all_nodes, para_to_written_node, root_graph,
graph_list);
}
void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count,
const std::set<CNodePtr> &all_nodes,
const std::map<AnfNodePtr, CNodePtr> &para_to_written_node,
NotNull<KernelGraphPtr> root_graph) {
NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list) {
std::vector<CNodePtr> exec_order = root_graph->execution_order();
while (parameter_count->HasValidElem()) {
auto [para, read, written] = parameter_count->GetOneValidElem();
@ -292,6 +293,8 @@ void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> paramete
if (visit_source->isa<Parameter>()) {
parameter_count->AddReadCount(visit_source, read - 1);
}
// replace parameter in node
for (auto &node : all_nodes) {
for (size_t i = 0; i < node->size(); ++i) {
if (node->input(i) == para) {
@ -300,6 +303,14 @@ void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> paramete
}
}
}
// replace parameter in graph input
for (auto &g : graph_list) {
auto child_graph_inputs = g->MutableInputs();
std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), para, source);
MS_LOG_INFO << "Replace parameter " << para->DebugString() << " by " << source->DebugString() << " in graph "
<< g->graph_id() << " inputs";
}
}
root_graph->set_execution_order(exec_order);
}

@ -47,7 +47,7 @@ class AscendControlParser {
static void EraseParameter(NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list);
static void EraseAssign(std::shared_ptr<ReferenceCounter> parameter_count, const std::set<CNodePtr> &all_nodes,
const std::map<AnfNodePtr, CNodePtr> &para_to_written_node,
NotNull<KernelGraphPtr> root_graph);
NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list);
static void EraseLabel(NotNull<KernelGraphPtr> root_graph);
static void ChildGraphDataAssign(NotNull<KernelGraphPtr> kg,
const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list,

@ -153,9 +153,6 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear();
AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear();
UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear();
// add make_tuple to the output graph
@ -178,7 +175,10 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
debugger_->PreExecute(root_graph);
}
SetSummaryNodes(root_graph.get());
// alloc mem
// Alloc memory for child graph's inputs
AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo));
memo.clear();
// Alloc memory for root graph's inputs and node's outputs, workspace
MemoryAlloc(root_graph.get());
// generate and load task into device
Load(root_graph);

@ -337,6 +337,8 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) {
MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
}
MS_LOG(INFO) << "Malloc Input for graph " << graph->graph_id() << ", node: " << item->fullname_with_scope()
<< " index: " << index << " size: " << tensor_size;
AnfAlgo::SetOutputAddr(address, index, item.get());
}
}

Loading…
Cancel
Save