!12800 Optimize the compile performance in Parser, FG, Manager and Renormalize.

From: @zh_qh
Reviewed-by: 
Signed-off-by:
pull/12800/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 9b527cc9dd

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2020 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -610,24 +610,7 @@ void AnfExporter::OutputOrderList(std::ofstream &ofs, const FuncGraphPtr &func_g
constexpr int width = 4; constexpr int width = 4;
ofs << "# order:\n"; ofs << "# order:\n";
int i = 1; int i = 1;
auto &isolate_nodes = func_graph->isolate_nodes();
for (auto &node : order_list) { for (auto &node : order_list) {
bool is_isolate = (isolate_nodes.find(node) != isolate_nodes.end());
const std::string isolate_str = (is_isolate ? " # isolate" : "");
ofs << '#' << std::setw(width) << i << ": " << node->DebugString() << isolate_str << '\n';
++i;
}
}
void AnfExporter::OutputIsolateNodes(std::ofstream &ofs, const FuncGraphPtr &func_graph) {
auto &isolate_nodes = func_graph->isolate_nodes();
if (isolate_nodes.empty()) {
return;
}
constexpr int width = 4;
ofs << "# isolate nodes:\n";
int i = 1;
for (auto &node : isolate_nodes) {
ofs << '#' << std::setw(width) << i << ": " << node->DebugString() << '\n'; ofs << '#' << std::setw(width) << i << ": " << node->DebugString() << '\n';
++i; ++i;
} }
@ -670,7 +653,6 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &fun
ofs << "}\n"; ofs << "}\n";
OutputOrderList(ofs, func_graph); OutputOrderList(ofs, func_graph);
OutputIsolateNodes(ofs, func_graph);
} }
void AnfExporter::ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) { void AnfExporter::ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) {

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -98,7 +98,6 @@ class AnfExporter {
void OutputStatementComment(std::ofstream &ofs, const CNodePtr &node); void OutputStatementComment(std::ofstream &ofs, const CNodePtr &node);
virtual void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph); virtual void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph);
void OutputOrderList(std::ofstream &ofs, const FuncGraphPtr &func_graph); void OutputOrderList(std::ofstream &ofs, const FuncGraphPtr &func_graph);
void OutputIsolateNodes(std::ofstream &ofs, const FuncGraphPtr &func_graph);
int param_index; int param_index;
OrderedSet<FuncGraphPtr> func_graph_set{}; OrderedSet<FuncGraphPtr> func_graph_set{};

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2020 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -36,7 +36,7 @@
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
namespace mindspore { namespace mindspore {
// namespace to support debug trace infomation // namespace to support debug trace information
namespace trace { namespace trace {
using abstract::AbstractBasePtr; using abstract::AbstractBasePtr;
using abstract::AnalysisContextPtr; using abstract::AnalysisContextPtr;
@ -167,7 +167,7 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(engine_); MS_EXCEPTION_IF_NULL(engine_);
auto cfg = engine_->MakeConfig(node, cur_ctx_); auto cfg = engine_->MakeConfig(node, cur_ctx_);
auto ret = engine_->cache().GetValue(cfg); auto ret = engine_->analysis_cache().GetValue(cfg);
if (ret == nullptr) { if (ret == nullptr) {
return "Undefined"; return "Undefined";
} }
@ -180,7 +180,7 @@ AbstractBasePtr AnalyzedFuncGraphExporter::GetNodeAbstract(const AnfNodePtr &nod
} }
MS_EXCEPTION_IF_NULL(engine_); MS_EXCEPTION_IF_NULL(engine_);
auto cfg = engine_->MakeConfig(node, cur_ctx_); auto cfg = engine_->MakeConfig(node, cur_ctx_);
auto ret = engine_->cache().GetValue(cfg); auto ret = engine_->analysis_cache().GetValue(cfg);
return ret == nullptr ? nullptr : ret->abstract(); return ret == nullptr ? nullptr : ret->abstract();
} }
@ -439,7 +439,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename,
param_index = 1; param_index = 1;
auto tagged_func_graphs = CalcTaggedFuncGraphs(); auto tagged_func_graphs = CalcTaggedFuncGraphs();
// first output graph on the analysis stack // 1. Output graph on the analysis stack
for (const auto &node_cfg : node_cfgs) { for (const auto &node_cfg : node_cfgs) {
auto ctx = node_cfg->context(); auto ctx = node_cfg->context();
if (engine_ == nullptr) { if (engine_ == nullptr) {
@ -448,7 +448,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename,
if (context_map_.insert({ctx, false}).second) { if (context_map_.insert({ctx, false}).second) {
context_vec_.push_back(ctx); context_vec_.push_back(ctx);
} }
// the graph has already been printed // If the graph has already been printed
if (context_map_[ctx]) { if (context_map_[ctx]) {
continue; continue;
} }
@ -456,7 +456,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename,
auto fg = ctx->func_graph(); auto fg = ctx->func_graph();
// set current context // Set current context
cur_ctx_ = ctx; cur_ctx_ = ctx;
tagged_cnodes_ = tagged_func_graphs[fg]; tagged_cnodes_ = tagged_func_graphs[fg];
ExportOneFuncGraph(ofs, fg); ExportOneFuncGraph(ofs, fg);
@ -465,10 +465,10 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename,
tagged_cnodes_.clear(); tagged_cnodes_.clear();
// print seperator between function graphs on analyzed graph call stack and others // Print separator between function graphs on analyzed graph call stack and others
ofs << "#===============================================================================\n\n\n"; ofs << "#===============================================================================\n\n\n";
// second output other graphs // 2. Output other graphs
size_t ctx_idx = 0; size_t ctx_idx = 0;
while (ctx_idx < context_vec_.size()) { while (ctx_idx < context_vec_.size()) {
auto ctx = context_vec_[ctx_idx++]; auto ctx = context_vec_[ctx_idx++];

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -238,27 +238,6 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
return changes; return changes;
} }
bool SubstitutionList::ApplySubstitutionsToIRForIsolate(const OptimizerPtr &optimizer) const {
const auto &manager = optimizer->manager();
const auto &nodes = manager->isolate_nodes();
bool changes = false;
bool loop = true;
while (loop) {
loop = false;
std::for_each(list_.cbegin(), list_.cend(), [&](const auto &substitution) {
std::for_each(nodes.cbegin(), nodes.cend(), [&](const auto &node) {
bool change = ApplySubstitutionToIR(optimizer, node, substitution);
changes = changes || change;
loop = loop || change;
});
});
if (is_once_) {
break;
}
}
return changes;
}
bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const { bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const {
// Add for substitution status counting // Add for substitution status counting
size_t space = 0; size_t space = 0;
@ -336,18 +315,6 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize
} else { } else {
changes = ApplySubstitutionsToIR(optimizer, func_graph); changes = ApplySubstitutionsToIR(optimizer, func_graph);
} }
bool has_isolate = !manager->isolate_nodes().empty();
if (has_isolate) {
#ifdef ENABLE_PROFILE
double t = GetTime();
#endif
bool change = ApplySubstitutionsToIRForIsolate(optimizer);
changes = changes || change;
#ifdef ENABLE_PROFILE
MsProfile::StatTime("opt.isolate.transform." + optimizer->name(), GetTime() - t);
#endif
}
return changes; return changes;
} }
} // namespace opt } // namespace opt

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -73,7 +73,7 @@ class SubstitutionList {
bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &sub) const; bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &sub) const;
bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
bool ApplySubstitutionsToIRForIsolate(const OptimizerPtr &optimizer) const;
std::vector<SubstitutionPtr> list_; std::vector<SubstitutionPtr> list_;
// a flag to mark this list of Substitution can only be executed only once // a flag to mark this list of Substitution can only be executed only once
bool is_once_; bool is_once_;

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -163,7 +163,7 @@ bool CombineLikeGraphs(const ResourcePtr &res) {
auto &graphs = it.second; auto &graphs = it.second;
MS_LOG(DEBUG) << "Start combine like graph:" << it.first << ", size:" << graphs.size(); MS_LOG(DEBUG) << "Start combine like graph:" << it.first << ", size:" << graphs.size();
auto fg = graphs[0]; auto fg = graphs[0];
FuncGraphPtrList func_graphs = {fg}; FuncGraphVector func_graphs = {fg};
ClonerPtr cloner = std::make_shared<Cloner>(func_graphs, false, false, true, std::make_shared<TraceCopy>(), ClonerPtr cloner = std::make_shared<Cloner>(func_graphs, false, false, true, std::make_shared<TraceCopy>(),
std::make_shared<TraceCombileLikeGraphs>()); std::make_shared<TraceCombileLikeGraphs>());
cloner->Run(); cloner->Run();

@ -1,7 +1,7 @@
/** /**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
* *
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -37,7 +37,7 @@ FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) {
void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); } void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); }
static bool CanBeIsolateNode(const std::string &var_name, const AnfNodePtr &node) { static bool CanBeIsolatedNode(const std::string &var_name, const AnfNodePtr &node) {
auto cnode = dyn_cast<CNode>(node); auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr || cnode->inputs().empty()) { if (cnode == nullptr || cnode->inputs().empty()) {
// Not a valid cnode, can not be isolate node. // Not a valid cnode, can not be isolate node.
@ -46,7 +46,7 @@ static bool CanBeIsolateNode(const std::string &var_name, const AnfNodePtr &node
auto prim = GetValueNode<PrimitivePtr>(cnode->inputs().at(0)); auto prim = GetValueNode<PrimitivePtr>(cnode->inputs().at(0));
if (prim == nullptr) { if (prim == nullptr) {
// Not a primitive cnode, it may have side effects or not, // Not a primitive cnode, it may have side effects or not,
// we add it as an isolate node if its name is not '_' or empty. // We add it as an isolate node if its name is not '_' or empty.
// this means that code like: // this means that code like:
// _ = func_call() // _ = func_call()
// will be ignored even if func_call() has side effects. // will be ignored even if func_call() has side effects.
@ -58,7 +58,7 @@ static bool CanBeIsolateNode(const std::string &var_name, const AnfNodePtr &node
return has_effects; return has_effects;
} }
// write variable records the variable name to corresponding node // Write variable records the variable name to corresponding node
void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) { void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) {
MS_LOG(DEBUG) << func_graph_->ToString() << " write var " << var_name << " with node " << node->DebugString(); MS_LOG(DEBUG) << func_graph_->ToString() << " write var " << var_name << " with node " << node->DebugString();
auto [iter, is_new_name] = vars_.emplace(var_name, std::make_pair(node, false)); auto [iter, is_new_name] = vars_.emplace(var_name, std::make_pair(node, false));
@ -67,18 +67,24 @@ void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr
// add it as an isolate node. for example: // add it as an isolate node. for example:
// a = print(x) // a = print(x)
// a = print(y) // a = print(y)
// when we write variable 'a = print(y)', // When we write variable 'a = print(y)',
// the cnode 'print(x)' should added as an isolate node. // the cnode 'print(x)' should added as an isolate node.
if (!iter->second.second && CanBeIsolateNode(var_name, iter->second.first)) { auto is_used = iter->second.second;
func_graph_->AddIsolateNode(iter->second.first); auto hidden_node = iter->second.first;
auto is_isolated = CanBeIsolatedNode(var_name, hidden_node);
MS_LOG(INFO) << "Isolated node found(Hidden), hidden_node: " << hidden_node->DebugString(2) << " is hidden by "
<< node->DebugString(2) << " with the same name, var_name: " << var_name
<< ", is_isolated: " << is_isolated << ", !is_used: " << !is_used;
if (!is_used && is_isolated) {
AddIsolatedNode(hidden_node);
} }
iter->second = std::make_pair(node, false); iter->second = std::make_pair(node, false);
} }
} }
// read variable from predecessors // Read variable from predecessors
AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) { AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
// get var node if it is found // Get var node if it is found
auto found = vars_.find(var); auto found = vars_.find(var);
if (found != vars_.end()) { if (found != vars_.end()) {
auto &node = found->second.first; auto &node = found->second.first;
@ -91,7 +97,7 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
} }
return node; return node;
} }
// get var from predecessor block ,if can't get the make a resolve node to it // Get var from predecessor block ,if can't get the make a resolve node to it
if (matured_) { if (matured_) {
// If only one predecessor block, read the definition of var from it. // If only one predecessor block, read the definition of var from it.
if (prev_blocks_.size() == 1) { if (prev_blocks_.size() == 1) {
@ -99,7 +105,7 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
MS_EXCEPTION_IF_NULL(block); MS_EXCEPTION_IF_NULL(block);
return block->ReadVariable(var); return block->ReadVariable(var);
} else if (prev_blocks_.empty()) { } else if (prev_blocks_.empty()) {
// get namespace and make Resolve // Get namespace and make Resolve
auto it = var_to_resolve_.find(var); auto it = var_to_resolve_.find(var);
if (it != var_to_resolve_.end()) { if (it != var_to_resolve_.end()) {
return it->second; return it->second;
@ -181,7 +187,7 @@ AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const Symb
return node; return node;
} }
// add input for the block's phi parameter // Add input for the block's phi parameter
void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
std::string var = phi_nodes_[phi]; std::string var = phi_nodes_[phi];
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var; MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var;
@ -227,7 +233,7 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const Parame
} }
// Check if there is removable unnecessary phi node in this graph. // Check if there is removable unnecessary phi node in this graph.
// as per the FIRM TR 3.2, a phi node can be remove if: // As per the FIRM TR 3.2, a phi node can be remove if:
// <Quote> // <Quote>
// If all arguments of a φ-function are the same value s or the φfunction itself, // If all arguments of a φ-function are the same value s or the φfunction itself,
// then we remove the φ-function and let all users directly uses. We call such a // then we remove the φ-function and let all users directly uses. We call such a
@ -255,7 +261,7 @@ bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
if (arg_node != nullptr) { if (arg_node != nullptr) {
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " phi " << phi->ToString() << " can be replaced with " MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " phi " << phi->ToString() << " can be replaced with "
<< arg_node->DebugString(); << arg_node->DebugString();
// replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1." // Replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1."
WriteVariable(var, arg_node); WriteVariable(var, arg_node);
removable_phis_[phi] = arg_node; removable_phis_[phi] = arg_node;
resolve_to_removable_phis_[arg_node] = phi; resolve_to_removable_phis_[arg_node] = phi;
@ -326,6 +332,8 @@ void FunctionBlock::Jump(const FunctionBlockPtr &target_block, AnfNodePtr node)
jumps_[target_block.get()] = jump; jumps_[target_block.get()] = jump;
target_block->AddPrevBlock(shared_from_this()); target_block->AddPrevBlock(shared_from_this());
func_graph()->set_output(jump); func_graph()->set_output(jump);
// Attach all isolated nodes.
AttachIsolatedNodesBeforeReturn();
} }
// Perform a conditional jump using switch operation. // Perform a conditional jump using switch operation.
@ -341,6 +349,8 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr
NewValueNode(false_block->func_graph())}); NewValueNode(false_block->func_graph())});
CNodePtr switch_app_new = func_graph()->NewCNodeInOrder({switch_app}); CNodePtr switch_app_new = func_graph()->NewCNodeInOrder({switch_app});
func_graph()->set_output(switch_app_new); func_graph()->set_output(switch_app_new);
// Attach all isolated nodes.
AttachIsolatedNodesBeforeReturn();
} }
// Create cnode for the assign statement like 'self.target = source'. // Create cnode for the assign statement like 'self.target = source'.
@ -349,11 +359,12 @@ void FunctionBlock::SetStateAssign(const AnfNodePtr &target, const AnfNodePtr &s
const std::string primitive_name("assign"); const std::string primitive_name("assign");
const std::string module_name("mindspore.ops.functional"); const std::string module_name("mindspore.ops.functional");
ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true)); ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true));
auto assign = func_graph_->NewCNodeInOrder({assign_op, target, source}); auto assign_node = func_graph_->NewCNodeInOrder({assign_op, target, source});
func_graph_->AddIsolateNode(assign); MS_LOG(DEBUG) << "Isolated node found(Assign), assign_node: " << assign_node->DebugString(2);
AddIsolatedNode(assign_node);
} }
void FunctionBlock::FindIsolateVariables() { void FunctionBlock::FindIsolatedNodes() {
// //
// Search isolate nodes from variables, for example, // Search isolate nodes from variables, for example,
// variable 'a' is an isolate node in below code: // variable 'a' is an isolate node in below code:
@ -374,7 +385,7 @@ void FunctionBlock::FindIsolateVariables() {
used.emplace(node); used.emplace(node);
} }
} }
// Add isolate nodes which is unused var but not found in used set. // Add isolated nodes which is unused var but not found in used set.
for (const auto &var : vars_) { for (const auto &var : vars_) {
auto &node = var.second.first; auto &node = var.second.first;
bool is_used = var.second.second; bool is_used = var.second.second;
@ -382,11 +393,52 @@ void FunctionBlock::FindIsolateVariables() {
continue; continue;
} }
auto &var_name = var.first; auto &var_name = var.first;
if (used.find(node) == used.end() && CanBeIsolateNode(var_name, node)) { if (used.find(node) == used.end() && CanBeIsolatedNode(var_name, node)) {
func_graph_->AddIsolateNode(node); // We don't call AddIsolatedNode(node) anymore.
// If need, to call FindIsolatedNodes() in appropriate place.
MS_LOG(ERROR) << "Isolated node found(NoUse), node: " << node->DebugString(2) << ", var_name: " << var_name;
} }
} }
} }
void FunctionBlock::AddIsolatedNode(const AnfNodePtr &target) { isolated_nodes_.add(target); }
void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
if (isolated_nodes_.size() == 0) {
return;
}
std::vector<AnfNodePtr> states;
states.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (auto &node : isolated_nodes_) {
MS_LOG(DEBUG) << "Adding dependency, node: " << node->DebugString(2) << " in " << func_graph()->ToString();
states.emplace_back(node);
}
AnfNodePtr state = nullptr;
// If there are only make_tuple and another node in states(the states size is 2),
// do not need to make_tuple, just use the node.
if (states.size() == 2) {
state = states[1];
} else {
state = func_graph()->NewCNode(states);
}
AnfNodePtr old_output = nullptr;
auto return_node = func_graph()->get_return();
if (return_node) {
if (return_node->inputs().size() < 1) {
MS_LOG(EXCEPTION) << "Length of inputs of output node is less than 2";
}
old_output = return_node->input(1);
} else {
old_output = NewValueNode(kNone);
}
AnfNodePtr stop_grad_node = func_graph()->NewCNode({NewValueNode(prim::kPrimStopGradient), state});
AnfNodePtr depend_node = func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), old_output, stop_grad_node});
MS_LOG(INFO) << "Attached for side-effect nodes, depend_node: " << depend_node->DebugString()
<< ", state: " << state->DebugString(2);
func_graph()->set_output(depend_node, true);
}
} // namespace parse } // namespace parse
} // namespace mindspore } // namespace mindspore

@ -1,7 +1,7 @@
/** /**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
* *
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,7 +28,7 @@
#include <utility> #include <utility>
#include "pipeline/jit/parse/parse_base.h" #include "pipeline/jit/parse/parse_base.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/ordered_map.h" #include "utils/ordered_set.h"
namespace mindspore { namespace mindspore {
namespace parse { namespace parse {
@ -71,46 +71,51 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
AnfNodePtr MakeResolveOperation(const std::string &value); AnfNodePtr MakeResolveOperation(const std::string &value);
AnfNodePtr MakeResolve(const std::shared_ptr<NameSpace> &name_space, const std::shared_ptr<Symbol> &resolve_symbol); AnfNodePtr MakeResolve(const std::shared_ptr<NameSpace> &name_space, const std::shared_ptr<Symbol> &resolve_symbol);
const std::unordered_map<ParameterPtr, AnfNodePtr> &removable_phis() const { return removable_phis_; } const std::unordered_map<ParameterPtr, AnfNodePtr> &removable_phis() const { return removable_phis_; }
void FindIsolateVariables(); void FindIsolatedNodes();
void AddIsolatedNode(const AnfNodePtr &target);
void AttachIsolatedNodesBeforeReturn();
private: private:
// block graph // Block graph
FuncGraphPtr func_graph_; FuncGraphPtr func_graph_;
// the block's parser // Block parser
const Parser &parser_; const Parser &parser_;
// A block is matured if all its prev_blocks is processed // A block is matured if all its prev_blocks is processed
bool matured_; bool matured_;
// store the nest-level block // Store the nest-level block.
// refer to comments in Parser::func_block_list_; // Refer to comments in Parser::func_block_list_;
std::vector<FunctionBlock *> prev_blocks_; std::vector<FunctionBlock *> prev_blocks_;
// store args and variable's node, use a bool flag to indicate if the variable is used. // Store args and variable's node, use a bool flag to indicate if the variable is used.
std::map<std::string, std::pair<AnfNodePtr, bool>> vars_; std::map<std::string, std::pair<AnfNodePtr, bool>> vars_;
// phi_nodes map the parameter node to variable, it can be resolved if the block's predecessors are processed // Map the parameter node to variable, it can be resolved if the block's predecessors are processed
std::map<ParameterPtr, std::string> phi_nodes_; std::map<ParameterPtr, std::string> phi_nodes_;
// jumps map the successor block and the function call that perform jump // Jumps map the successor block and the function call that perform jump
// refer to comments in Parser::func_block_list_ that how to break the cyclic reference // Refer to comments in Parser::func_block_list_ that how to break the cyclic reference
std::map<FunctionBlock *, CNodePtr> jumps_; std::map<FunctionBlock *, CNodePtr> jumps_;
// keeps all removable phis which will be removed in one pass. // Keep all removable phis which will be removed in one pass.
std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_; std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_;
// Keeps the map for the resolve node to the removable phi node. // Keep the map for the resolve node to the removable phi node.
// For the case that ReadVariable returns a phi node although this phi node // For the case that ReadVariable returns a phi node although this phi node
// generated in the prev block is identified as removable. The other blocks // generated in the prev block is identified as removable. The other blocks
// should find this phi node. // should find this phi node.
std::unordered_map<AnfNodePtr, ParameterPtr> resolve_to_removable_phis_; std::unordered_map<AnfNodePtr, ParameterPtr> resolve_to_removable_phis_;
// hold declared global variables in function // Hold declared global variables in function
std::set<std::string> global_vars_; std::set<std::string> global_vars_;
// keeps the new made resolve symbol for the variable not found in vars_. // Keep new made resolve symbol for the variable not found in vars_.
std::unordered_map<std::string, AnfNodePtr> var_to_resolve_; std::unordered_map<std::string, AnfNodePtr> var_to_resolve_;
// Isolated nodes.
OrderedSet<AnfNodePtr> isolated_nodes_;
}; };
} // namespace parse } // namespace parse

File diff suppressed because it is too large Load Diff

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -197,7 +197,7 @@ AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const F
return cnode; return cnode;
} }
// transform the ValueTuple or ValueList of graph/primitive node to make tuple of const graph/primitive node // Transform the ValueTuple or ValueList of graph/primitive node to make tuple of const graph/primitive node
bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph,
const ValueNodePtr &value_node, AnfNodePtr *const transformed) { const ValueNodePtr &value_node, AnfNodePtr *const transformed) {
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value_node);
@ -208,18 +208,18 @@ bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const Func
// (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it, // (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it,
// So if has graph in list, try to replace the node with make tuple of graph value node. // So if has graph in list, try to replace the node with make tuple of graph value node.
// we do this because the graph manager won't investigate the graph inside valuetuple, // We do this because the graph manager won't investigate the graph inside valuetuple,
// change the vector of graph to be make_tuple of graph value node. // change the vector of graph to be make_tuple of graph value node.
// (2) the primitive valuetuple or valuelist may encounter to abstract error, make it all // (2) the primitive valuetuple or valuelist may encounter to abstract error, make it all
// independent nodes. // independent nodes.
auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec); auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec);
// replace the ret ptr to be make tuple of graph value node // Replace the ret ptr to be make tuple of graph value node
*transformed = node_tuple_graphs; *transformed = node_tuple_graphs;
return true; return true;
} }
// resolve the python obj, and if the resovled node is valuenode with graphs, add the graphs to manager // Resolve the python obj, and if the resovled node is valuenode with graphs, add the graphs to manager.
AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, const py::object &obj, AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, const py::object &obj,
const AnfNodePtr &node) { const AnfNodePtr &node) {
ScopeGuard scope_guard(node->scope()); ScopeGuard scope_guard(node->scope());
@ -233,7 +233,7 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons
manager->AddFuncGraph(new_fg); manager->AddFuncGraph(new_fg);
} }
// if the constant node is constant of vector of graph ,add graph to manager // If the constant node is constant of vector of graph, add graph to manager.
if (IsValueNode<ValueTuple>(resolved_node) || IsValueNode<ValueList>(resolved_node)) { if (IsValueNode<ValueTuple>(resolved_node) || IsValueNode<ValueList>(resolved_node)) {
(void)TransformVectorFuncValueNode(manager, node->func_graph(), resolved_node->cast<ValueNodePtr>(), (void)TransformVectorFuncValueNode(manager, node->func_graph(), resolved_node->cast<ValueNodePtr>(),
&resolved_node); &resolved_node);

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -426,16 +426,6 @@ bool AddCacheEmbeddingPass(const ResourcePtr &res) {
return true; return true;
} }
bool MergeDupGraphPass(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(res->manager());
if (res->manager()->func_graphs().size() <= 1) {
return true;
}
return MergeDuplicateGraphs(res->manager());
}
bool RemoveValueNodeDuplicationsPass(const ResourcePtr &res) { bool RemoveValueNodeDuplicationsPass(const ResourcePtr &res) {
if (res->func_graph() == nullptr) { if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "Remove value node duplications error."; MS_LOG(EXCEPTION) << "Remove value node duplications error.";

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2020 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -73,107 +73,5 @@ void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, Has
// Meet for the first time, append node to bucket. // Meet for the first time, append node to bucket.
bucket.emplace_back(node); bucket.emplace_back(node);
} }
size_t HashOfGraph(const FuncGraphPtr &fg) {
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
MS_LOG(DEBUG) << "TopSort for:" << fg->ToString();
std::unordered_map<AnfNodePtr, std::size_t> hashes;
auto &params = fg->parameters();
for (size_t i = 0; i < params.size(); i++) {
hashes[params[i]] = std::hash<std::string>{}("param" + std::to_string(i));
}
for (auto node : toposet) {
MS_EXCEPTION_IF_NULL(node);
if (hashes.find(node) != hashes.end()) {
continue;
}
std::size_t h = 0;
if (node->isa<ValueNode>()) {
ValueNodePtr value_node = node->cast<ValueNodePtr>();
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
if (IsValueNode<FuncGraph>(value_node)) {
auto v_fg = value->cast<FuncGraphPtr>();
h = value->hash();
} else if (IsValueNode<tensor::Tensor>(value_node)) {
// the tensor has same value has been replaced in duplicate value pass,
// so we use the value pointer here as an identifier
h = hash_combine(value->hash(), std::hash<Value *>{}(value.get()));
} else {
h = hash_combine(value->hash(), (opt::AbsOf(value_node)->hash()));
}
} else if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
size_t init = 0;
h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) {
return hash_combine(hash, hashes[node_in]);
});
} else if (node->isa<Parameter>()) {
h = node->hash();
} else {
MS_LOG(ERROR) << "Unknow node type";
}
hashes[node] = h;
}
return hashes[fg->get_return()];
}
bool IsCNodeGraph(const AnfNodePtr &node) {
if (node == nullptr || !node->isa<CNode>()) {
return false;
}
auto inp0 = node->cast<CNodePtr>()->input(0);
return IsValueNode<FuncGraph>(inp0);
}
bool MergeDuplicateGraphs(const FuncGraphManagerPtr manager) {
std::unordered_map<size_t, std::vector<FuncGraphPtr>> hash_graphs;
std::unordered_map<FuncGraphPtr, size_t> graph_hash;
for (auto fg : manager->func_graphs()) {
size_t h = HashOfGraph(fg);
graph_hash[fg] = h;
if (hash_graphs.find(h) == hash_graphs.end()) {
hash_graphs[h] = {fg};
} else {
hash_graphs[h].push_back(fg);
}
}
FuncGraphPairMapEquiv equiv_graph;
NodeMapEquiv equiv_node;
for (auto &fg : manager->func_graphs()) {
MS_LOG(DEBUG) << "Try Merge Graph:" << fg->ToString();
for (auto &item : fg->nodes()) {
if (!item->isa<CNode>()) {
continue;
}
auto &inputs = item->cast<CNodePtr>()->inputs();
for (size_t i = 0; i < inputs.size(); i++) {
if (!inputs[i]->isa<ValueNode>()) {
continue;
}
auto value_ptr = GetValueNode(inputs[i]);
auto v_fg = value_ptr->cast<FuncGraphPtr>();
if (v_fg == nullptr) {
continue;
}
auto &fg_vec = hash_graphs[graph_hash[v_fg]];
if (fg_vec.size() > 1) {
if (v_fg != fg_vec[0]) {
bool is_morphic = Isomorphic(v_fg, fg_vec[0], &equiv_graph, &equiv_node);
if (is_morphic) {
auto new_node = NewValueNode(fg_vec[0]);
MS_LOG(DEBUG) << "Replace graph node :" << inputs[i]->ToString() << " with:" << new_node->ToString();
manager->Replace(inputs[i], new_node);
}
}
}
}
}
}
return true;
}
} // namespace pipeline } // namespace pipeline
} // namespace mindspore } // namespace mindspore

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -28,9 +28,6 @@ using HashCache = std::unordered_map<std::size_t, std::vector<AnfNodePtr>>;
using HashValue = std::unordered_map<AnfNodePtr, std::size_t>; using HashValue = std::unordered_map<AnfNodePtr, std::size_t>;
void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache *hash_cache, HashValue *hash_value); void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache *hash_cache, HashValue *hash_value);
size_t HashOfGraph(const FuncGraphPtr &fg);
bool IsCNodeGraph(const AnfNodePtr &node);
bool MergeDuplicateGraphs(const FuncGraphManagerPtr manager);
} // namespace pipeline } // namespace pipeline
} // namespace mindspore } // namespace mindspore

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -846,7 +846,7 @@ class SideEffectFinder {
const SccPtr &GetScc(const FuncGraphPtr &func_graph) const { const SccPtr &GetScc(const FuncGraphPtr &func_graph) const {
auto found = scc_map_.find(func_graph); auto found = scc_map_.find(func_graph);
if (found == scc_map_.end()) { if (found == scc_map_.end()) {
MS_LOG(EXCEPTION) << "SCC not found for " << func_graph->ToString(); MS_LOG(EXCEPTION) << "SCC not found for " << func_graph->ToString() << "." << func_graph->debug_info()->get_id();
} }
return found->second; return found->second;
} }
@ -1014,7 +1014,6 @@ class AutoMonadConverter {
HandleCNodes(); HandleCNodes();
} }
// Clean up after conversion finished. // Clean up after conversion finished.
func_graph_->ClearIsolateNodes();
func_graph_->ClearOrderList(); func_graph_->ClearOrderList();
return has_effect_cnodes_; return has_effect_cnodes_;
} }
@ -1248,9 +1247,17 @@ class AutoMonadConverter {
} }
void InsertStateDepend(const AnfNodePtr &state) { void InsertStateDepend(const AnfNodePtr &state) {
auto output = GetGraphOutput();
// It's safe to handle isolated nodes here:
// Node: Depend(output, StopGrad)
if (IsPrimitiveCNode(output, prim::kPrimDepend) &&
IsPrimitiveCNode(output->cast<CNodePtr>()->input(2), prim::kPrimStopGradient)) {
// Replace Depend(orig_output, StopGrad) node with orig_output.
// After that, nodes may be eliminated if have no side effects.
output = output->cast<CNodePtr>()->input(1);
}
// Insert Depend node and set it as output. // Insert Depend node and set it as output.
auto depend = NewValueNode(prim::kPrimDepend); auto depend = NewValueNode(prim::kPrimDepend);
auto output = GetGraphOutput();
auto depend_cnode = func_graph_->NewCNode({depend, output, state}); auto depend_cnode = func_graph_->NewCNode({depend, output, state});
depend_cnode->set_abstract(output->abstract()); depend_cnode->set_abstract(output->abstract());
func_graph_->set_output(depend_cnode); func_graph_->set_output(depend_cnode);
@ -1374,12 +1381,6 @@ bool AutoMonad(const FuncGraphPtr &func_graph) {
bool fg_has_effects = AutoMonadConverter::Handle(fg, top_flag); bool fg_has_effects = AutoMonadConverter::Handle(fg, top_flag);
has_effects = has_effects || fg_has_effects; has_effects = has_effects || fg_has_effects;
} }
// Clear isolate nodes after auto-monad finished.
auto manager = func_graph->manager();
if (manager) {
manager->ClearIsolateNodes();
}
return has_effects; return has_effects;
} }
@ -1406,7 +1407,6 @@ bool ReAutoMonad(const FuncGraphPtr &func_graph) {
for (auto &fg : func_graph->func_graphs_used_total()) { for (auto &fg : func_graph->func_graphs_used_total()) {
if (!fg->has_flag(mindspore::kFuncGraphFlagReAutoMonad)) { if (!fg->has_flag(mindspore::kFuncGraphFlagReAutoMonad)) {
fg->ClearOrderList(); fg->ClearOrderList();
fg->ClearIsolateNodes();
} }
} }
changed = AutoMonad(func_graph); changed = AutoMonad(func_graph);
@ -1416,13 +1416,9 @@ bool ReAutoMonad(const FuncGraphPtr &func_graph) {
// After auto monad, Order List and Isolate nodes in graph and manager will be cleared. // After auto monad, Order List and Isolate nodes in graph and manager will be cleared.
} else { } else {
func_graph->ClearOrderList(); func_graph->ClearOrderList();
func_graph->ClearIsolateNodes();
for (auto &fg : func_graph->func_graphs_used_total()) { for (auto &fg : func_graph->func_graphs_used_total()) {
fg->ClearOrderList(); fg->ClearOrderList();
fg->ClearIsolateNodes();
} }
MS_EXCEPTION_IF_NULL(func_graph->manager());
func_graph->manager()->ClearIsolateNodes();
} }
return changed; return changed;
} }

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -83,11 +83,11 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
const auto &arg = args_spec_list[i]; const auto &arg = args_spec_list[i];
const auto &node = parameters[i]; const auto &node = parameters[i];
AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_); AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_);
engine->cache().set_value(conf, std::make_shared<EvalResult>(arg, nullptr)); engine->analysis_cache().set_value(conf, std::make_shared<EvalResult>(arg, nullptr));
} }
const AnfNodePtr &func_node = fg->get_return(); const AnfNodePtr &func_node = fg->get_return();
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << fg->ToString() MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString()
<< ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString() << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString()
<< ", current function call depth: " << engine->function_call_depth(); << ", current function call depth: " << engine->function_call_depth();
AbstractBasePtr ret_base = nullptr; AbstractBasePtr ret_base = nullptr;
@ -97,37 +97,20 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH) << MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
<< ", please call 'context.set_context(max_call_depth=value)' to adjust this value."; << ", please call 'context.set_context(max_call_depth=value)' to adjust this value.";
} }
// Analysis for isolate nodes first, as some validation check in FuncGraph is isolate nodes;
for (const auto &node : fg->GetIsolateNodesInOrder()) {
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
MS_LOG(DEBUG) << "Analysis isolate_node begin, func graph: " << fg.get() << fg->ToString()
<< ", node_conf: " << node_conf->ToString();
auto isolate_base = engine->GetEvaluatedValue(node_conf)->abstract();
MS_LOG(DEBUG) << "Analysis isolate_node end, func graph: " << fg.get() << fg->ToString()
<< ", node_conf: " << node_conf->ToString() << ", abstract: " << isolate_base->ToString();
}
const auto &all_nodes = TopoSort(func_node, SuccIncoming, [&fg](const AnfNodePtr &node) -> IncludeType { const auto &all_nodes = TopoSort(func_node, SuccIncoming, [&fg](const AnfNodePtr &node) -> IncludeType {
if (node->func_graph() != fg || node->isa<ValueNode>()) { if (node->func_graph() != fg || node->isa<ValueNode>()) {
return EXCLUDE; return EXCLUDE;
} }
return FOLLOW; return FOLLOW;
}); });
bool isolate_node_propagate_flag = false;
for (const auto &node : all_nodes) { for (const auto &node : all_nodes) {
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << fg->ToString() MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg << "/" << fg->ToString()
<< ", node_conf: " << node_conf->ToString(); << ", node_conf: " << node_conf->ToString();
auto node_eval_result = engine->GetEvaluatedValue(node_conf); auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
ret_base = node_eval_result->abstract(); ret_base = node_eval_result->abstract();
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString() MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg << "/" << fg->ToString()
<< ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString(); << ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString();
if (node->isa<CNode>()) {
isolate_node_propagate_flag |= node_eval_result->HasIsolateNodesPropagateCNodeFlag();
MS_LOG(DEBUG) << "Check isolate_nodes flag for node: " << node->DebugString()
<< ", abstract: " << ret_base->ToString()
<< ", flag: " << node_eval_result->HasIsolateNodesPropagateCNodeFlag();
}
} }
engine->DecreaseFunctionCallDepth(); engine->DecreaseFunctionCallDepth();
@ -138,12 +121,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
if (fg->stub()) { if (fg->stub()) {
ret_base = std::make_shared<AbstractUndetermined>(); ret_base = std::make_shared<AbstractUndetermined>();
} }
auto eval_result = std::make_shared<EvalResult>(ret_base, std::make_shared<AttrValueMap>()); return std::make_shared<EvalResult>(ret_base, nullptr);
if (isolate_node_propagate_flag) {
eval_result->SetIsolateNodesPropagateCNodeFlag(true);
eval_result->SetIsolateNodesPropagateFuncGraphFlag(true);
}
return eval_result;
} }
AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
@ -280,15 +258,15 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr { [](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue()->abstract(); return conf->ObtainEvalResult()->abstract();
}); });
args_spec_list = NormalizeArgs(args_spec_list); args_spec_list = NormalizeArgs(args_spec_list);
args_spec_list = BroadenUndeterminedArgs(args_spec_list); args_spec_list = BroadenUndeterminedArgs(args_spec_list);
trace::TraceGraphEvalEnter(shared_from_base<Evaluator>(), out_conf); trace::TraceGraphEvalEnter(shared_from_base<Evaluator>(), out_conf);
MS_LOG(DEBUG) << EvalEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf); MS_LOG(DEBUG) << EvalEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
MS_EXCEPTION_IF_NULL(cache_); MS_EXCEPTION_IF_NULL(evaluator_cache_map_);
auto iter = cache_->find(args_spec_list); auto iter = evaluator_cache_map_->find(args_spec_list);
if (iter == cache_->end()) { if (iter == evaluator_cache_map_->end()) {
MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval()."; MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval().";
EvalResultPtr ret = Eval(engine, args_spec_list); EvalResultPtr ret = Eval(engine, args_spec_list);
if (ret->abstract() == nullptr) { if (ret->abstract() == nullptr) {
@ -296,7 +274,7 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args
MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr."; MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
} }
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->abstract()->ToString() << "."; MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->abstract()->ToString() << ".";
(*cache_)[args_spec_list] = ret; (*evaluator_cache_map_)[args_spec_list] = ret;
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>()); trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
return ret; return ret;
} else { } else {
@ -315,7 +293,7 @@ EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[is_py_eval](const ConfigPtr &conf) -> AbstractBasePtr { [is_py_eval](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
auto abstract = conf->GetEvaluatedValue()->abstract(); auto abstract = conf->ObtainEvalResult()->abstract();
// broaden the ref_key, while infer python prim for cache // broaden the ref_key, while infer python prim for cache
if (is_py_eval && abstract->isa<AbstractRef>()) { if (is_py_eval && abstract->isa<AbstractRef>()) {
auto abs_ref = abstract->cast<AbstractRefPtr>(); auto abs_ref = abstract->cast<AbstractRefPtr>();
@ -333,7 +311,7 @@ EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const Confi
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr { [](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue()->abstract(); return conf->ObtainEvalResult()->abstract();
}); });
if (args_conf_list.size() == 0) { if (args_conf_list.size() == 0) {
MS_LOG(EXCEPTION) << "Size should greater than 0"; MS_LOG(EXCEPTION) << "Size should greater than 0";
@ -354,12 +332,12 @@ EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrLis
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr { [](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue()->abstract(); return conf->ObtainEvalResult()->abstract();
}); });
EvalResultPtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf); EvalResultPtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf);
// Don't lookup from cache, as different out_conf with same node but different context // Don't lookup from cache, as different out_conf with same node but different context
// may add different entry to anfnode_config_map_, like getattr primitive. // may add different entry to anfnode_config_map_, like getattr primitive.
(*cache_)[args_spec_list] = ret; (*evaluator_cache_map_)[args_spec_list] = ret;
return ret; return ret;
} }
@ -369,11 +347,11 @@ EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtr
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr { [](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue()->abstract(); return conf->ObtainEvalResult()->abstract();
}); });
MS_EXCEPTION_IF_NULL(cache_); MS_EXCEPTION_IF_NULL(evaluator_cache_map_);
auto iter = cache_->find(args_spec_list); auto iter = evaluator_cache_map_->find(args_spec_list);
if (iter != cache_->end()) { if (iter != evaluator_cache_map_->end()) {
return iter->second; return iter->second;
} }
@ -386,7 +364,7 @@ EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtr
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); }); [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
EvalResultPtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf); EvalResultPtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf);
(*cache_)[args_spec_list] = ret; (*evaluator_cache_map_)[args_spec_list] = ret;
return ret; return ret;
} }
@ -395,11 +373,11 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr { [](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue()->abstract(); return conf->ObtainEvalResult()->abstract();
}); });
MS_EXCEPTION_IF_NULL(cache_); MS_EXCEPTION_IF_NULL(evaluator_cache_map_);
auto iter = cache_->find(args_spec_list); auto iter = evaluator_cache_map_->find(args_spec_list);
if (iter != cache_->end()) { if (iter != evaluator_cache_map_->end()) {
return iter->second; return iter->second;
} }
@ -427,7 +405,7 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
AbstractBasePtrList jargs = {result->abstract(), bprop}; AbstractBasePtrList jargs = {result->abstract(), bprop};
AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs); AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs);
auto infer_reuslt = std::make_shared<EvalResult>(jtuple, std::make_shared<AttrValueMap>()); auto infer_reuslt = std::make_shared<EvalResult>(jtuple, std::make_shared<AttrValueMap>());
(*cache_)[args_spec_list] = infer_reuslt; (*evaluator_cache_map_)[args_spec_list] = infer_reuslt;
return infer_reuslt; return infer_reuslt;
} }

@ -1,7 +1,7 @@
/** /**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
* *
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -40,7 +40,7 @@ using EvaluatorAttrMapPtr = std::shared_ptr<EvaluatorAttrMap>;
class Evaluator : public Base { class Evaluator : public Base {
public: public:
explicit Evaluator(const std::string &id) explicit Evaluator(const std::string &id)
: cache_(std::make_shared<EvaluatorCacheMap>()), : evaluator_cache_map_(std::make_shared<EvaluatorCacheMap>()),
attr_cache_(std::make_shared<EvaluatorAttrMap>()), attr_cache_(std::make_shared<EvaluatorAttrMap>()),
identifier_(id) {} identifier_(id) {}
~Evaluator() override = default; ~Evaluator() override = default;
@ -86,10 +86,10 @@ class Evaluator : public Base {
virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); } virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); }
EvaluatorCacheMapPtr &cache() { return cache_; } EvaluatorCacheMapPtr &evaluator_cache_map() { return evaluator_cache_map_; }
EvaluatorAttrMapPtr &attr_cache() { return attr_cache_; } EvaluatorAttrMapPtr &attr_cache() { return attr_cache_; }
EvaluatorCacheMapPtr cache_; EvaluatorCacheMapPtr evaluator_cache_map_;
EvaluatorAttrMapPtr attr_cache_; EvaluatorAttrMapPtr attr_cache_;
std::string identifier_; std::string identifier_;

@ -1,7 +1,7 @@
/** /**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
* *
* Copyright 2019-2020 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -53,7 +53,7 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
AnfNodeConfigPtr out_conf) { AnfNodeConfigPtr out_conf) {
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->ObtainEvalResult()->abstract(); });
auto do_signature = prim_->cast<prim::DoSignaturePrimitivePtr>(); auto do_signature = prim_->cast<prim::DoSignaturePrimitivePtr>();
auto &func = do_signature->function(); auto &func = do_signature->function();
if (func->isa<Primitive>()) { if (func->isa<Primitive>()) {
@ -145,7 +145,7 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->ObtainEvalResult()->abstract(); });
// get the forward graph // get the forward graph
MS_EXCEPTION_IF_NULL(args_spec_list[0]); MS_EXCEPTION_IF_NULL(args_spec_list[0]);
auto fn = args_spec_list[0]->cast<AbstractFunctionPtr>(); auto fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
@ -244,7 +244,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
<< ", inputs size " << out_node_inputs.size(); << ", inputs size " << out_node_inputs.size();
} }
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->ObtainEvalResult()->abstract(); });
ScopePtr scope = kDefaultScope; ScopePtr scope = kDefaultScope;
if (out_conf != nullptr) { if (out_conf != nullptr) {
@ -600,8 +600,8 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
} }
MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
const auto &iter = cache_->find(args); const auto &iter = evaluator_cache_map_->find(args);
if (iter != cache_->end()) { if (iter != evaluator_cache_map_->end()) {
return iter->second; return iter->second;
} }
auto py_args = PreparePyInputs(prim_py_, args); auto py_args = PreparePyInputs(prim_py_, args);
@ -614,7 +614,7 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs)); auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
(*cache_)[args] = infer_result; (*evaluator_cache_map_)[args] = infer_result;
return infer_result; return infer_result;
} }
@ -936,7 +936,7 @@ class EmbedEvaluator : public SymbolicPrimEvaluator {
AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]); AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
MS_EXCEPTION_IF_NULL(node_conf); MS_EXCEPTION_IF_NULL(node_conf);
AbstractBasePtr x = node_conf->GetEvaluatedValue()->abstract(); AbstractBasePtr x = node_conf->ObtainEvalResult()->abstract();
x = SensitivityTransform(x); x = SensitivityTransform(x);
SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x); SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x);
AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>()); AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>());
@ -976,7 +976,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
MS_LOG(ERROR) << "Conf should be AnfNodeConfig"; MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
return nullptr; return nullptr;
} }
AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract(); AbstractBasePtr abs = node_conf->ObtainEvalResult()->abstract();
AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>(); AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
if (ref_abs == nullptr) { if (ref_abs == nullptr) {
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString(); MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
@ -1040,7 +1040,7 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
} }
// don't lookup from cache, as different out_conf with same node but different context // don't lookup from cache, as different out_conf with same node but different context
// may add different entry to anfnode_config_map, like getattr primitive; // may add different entry to anfnode_config_map, like getattr primitive;
(*cache_)[args_spec_list] = ret; (*evaluator_cache_map_)[args_spec_list] = ret;
return ret; return ret;
} }
}; };
@ -1126,7 +1126,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf); AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>()); auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
(*cache_)[args_spec_list] = infer_result; (*evaluator_cache_map_)[args_spec_list] = infer_result;
return infer_result; return infer_result;
} }
@ -1161,7 +1161,7 @@ class PartialEvaluator : public Evaluator {
MS_EXCEPTION_IF_NULL(out_conf); MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node()); MS_EXCEPTION_IF_NULL(out_conf->node());
auto arg0_value = args_conf_list[0]->GetEvaluatedValue()->abstract(); auto arg0_value = args_conf_list[0]->ObtainEvalResult()->abstract();
AbstractBasePtrList args_spec_list{arg0_value}; AbstractBasePtrList args_spec_list{arg0_value};
// Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node. // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
if (arg0_value->isa<AbstractError>()) { if (arg0_value->isa<AbstractError>()) {
@ -1169,7 +1169,7 @@ class PartialEvaluator : public Evaluator {
MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString() MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
<< " as func is: " << arg0_value->ToString(); << " as func is: " << arg0_value->ToString();
auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>()); auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
(*cache_)[args_spec_list] = eval_result; (*evaluator_cache_map_)[args_spec_list] = eval_result;
return eval_result; return eval_result;
} }
auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0); auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
@ -1182,11 +1182,9 @@ class PartialEvaluator : public Evaluator {
} }
} }
std::vector<EvalResultPtr> eval_result_list; (void)std::transform(
(void)std::transform(args_conf_list.cbegin() + 1, args_conf_list.cend(), std::back_inserter(eval_result_list), args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &config) -> EvalResultPtr { return config->GetEvaluatedValue(); }); [](const ConfigPtr &config) -> AbstractBasePtr { return config->ObtainEvalResult()->abstract(); });
(void)std::transform(eval_result_list.cbegin(), eval_result_list.cend(), std::back_inserter(args_spec_list),
[](const EvalResultPtr &eval_result) -> AbstractBasePtr { return eval_result->abstract(); });
AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end()); AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end());
auto cnode = out_conf->node()->cast<CNodePtr>(); auto cnode = out_conf->node()->cast<CNodePtr>();
@ -1195,25 +1193,16 @@ class PartialEvaluator : public Evaluator {
MS_LOG(EXCEPTION) << "Out_conf node: " << cnode->DebugString() MS_LOG(EXCEPTION) << "Out_conf node: " << cnode->DebugString()
<< ", args_conf_list: " << mindspore::ToString(args_conf_list); << ", args_conf_list: " << mindspore::ToString(args_conf_list);
} }
auto flag = std::any_of(eval_result_list.cbegin(), eval_result_list.cend(), [](const EvalResultPtr &eval_result) {
MS_LOG(DEBUG) << "Propagate isolate nodes flag from: " << eval_result->abstract()->ToString()
<< ", flag: " << eval_result->HasIsolateNodesPropagateCNodeFlag();
return eval_result->HasIsolateNodesPropagateCNodeFlag();
});
AbstractFuncAtomPtrList partial_funcs_list; AbstractFuncAtomPtrList partial_funcs_list;
auto build_partial = [args, cnode, flag, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) { auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) {
auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args, cnode); auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args, cnode);
partial_funcs_list.push_back(new_func); partial_funcs_list.push_back(new_func);
if (atom_func->HasIsolateNodesFlag() || flag) {
new_func->SetIsolateNodesFlag(true);
}
}; };
func->Visit(build_partial); func->Visit(build_partial);
auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list); auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>()); auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
(*cache_)[args_spec_list] = eval_result; (*evaluator_cache_map_)[args_spec_list] = eval_result;
return eval_result; return eval_result;
} }

@ -1,7 +1,7 @@
/** /**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
* *
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -98,8 +98,6 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
void ProcessNode(const AnfNodePtr &node); void ProcessNode(const AnfNodePtr &node);
void ProcessCNode(const CNodePtr &new_node); void ProcessCNode(const CNodePtr &new_node);
void ProcessIsolateNodes();
AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node); AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node);
inline void AddTodoItem(const AnfNodePtr &node) { todo_.push_back(node); } inline void AddTodoItem(const AnfNodePtr &node) { todo_.push_back(node); }
// Get node replicated by Cloner. // Get node replicated by Cloner.
@ -114,12 +112,9 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
// Build a value node if ival is constant and not any-value // Build a value node if ival is constant and not any-value
AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
const AttrValueMapPtr &attrs); const AttrValueMapPtr &attrs);
// Build a replaceable node for iconf->node; it may be a replicated forward CNode in static analysis or just a // Build a replaceable node for iconf->node; it may be a replicated forwarded CNode in static analysis or just a
// replicated node. First of returned pair is the origin node or the forward cnode, second is the replaced node. // replicated node.
std::pair<AnfNodePtr, AnfNodePtr> BuildReplacedNode(const AnfNodeConfigPtr &conf); AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf);
// Collect CNodes which have IsolateNodes that will be replaced by a ValuedNode.
AnfNodePtr CollectCNodeWithIsolateNodes(const CNodePtr &c_node, const EvalResultPtr &c_node_eval_result,
const FuncGraphPtr &new_fg);
// Build a specialized node from given argvals; // Build a specialized node from given argvals;
AnfNodePtr BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs, AnfNodePtr BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs,
const AbstractBasePtrList &argvals); const AbstractBasePtrList &argvals);

@ -1,7 +1,7 @@
/** /**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
* *
* Copyright 2019-2020 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -58,7 +58,7 @@ void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr
MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString() MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString()
<< ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString() << ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString()
<< ", Pointer: " << result->abstract().get(); << ", Pointer: " << result->abstract().get();
cache_[conf] = result; analysis_cache_map_[conf] = result;
// Set intermediate abstract value. // Set intermediate abstract value.
if (IsIntermediateAbstract(result->abstract())) { if (IsIntermediateAbstract(result->abstract())) {
@ -77,8 +77,8 @@ void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr
} }
EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) { EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) {
auto value = cache_.find(conf); auto value = analysis_cache_map_.find(conf);
if (value == cache_.end()) { if (value == analysis_cache_map_.end()) {
return nullptr; return nullptr;
} }
return value->second; return value->second;
@ -124,7 +124,7 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac
AnalysisResult result; AnalysisResult result;
MS_EXCEPTION_IF_NULL(output_conf); MS_EXCEPTION_IF_NULL(output_conf);
result.inferred = output_conf->GetEvaluatedValue(); result.inferred = output_conf->ObtainEvalResult();
result.context = root_context; result.context = root_context;
return result; return result;
} }
@ -136,25 +136,24 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana
return eval->graph_context(); return eval->graph_context();
} }
EvalResultPtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) { EvalResultPtr AnalysisEngine::ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
auto value = cache_.GetValue(conf); EvalResultPtr result = analysis_cache_.GetValue(conf);
if (value != nullptr) { if (result != nullptr) {
MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value->abstract().get() MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString()
<< ", " << value->abstract()->ToString() << ", flag: " << value->HasIsolateNodesPropagateCNodeFlag(); << ", Value: " << result->abstract().get() << ", " << result->abstract()->ToString();
return value; return result;
} }
MS_LOG(DEBUG) << "Evaluate cache miss for NodeConfig: " << conf->ToString(); MS_LOG(DEBUG) << "Evaluate cache miss for NodeConfig: " << conf->ToString();
value = Eval(conf); result = Eval(conf);
if (value == nullptr) { if (result == nullptr) {
MS_LOG(EXCEPTION) << "Evaluate for NodeConfig " << conf->ToString() << " get nullptr"; MS_LOG(EXCEPTION) << "Evaluate for NodeConfig " << conf->ToString() << " get nullptr";
} }
MS_LOG(DEBUG) << "Evaluate node on demond for NodeConfig: " << conf->ToString() MS_LOG(DEBUG) << "Evaluate node on demond for NodeConfig: " << conf->ToString()
<< ", Value: " << value->abstract().get() << ", " << value->abstract()->ToString() << ", result: " << result->abstract().get() << ", " << result->abstract()->ToString();
<< ", flag: " << value->HasIsolateNodesPropagateCNodeFlag(); analysis_cache_.set_value(conf, result);
cache_.set_value(conf, value); return result;
return value;
} }
EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
@ -198,8 +197,7 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
<< " NodeInfo: " << trace::GetDebugInfo(node->debug_info()); << " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
} }
#endif #endif
MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString() MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString();
<< ", flag: " << eval_result->HasIsolateNodesPropagateCNodeFlag();
return eval_result; return eval_result;
} }
@ -251,20 +249,6 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co
return out; return out;
} }
static bool CheckIsolateNodesPropagateFlag(const AbstractFunctionPtr &abs_func, const ConfigPtrList &conf_list) {
if (abs_func->HasIsolateNodesFlag()) {
MS_LOG(DEBUG) << "Propagate isolate nodes flag from: " << abs_func->ToString();
return true;
}
auto flag = std::any_of(conf_list.cbegin(), conf_list.cend(), [](const ConfigPtr &conf) {
auto eval_result = conf->GetEvaluatedValue();
MS_LOG(DEBUG) << "Propagate isolate nodes flag from: " << eval_result->abstract()->ToString()
<< ", flag: " << eval_result->HasIsolateNodesPropagateCNodeFlag();
return eval_result->HasIsolateNodesPropagateCNodeFlag();
});
return flag;
}
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
@ -280,10 +264,10 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
AnfNodeConfigPtr func_conf = MakeConfig(func_node, context); AnfNodeConfigPtr func_conf = MakeConfig(func_node, context);
MS_EXCEPTION_IF_NULL(func_conf); MS_EXCEPTION_IF_NULL(func_conf);
// Keep it in a local variable, otherwise smart pointer will free it. // Keep it in a local variable, otherwise smart pointer will free it.
auto maybe_func_eval_result = func_conf->GetEvaluatedValue(); auto maybe_func_eval_result = func_conf->ObtainEvalResult();
AbstractBasePtr maybe_func = maybe_func_eval_result->abstract(); AbstractBasePtr maybe_func = maybe_func_eval_result->abstract();
if (maybe_func == nullptr) { if (maybe_func == nullptr) {
MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString() MS_LOG(EXCEPTION) << "No abstract, func_conf: " << func_conf->ToString()
<< " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
} }
if (maybe_func->BuildType()->type_id() == kObjectTypeUndeterminedType) { if (maybe_func->BuildType()->type_id() == kObjectTypeUndeterminedType) {
@ -292,8 +276,7 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
} }
AbstractFunctionPtr func = dyn_cast<AbstractFunction>(maybe_func); AbstractFunctionPtr func = dyn_cast<AbstractFunction>(maybe_func);
if (func == nullptr) { if (func == nullptr) {
MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return not AbstractFunction: " << maybe_func->ToString() MS_LOG(EXCEPTION) << "Not AbstractFunction: " << maybe_func->ToString() << ", func_conf: " << func_conf->ToString()
<< ", func_conf: " << func_conf->ToString()
<< " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
} }
@ -313,21 +296,6 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
func->Visit(build_evaluator); func->Visit(build_evaluator);
auto eval_result = ExecuteEvaluators(infs, conf, args_conf_list); auto eval_result = ExecuteEvaluators(infs, conf, args_conf_list);
auto flag = CheckIsolateNodesPropagateFlag(func, args_conf_list);
if (flag != eval_result->HasIsolateNodesPropagateCNodeFlag()) {
MS_LOG(DEBUG) << "Different propagate isolate nodes flag from: " << eval_result->abstract()->ToString()
<< ", cnode flag: " << eval_result->HasIsolateNodesPropagateCNodeFlag()
<< ", funcgraph flag: " << eval_result->HasIsolateNodesPropagateFuncGraphFlag()
<< ", check flag:" << flag;
// This eval_result may be fetch from an Evaluator's cache based on args_spec_list equality.
// But args may be come from different CNode, so propagate flag is not same,
// a new copy of eval_result should be used.
auto new_eval_result = eval_result->Clone();
// FuncGraph flag should be used for HOF call or used FuncGraph propagate.
flag = flag | new_eval_result->HasIsolateNodesPropagateFuncGraphFlag();
new_eval_result->SetIsolateNodesPropagateCNodeFlag(flag);
eval_result = new_eval_result;
}
return eval_result; return eval_result;
} }
@ -349,25 +317,25 @@ void AnalysisEngine::ClearEvaluatorCache() {
for (std::pair<AbstractFunctionPtr, EvaluatorPtr> element : constructors_) { for (std::pair<AbstractFunctionPtr, EvaluatorPtr> element : constructors_) {
EvaluatorPtr evaluator = element.second; EvaluatorPtr evaluator = element.second;
MS_EXCEPTION_IF_NULL(evaluator); MS_EXCEPTION_IF_NULL(evaluator);
MS_EXCEPTION_IF_NULL(evaluator->cache()); MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_map());
evaluator->cache()->clear(); evaluator->evaluator_cache_map()->clear();
} }
for (auto &element : prim_constructors_) { for (auto &element : prim_constructors_) {
EvaluatorPtr evaluator = element.second; EvaluatorPtr evaluator = element.second;
MS_EXCEPTION_IF_NULL(evaluator); MS_EXCEPTION_IF_NULL(evaluator);
MS_EXCEPTION_IF_NULL(evaluator->cache()); MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_map());
evaluator->cache()->clear(); evaluator->evaluator_cache_map()->clear();
} }
for (auto &element : prim_py_evaluators_) { for (auto &element : prim_py_evaluators_) {
EvaluatorPtr evaluator = element.second; EvaluatorPtr evaluator = element.second;
MS_EXCEPTION_IF_NULL(evaluator); MS_EXCEPTION_IF_NULL(evaluator);
MS_EXCEPTION_IF_NULL(evaluator->cache()); MS_EXCEPTION_IF_NULL(evaluator->evaluator_cache_map());
evaluator->cache()->clear(); evaluator->evaluator_cache_map()->clear();
} }
} }
void AnalysisEngine::Clear() { void AnalysisEngine::Clear() {
cache_.Clear(); analysis_cache_.Clear();
anfnode_config_map_.clear(); anfnode_config_map_.clear();
eval_trace_.clear(); eval_trace_.clear();
constructors_.clear(); constructors_.clear();
@ -586,7 +554,7 @@ EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, c
} }
} }
forward_count_++; forward_count_++;
auto res = GetEvaluatedValue(new_conf); auto res = ObtainEvalResultWithCache(new_conf);
forward_count_--; forward_count_--;
return res; return res;
} }
@ -651,7 +619,7 @@ EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPt
for (auto u_eval : undetermined_evals) { for (auto u_eval : undetermined_evals) {
MS_LOG(DEBUG) << u_eval.evaluator_->ToString() << "check undetermined."; MS_LOG(DEBUG) << u_eval.evaluator_->ToString() << "check undetermined.";
auto &alternate_evaluator = multi_poss_[u_eval.evaluator_]; auto &alternate_evaluator = multi_poss_[u_eval.evaluator_];
auto &eval_cache = alternate_evaluator->cache(); auto &eval_cache = alternate_evaluator->evaluator_cache_map();
const auto &alt_eval_args = EvaluatorArgs(alternate_evaluator, args_spec_list); const auto &alt_eval_args = EvaluatorArgs(alternate_evaluator, args_spec_list);
if ((!undetermined_evals.count(alt_eval_args)) && if ((!undetermined_evals.count(alt_eval_args)) &&
(((!continued_evals_.count(u_eval)) && (eval_cache->find(args_spec_list) != eval_cache->end())) || (((!continued_evals_.count(u_eval)) && (eval_cache->find(args_spec_list) != eval_cache->end())) ||
@ -698,7 +666,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr { [](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue()->abstract(); return conf->ObtainEvalResult()->abstract();
}); });
for (auto eval : evaluators) { for (auto eval : evaluators) {
SetUndeterminedFlag(eval); SetUndeterminedFlag(eval);
@ -741,9 +709,9 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
return ProcessEvalResults(out_specs); return ProcessEvalResults(out_specs);
} }
EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { EvalResultPtr AnfNodeConfig::ObtainEvalResult() {
AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>(); AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>();
return engine_.lock()->GetEvaluatedValue(self); return engine_.lock()->ObtainEvalResultWithCache(self);
} }
abstract::AbstractBasePtr MakeAbstractClosure(const FuncGraphPtr &func_graph, abstract::AbstractBasePtr MakeAbstractClosure(const FuncGraphPtr &func_graph,

@ -1,7 +1,7 @@
/** /**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
* *
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -46,9 +46,6 @@ namespace abstract {
using AttrValueMap = std::unordered_map<std::string, ValuePtr>; using AttrValueMap = std::unordered_map<std::string, ValuePtr>;
using AttrValueMapPtr = std::shared_ptr<AttrValueMap>; using AttrValueMapPtr = std::shared_ptr<AttrValueMap>;
inline const int kIsolateNodesPropagateCNodeFlag = 1;
inline const int kIsolateNodesPropagateFuncGraphFlag = 2;
// the class to save evaluated result: abstract value and modified attribute // the class to save evaluated result: abstract value and modified attribute
class EvalResult : public Base { class EvalResult : public Base {
public: public:
@ -58,43 +55,10 @@ class EvalResult : public Base {
AbstractBasePtr abstract() { return abstract_; } AbstractBasePtr abstract() { return abstract_; }
AttrValueMapPtr attribute() { return attribute_; } AttrValueMapPtr attribute() { return attribute_; }
std::shared_ptr<EvalResult> Clone() const {
auto cloned = std::make_shared<EvalResult>(abstract_, attribute_);
cloned->SetIsolateNodesPropagateCNodeFlag(HasIsolateNodesPropagateCNodeFlag());
cloned->SetIsolateNodesPropagateFuncGraphFlag(HasIsolateNodesPropagateFuncGraphFlag());
return cloned;
}
// The related AbstractBase is evaluated from CNode which input has isolate nodes.
// This flag is propagated to all user node.
// When a node A can be specialized to a ValueNode, we should check if that node A has this flag,
// if it has, then the original FuncGraph call should be depended, so it's side effect will not
// be lost.
bool HasIsolateNodesPropagateCNodeFlag() const {
auto iter = eval_attr_.find(kIsolateNodesPropagateCNodeFlag);
if (iter != eval_attr_.end()) {
return GetValue<bool>(iter->second);
}
return false;
}
void SetIsolateNodesPropagateCNodeFlag(bool flag) { eval_attr_[kIsolateNodesPropagateCNodeFlag] = MakeValue(flag); }
// FuncGraph itself may not have IsoloateNodes, but the used FuncGraph or HOF call may have IsolateNodes;
bool HasIsolateNodesPropagateFuncGraphFlag() const {
auto iter = eval_attr_.find(kIsolateNodesPropagateFuncGraphFlag);
if (iter != eval_attr_.end()) {
return GetValue<bool>(iter->second);
}
return false;
}
void SetIsolateNodesPropagateFuncGraphFlag(bool flag) {
eval_attr_[kIsolateNodesPropagateFuncGraphFlag] = MakeValue(flag);
}
private: private:
AbstractBasePtr abstract_; AbstractBasePtr abstract_;
// Attribute related to PrimEvaluator; // Attribute related to PrimEvaluator;
AttrValueMapPtr attribute_; AttrValueMapPtr attribute_;
std::unordered_map<int, ValuePtr> eval_attr_;
}; };
using EvalResultPtr = std::shared_ptr<EvalResult>; using EvalResultPtr = std::shared_ptr<EvalResult>;
@ -104,7 +68,7 @@ class Config : public Base {
Config() = default; Config() = default;
~Config() override = default; ~Config() override = default;
MS_DECLARE_PARENT(Config, Base); MS_DECLARE_PARENT(Config, Base);
virtual EvalResultPtr GetEvaluatedValue() = 0; virtual EvalResultPtr ObtainEvalResult() = 0;
}; };
// Config will be stored in AnalysisCache // Config will be stored in AnalysisCache
@ -132,7 +96,7 @@ class AnfNodeConfig : public Config {
~AnfNodeConfig() override = default; ~AnfNodeConfig() override = default;
MS_DECLARE_PARENT(AnfNodeConfig, Config); MS_DECLARE_PARENT(AnfNodeConfig, Config);
EvalResultPtr GetEvaluatedValue() override; EvalResultPtr ObtainEvalResult() override;
AnalysisContextPtr context() const { return context_; } AnalysisContextPtr context() const { return context_; }
@ -182,7 +146,7 @@ class VirtualConfig : public Config {
~VirtualConfig() override = default; ~VirtualConfig() override = default;
MS_DECLARE_PARENT(VirtualConfig, Config); MS_DECLARE_PARENT(VirtualConfig, Config);
EvalResultPtr GetEvaluatedValue() override { EvalResultPtr ObtainEvalResult() override {
return std::make_shared<EvalResult>(abstract_, std::make_shared<AttrValueMap>()); return std::make_shared<EvalResult>(abstract_, std::make_shared<AttrValueMap>());
} }
@ -195,12 +159,12 @@ class AnalysisCache {
public: public:
AnalysisCache() = default; AnalysisCache() = default;
~AnalysisCache() = default; ~AnalysisCache() = default;
void Clear() { cache_.clear(); } void Clear() { analysis_cache_map_.clear(); }
void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg); void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg);
EvalResultPtr GetValue(const AnfNodeConfigPtr &conf); EvalResultPtr GetValue(const AnfNodeConfigPtr &conf);
private: private:
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual> cache_; std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual> analysis_cache_map_;
}; };
using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>; using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>;
@ -222,7 +186,9 @@ struct PartialAppHasher {
class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
public: public:
AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
: cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) { : analysis_cache_(AnalysisCache()),
prim_constructors_(prim_evaluator_map),
func_graph_manager_(func_graph_manager) {
function_call_depth_ = 0; function_call_depth_ = 0;
forward_count_ = 0; forward_count_ = 0;
} }
@ -231,7 +197,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
// func_graph: The func_graph to analyze. // func_graph: The func_graph to analyze.
// args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase. // args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.
AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
EvalResultPtr GetEvaluatedValue(const AnfNodeConfigPtr &conf); EvalResultPtr ObtainEvalResultWithCache(const AnfNodeConfigPtr &conf);
// Return the Evaluator for the given function. // Return the Evaluator for the given function.
EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn); EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn);
@ -241,7 +207,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list); EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list);
void Clear(); void Clear();
void ClearEvaluatorCache(); void ClearEvaluatorCache();
AnalysisCache &cache() { return cache_; } AnalysisCache &analysis_cache() { return analysis_cache_; }
AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context) { AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context) {
return std::make_shared<AnfNodeConfig>(shared_from_this(), node, context); return std::make_shared<AnfNodeConfig>(shared_from_this(), node, context);
} }
@ -262,7 +228,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf); EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf);
const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; } const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; }
AnalysisCache cache_; AnalysisCache analysis_cache_;
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_; std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
void ResetFunctionCallDepth() { function_call_depth_ = 0; } void ResetFunctionCallDepth() { function_call_depth_ = 0; }

@ -1,7 +1,7 @@
/** /**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
* *
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -58,11 +58,6 @@ class AbstractFuncUnion : public AbstractFunction {
bool operator==(const AbstractFunction &other) const override; bool operator==(const AbstractFunction &other) const override;
std::size_t hash() const override; std::size_t hash() const override;
AbstractFunctionPtr Copy() const override { MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; } AbstractFunctionPtr Copy() const override { MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; }
bool HasIsolateNodesFlag() const override {
bool flag = std::any_of(func_list_.cbegin(), func_list_.cend(),
[](const AbstractFunctionPtr &func) { return func->HasIsolateNodesFlag(); });
return flag;
}
private: private:
AbstractFuncAtomPtrList func_list_; AbstractFuncAtomPtrList func_list_;
@ -131,8 +126,6 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom {
std::string ToString() const override; std::string ToString() const override;
bool HasIsolateNodesFlag() const override { return !func_graph_->isolate_nodes().empty(); }
private: private:
FuncGraphPtr func_graph_; FuncGraphPtr func_graph_;
AnalysisContextPtr context_; AnalysisContextPtr context_;
@ -202,16 +195,12 @@ class PartialAbstractClosure : public AbstractFuncAtom {
std::size_t hash() const override; std::size_t hash() const override;
std::string ToString() const override; std::string ToString() const override;
bool HasIsolateNodesFlag() const override { return isolate_nodes_flag_; }
void SetIsolateNodesFlag(bool flag) { isolate_nodes_flag_ = flag; }
private: private:
AbstractFuncAtomPtr fn_; AbstractFuncAtomPtr fn_;
AbstractBasePtrList args_spec_list_; AbstractBasePtrList args_spec_list_;
// The CNode which this PartialAbstractClosure evaluated from. // The CNode which this PartialAbstractClosure evaluated from.
AnfNodeWeakPtr node_; AnfNodeWeakPtr node_;
// If the bound fn_ has isolate ndoes or arguments evaluated from function has isolate nodes.
bool isolate_nodes_flag_{false};
}; };
using PartialAbstractClosurePtr = std::shared_ptr<PartialAbstractClosure>; using PartialAbstractClosurePtr = std::shared_ptr<PartialAbstractClosure>;

@ -1,7 +1,7 @@
/** /**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
* *
* Copyright 2019-2020 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -207,8 +207,6 @@ class AbstractFunction : public AbstractBase {
virtual AnfNodePtr tracking_id() const { return nullptr; } virtual AnfNodePtr tracking_id() const { return nullptr; }
virtual void set_tracking_id(AnfNodePtr) {} virtual void set_tracking_id(AnfNodePtr) {}
virtual AnalysisContextPtr context() const { return nullptr; } virtual AnalysisContextPtr context() const { return nullptr; }
// Function which itself has IsolateNodes, not include used function or HOF.
virtual bool HasIsolateNodesFlag() const { return false; }
}; };
using AbstractFunctionPtrList = std::vector<AbstractFunctionPtr>; using AbstractFunctionPtrList = std::vector<AbstractFunctionPtr>;

@ -157,8 +157,8 @@ void CheckShapeAllPositive(const std::string &op, const ShapeVector &shape) {
void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) { void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) {
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
if ((shape[i] < 0) && (shape[i] != Shape::SHP_ANY)) { if ((shape[i] < 0) && (shape[i] != Shape::SHP_ANY)) {
MS_LOG(EXCEPTION) << op << " shape element [" << i << "] must be positive integer or SHP_ANY, but got " MS_EXCEPTION(ValueError) << op << " shape element [" << i << "] must be positive integer or SHP_ANY, but got "
<< shape[i]; << shape[i];
} }
} }
} }

@ -1,7 +1,7 @@
/** /**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
* *
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -65,7 +65,7 @@ using CNodePtrList = std::vector<CNodePtr>;
class FuncGraph; class FuncGraph;
using FuncGraphSet = OrderedSet<FuncGraphPtr>; using FuncGraphSet = OrderedSet<FuncGraphPtr>;
using FuncGraphPtrList = std::vector<FuncGraphPtr>; using FuncGraphVector = std::vector<FuncGraphPtr>;
class Primitive; class Primitive;
using PrimitivePtr = std::shared_ptr<Primitive>; using PrimitivePtr = std::shared_ptr<Primitive>;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save