fix assign-node-wipe bug

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
pull/3125/head
zhoufeng 5 years ago
parent 8300802b95
commit 3d4d434fac

@ -18,9 +18,12 @@
#include <utility> #include <utility>
#include <memory> #include <memory>
#include <algorithm> #include <algorithm>
#include <string>
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "utils/union_find_set.h" #include "utils/union_find_set.h"
#include "runtime/device/ascend/ascend_label_assign.h" #include "runtime/device/ascend/ascend_label_assign.h"
#include "utils/context/ms_context.h"
#include "debug/anf_ir_dump.h"
static constexpr size_t kCNodePrim = 0; static constexpr size_t kCNodePrim = 0;
static constexpr size_t kCNodeCallArg = 1; static constexpr size_t kCNodeCallArg = 1;
@ -248,10 +251,14 @@ void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph,
} }
MS_LOG(INFO) << "Erase " << assign_node->DebugString(5); MS_LOG(INFO) << "Erase " << assign_node->DebugString(5);
EraseNodeFromExecOrder(assign_node, NOT_NULL(&exec_order)); EraseNodeFromExecOrder(assign_node, NOT_NULL(&exec_order));
auto source = assign_node->input(kCNodeAssignSource);
auto source = AnfAlgo::VisitKernelWithReturnType(assign_node->input(kCNodeAssignSource), 0).first; MS_EXCEPTION_IF_NULL(source);
parameter_count.AddReadCount(source, -1); auto visit_source = AnfAlgo::VisitKernelWithReturnType(source, 0).first;
parameter_count.AddWriteCount(para, -1); parameter_count.AddWriteCount(para, -1);
parameter_count.AddReadCount(para, -1);
if (visit_source->isa<Parameter>()) {
parameter_count.AddReadCount(visit_source, read - 1);
}
for (auto &node : all_nodes) { for (auto &node : all_nodes) {
for (size_t i = 0; i < node->size(); ++i) { for (size_t i = 0; i < node->size(); ++i) {
if (node->input(i) == para) { if (node->input(i) == para) {
@ -260,8 +267,6 @@ void AscendControlParser::EraseParameter(NotNull<KernelGraphPtr> root_graph,
} }
} }
} }
parameter_count.AddReadCount(source, 1);
parameter_count.AddReadCount(para, -1);
} }
root_graph->set_execution_order(exec_order); root_graph->set_execution_order(exec_order);
} }
@ -318,6 +323,17 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
(void)RecurseGraph(root_graph, NOT_NULL(&memo)); (void)RecurseGraph(root_graph, NOT_NULL(&memo));
EraseParameter(root_graph, memo); EraseParameter(root_graph, memo);
EraseLabel(root_graph); EraseLabel(root_graph);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto save_graphs_path = context_ptr->save_graphs_path();
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
if (context_ptr->save_graphs_flag()) {
std::string file_path = save_graphs_path + "/after_erase_label_and_parameter.ir";
DumpIR(file_path, root_graph.get());
}
} }
std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlParser::ParseCallNode( std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlParser::ParseCallNode(

Loading…
Cancel
Save