You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
608 lines
24 KiB
608 lines
24 KiB
/**
|
|
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
|
*
|
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "pipeline/static_analysis/program_specialize.h"
|
|
|
|
#include <algorithm>
|
|
#include <exception>
|
|
#include "./common.h"
|
|
#include "operator/ops.h"
|
|
#include "operator/composite/do_signature.h"
|
|
#include "utils/graph_utils.h"
|
|
#include "utils/profile.h"
|
|
#include "debug/trace.h"
|
|
|
|
namespace mindspore {
|
|
namespace abstract {
|
|
namespace {
|
|
inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) {
|
|
if (conf->node()->intermediate_abstract()) {
|
|
return conf->node()->intermediate_abstract();
|
|
}
|
|
return conf->GetEvaluatedValue();
|
|
}
|
|
|
|
AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) {
|
|
AnfNodePtr value_node = NewValueNode(v);
|
|
value_node->set_abstract(abs_base);
|
|
MS_LOG(DEBUG) << "Create ValueNode: " << value_node->ToString() << ", with abstract: " << abs_base->ToString();
|
|
return value_node;
|
|
}
|
|
|
|
bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) {
|
|
while (fg != nullptr && fg != parent) {
|
|
fg = fg->parent();
|
|
}
|
|
return fg == parent;
|
|
}
|
|
} // namespace
|
|
|
|
FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
|
|
MS_EXCEPTION_IF_NULL(fg);
|
|
MS_EXCEPTION_IF_NULL(context);
|
|
MS_LOG(DEBUG) << "Specialize topmost function graph: " << context->func_graph()->ToString();
|
|
return SpecializeFuncGraph(fg, context);
|
|
}
|
|
|
|
FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
|
|
MS_EXCEPTION_IF_NULL(fg);
|
|
MS_EXCEPTION_IF_NULL(context);
|
|
auto iter = specializations_.find(context->SpecializeKey());
|
|
if (iter != specializations_.end()) {
|
|
return iter->second->specialized_func_graph();
|
|
}
|
|
|
|
std::shared_ptr<FuncGraphSpecializer> fg_spec = std::make_shared<FuncGraphSpecializer>(this, fg, context);
|
|
FuncGraphPtr fg2 = fg_spec->specialized_func_graph();
|
|
specializations_[context->SpecializeKey()] = fg_spec;
|
|
fg_spec->Run();
|
|
return fg2;
|
|
}
|
|
|
|
std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) {
|
|
MS_EXCEPTION_IF_NULL(context);
|
|
auto iter = specializations_.find(context->SpecializeKey());
|
|
if (iter != specializations_.end()) {
|
|
return iter->second;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
std::string GetNextCounter() {
|
|
static int g_CloneCounter = 1;
|
|
std::string str_count = std::to_string(g_CloneCounter);
|
|
g_CloneCounter++;
|
|
return str_count;
|
|
}
|
|
|
|
FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg,
|
|
const AnalysisContextPtr &context)
|
|
: specializer_(s), func_graph_(fg), context_(context) {
|
|
parent_ = s->GetFuncGraphSpecializer(context->parent());
|
|
engine_ = s->engine();
|
|
cloner_ = SpecializerClone(fg, std::make_shared<TraceSpecialize>(GetNextCounter()));
|
|
repl_node_ = cloner_->cloned_node();
|
|
specialized_func_graph_ = cloner_->cloned_func_graph()[fg];
|
|
todo_.push_back(fg->get_return());
|
|
auto ps = fg->parameters();
|
|
(void)todo_.insert(todo_.end(), ps.begin(), ps.end());
|
|
}
|
|
|
|
AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &node) {
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
FuncGraphPtr fg = node->func_graph();
|
|
|
|
if (node->isa<ValueNode>()) {
|
|
return node;
|
|
}
|
|
std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
|
|
while (fg != nullptr && fg != specializer->func_graph_) {
|
|
specializer = specializer->parent_;
|
|
}
|
|
// If had replicated, just return that.
|
|
auto iter = specializer->repl_node_->find(node);
|
|
if (iter != specializer->repl_node_->end()) {
|
|
return iter->second;
|
|
}
|
|
|
|
auto new_node = specializer->cloner_->CloneDisconnected(node);
|
|
if (node->isa<CNode>()) {
|
|
if (!new_node->isa<CNode>()) {
|
|
MS_LOG(EXCEPTION) << "new_node must be a CNode, but is " << new_node->DebugString() << ".";
|
|
}
|
|
auto c_node = node->cast<CNodePtr>();
|
|
MS_EXCEPTION_IF_NULL(c_node);
|
|
auto inputs = c_node->inputs();
|
|
std::vector<AnfNodePtr> new_inputs;
|
|
(void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(new_inputs),
|
|
[this](const AnfNodePtr &inp) -> AnfNodePtr {
|
|
if (inp->isa<ValueNode>()) {
|
|
return inp;
|
|
}
|
|
return ReplicateDisconnectedNode(inp);
|
|
});
|
|
auto c_new_node = new_node->cast<CNodePtr>();
|
|
MS_EXCEPTION_IF_NULL(c_new_node);
|
|
c_new_node->set_inputs(new_inputs);
|
|
}
|
|
|
|
iter = specializer->repl_node_->find(node);
|
|
if (iter != specializer->repl_node_->end()) {
|
|
if (iter->second == node) {
|
|
MS_LOG(EXCEPTION) << "Replicated is same as original node, node: " << node->ToString();
|
|
}
|
|
} else {
|
|
MS_LOG(EXCEPTION) << "Replicate node failed, node: " << node->ToString();
|
|
}
|
|
return new_node;
|
|
}
|
|
|
|
AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) {
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
FuncGraphPtr fg = node->func_graph();
|
|
|
|
std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
|
|
while (fg != nullptr && fg != specializer->func_graph_) {
|
|
specializer = specializer->parent_;
|
|
}
|
|
|
|
MS_EXCEPTION_IF_NULL(specializer->repl_node_);
|
|
auto iter = specializer->repl_node_->find(node);
|
|
if (iter != specializer->repl_node_->end()) {
|
|
return iter->second;
|
|
}
|
|
return node;
|
|
}
|
|
|
|
void FuncGraphSpecializer::Run() {
|
|
MS_LOG(DEBUG) << "Before run, origin func graph name: " << func_graph_->ToString()
|
|
<< ", cloned func graph name: " << specialized_func_graph_->ToString()
|
|
<< ", func graph: " << func_graph_->get_return()->DebugString();
|
|
FirstPass();
|
|
SecondPass();
|
|
MS_LOG(DEBUG) << "After run, origin func graph name: " << func_graph_->ToString()
|
|
<< ", cloned func graph name: " << specialized_func_graph_->ToString()
|
|
<< ", new func graph: " << specialized_func_graph_->get_return()->DebugString();
|
|
}
|
|
|
|
void FuncGraphSpecializer::FirstPass() {
|
|
while (todo_.size()) {
|
|
AnfNodePtr node = todo_.back();
|
|
todo_.pop_back();
|
|
if (node->func_graph() == nullptr) {
|
|
// do nothing for ValueNode
|
|
continue;
|
|
}
|
|
if (node->func_graph() != func_graph_) {
|
|
if (parent_ == nullptr) {
|
|
MS_LOG(EXCEPTION) << "Parent must not null NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
|
}
|
|
parent_->AddTodoItem(node);
|
|
parent_->FirstPass();
|
|
AnfNodePtr new_node = parent_->GetReplicatedNode(node);
|
|
if (node->isa<CNode>()) {
|
|
parent_->ProcessCNode(new_node->cast<CNodePtr>());
|
|
}
|
|
continue;
|
|
}
|
|
if (marked_.count(node) > 0) {
|
|
continue;
|
|
}
|
|
(void)marked_.insert(node);
|
|
ProcessNode(node);
|
|
}
|
|
}
|
|
|
|
// Specialize CNode in func graphs
|
|
void FuncGraphSpecializer::SecondPass() {
|
|
for (auto &node : DeepLinkedGraphSearch(specialized_func_graph_->get_return())) {
|
|
if (node->isa<CNode>()) {
|
|
ProcessCNode(node->cast<CNodePtr>());
|
|
}
|
|
}
|
|
}
|
|
|
|
void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
ScopeGuard scope_guard(node->scope());
|
|
AnfNodeConfigPtr conf = MakeConfig(node);
|
|
AnfNodePtr new_node = GetReplicatedNode(node);
|
|
MS_EXCEPTION_IF_NULL(new_node);
|
|
|
|
if (new_node->func_graph() != specialized_func_graph_) {
|
|
MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString()
|
|
<< ", new_node: " << new_node->DebugString()
|
|
<< ", new_node->func_graph(): " << new_node->func_graph()->ToString()
|
|
<< ", specialized_func_graph_: " << specialized_func_graph_->ToString();
|
|
return;
|
|
}
|
|
new_node->set_abstract(GetEvaluatedValueWrap(conf));
|
|
MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString();
|
|
|
|
if (node->isa<CNode>()) {
|
|
auto c_old = node->cast<CNodePtr>();
|
|
auto c_new = new_node->cast<CNodePtr>();
|
|
auto new_inputs = c_new->inputs();
|
|
auto old_inputs = c_old->inputs();
|
|
for (size_t i = 0; i < old_inputs.size(); ++i) {
|
|
auto node_input = old_inputs[i];
|
|
AnfNodeConfigPtr iconf = MakeConfig(node_input);
|
|
AbstractBasePtr ival = GetEvaluatedValueWrap(iconf);
|
|
// First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
|
|
// can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
|
|
AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival);
|
|
if (replace_node == nullptr) {
|
|
replace_node = BuildReplacedNode(iconf);
|
|
MS_EXCEPTION_IF_NULL(replace_node);
|
|
replace_node->set_abstract(ival);
|
|
MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << ival->ToString();
|
|
} else {
|
|
MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString()
|
|
<< ", ival: " << ival->ToString() << ", replace_node: " << replace_node->ToString();
|
|
}
|
|
if (new_inputs[i] != replace_node) {
|
|
new_inputs[i] = replace_node;
|
|
MS_LOG(DEBUG) << "Set new_input[" << i << "] = " << replace_node->DebugString();
|
|
}
|
|
}
|
|
c_new->set_inputs(new_inputs);
|
|
}
|
|
}
|
|
|
|
AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) {
|
|
MS_EXCEPTION_IF_NULL(conf);
|
|
|
|
auto conf_iter = engine_->anfnode_config_map().find(conf);
|
|
AnfNodeConfigPtr new_conf = conf;
|
|
while (conf_iter != engine_->anfnode_config_map().end()) {
|
|
MS_LOG(DEBUG) << "Origin conf: graph(" << new_conf->node()->func_graph()->ToString() << ", node("
|
|
<< new_conf->node()->DebugString() << ")";
|
|
new_conf = conf_iter->second;
|
|
MS_EXCEPTION_IF_NULL(new_conf);
|
|
MS_LOG(DEBUG) << "Replaced conf: graph(" << conf->node()->func_graph()->ToString() << ", node("
|
|
<< conf->node()->DebugString() << ")";
|
|
(void)ReplicateDisconnectedNode(new_conf->node());
|
|
conf_iter = engine_->anfnode_config_map().find(new_conf);
|
|
}
|
|
todo_.push_back(new_conf->node());
|
|
auto repl = GetReplicatedNode(new_conf->node());
|
|
if (repl->func_graph()) {
|
|
MS_LOG(DEBUG) << "Set repl: graph(" << repl->func_graph()->ToString() << "), node:" << repl->DebugString()
|
|
<< ") to replace origin:" << new_conf->node()->DebugString();
|
|
} else {
|
|
MS_LOG(DEBUG) << "Set repl: graph(nullptr), node(" << repl->DebugString()
|
|
<< ") to replace origin: " << new_conf->node()->DebugString();
|
|
}
|
|
return repl;
|
|
}
|
|
|
|
namespace {
|
|
const StringImmPtr kDeadNode = std::make_shared<StringImm>("Dead Node");
|
|
const StringImmPtr kPolyNode = std::make_shared<StringImm>("Poly Node");
|
|
|
|
inline bool CanSpecializeNode(const AnfNodePtr &node) {
|
|
if (IsValueNode<FuncGraph>(node) || IsValueNode<MetaFuncGraph>(node) || IsValueNode<Primitive>(node)) {
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
} // namespace
|
|
|
|
AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs,
|
|
const AbstractBasePtrList &argvals) {
|
|
MS_EXCEPTION_IF_NULL(abs);
|
|
AbstractFunctionPtr real_a = dyn_cast<AbstractFunction>(abs);
|
|
MS_EXCEPTION_IF_NULL(real_a);
|
|
|
|
AbstractFunctionPtr func = real_a->GetUnique();
|
|
SpecializeStatusCode errcode;
|
|
ScopeGuard scope_guard(node->scope());
|
|
AnfNodePtr repl = BuildSpecializedNodeInner(abs, func, argvals, &errcode);
|
|
if (repl == nullptr) {
|
|
if (errcode == kSpecializeFindUniqueArgvalDead) {
|
|
const auto error_dead_node = std::make_shared<AbstractError>(kDeadNode, node);
|
|
repl = BuildValueNode(kDeadNode, error_dead_node);
|
|
MS_LOG(DEBUG) << "DEAD for node: " << node->DebugString() << ", abstract: " << abs->ToString();
|
|
} else if (errcode == kSpecializeFindUniqueArgvalPoly) {
|
|
const auto error_poly_node = std::make_shared<AbstractError>(kPolyNode, node);
|
|
repl = BuildValueNode(kPolyNode, error_poly_node);
|
|
MS_LOG(DEBUG) << "POLY for node: " << node->DebugString() << ", abstract: " << abs->ToString();
|
|
} else {
|
|
MS_LOG(EXCEPTION) << "Failed to build specialized node, node: " << node->DebugString()
|
|
<< ", abstract: " << abs->ToString();
|
|
}
|
|
}
|
|
|
|
return repl;
|
|
}
|
|
|
|
AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func,
|
|
const AbstractBasePtrList &args,
|
|
SpecializeStatusCode *errcode) {
|
|
MS_EXCEPTION_IF_NULL(abs);
|
|
MS_EXCEPTION_IF_NULL(func);
|
|
MS_EXCEPTION_IF_NULL(errcode);
|
|
*errcode = kSpecializeSuccess;
|
|
|
|
auto real_func = dyn_cast<TypedPrimitiveAbstractClosure>(func);
|
|
if (real_func != nullptr) {
|
|
return BuildValueNode(real_func->prim(), abs);
|
|
}
|
|
|
|
EvaluatorPtr eval;
|
|
eval = engine_->GetEvaluatorFor(func);
|
|
MS_EXCEPTION_IF_NULL(eval);
|
|
AbstractBasePtrList argvals = eval->NormalizeArgs(args);
|
|
|
|
std::pair<AbstractBasePtrList, AbstractBasePtr> result;
|
|
SpecializeStatusCode status = FindUniqueArgvals(func, eval, argvals, &result);
|
|
if (status != kSpecializeSuccess) {
|
|
*errcode = status;
|
|
return nullptr;
|
|
}
|
|
argvals = result.first;
|
|
AbstractBasePtr unique_output = result.second;
|
|
|
|
auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
|
|
if (prim_func != nullptr) {
|
|
auto type_func = std::make_shared<TypedPrimitiveAbstractClosure>(prim_func->prim(), argvals, unique_output);
|
|
return BuildValueNode(prim_func->prim(), type_func);
|
|
}
|
|
|
|
if (!eval->isa<BaseFuncGraphEvaluator>()) {
|
|
MS_LOG(EXCEPTION) << "Eval is not BaseGraphEvaluator, but " << eval->ToString();
|
|
}
|
|
auto real_eval = dyn_cast<BaseFuncGraphEvaluator>(eval);
|
|
|
|
if (func->context() != nullptr) {
|
|
if (!IsVisible(func_graph_, func->context()->func_graph())) {
|
|
MS_LOG(EXCEPTION) << "Func is not visible NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
|
|
}
|
|
} else {
|
|
MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
|
|
}
|
|
AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals);
|
|
MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size()
|
|
<< ", graph: " << context->func_graph()->get_return()->DebugString();
|
|
FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context);
|
|
return BuildValueNode(v, abs);
|
|
}
|
|
|
|
const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
|
|
auto cache_iter = evalcaches_.find(eval);
|
|
if (cache_iter == evalcaches_.end()) {
|
|
evalcaches_[eval] = eval->cache();
|
|
return eval->cache();
|
|
}
|
|
return cache_iter->second;
|
|
}
|
|
|
|
std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromBroadedArgsVal(
|
|
const EvaluatorPtr &eval) {
|
|
MS_EXCEPTION_IF_NULL(eval);
|
|
std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
|
|
AbstractBasePtr ret = nullptr;
|
|
AbstractBasePtrList broaded_argvals;
|
|
for (auto &argvals_map : *evalcaches_[eval]) {
|
|
auto argvals = argvals_map.first;
|
|
broaded_argvals.clear();
|
|
|
|
(void)std::transform(argvals.begin(), argvals.end(), std::back_inserter(broaded_argvals),
|
|
[](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); });
|
|
(void)choices.insert(broaded_argvals);
|
|
MS_LOG(DEBUG) << "Broaded_argvals: " << broaded_argvals.size() << ", " << ::mindspore::ToString(broaded_argvals);
|
|
}
|
|
|
|
if (1 == choices.size()) {
|
|
ConfigPtrList args_conf_list;
|
|
(void)std::transform(broaded_argvals.begin(), broaded_argvals.end(), std::back_inserter(args_conf_list),
|
|
[](AbstractBasePtr v) -> ConfigPtr { return std::make_shared<VirtualConfig>(v); });
|
|
|
|
// if broaden return null
|
|
ret = eval->Run(engine_, args_conf_list, nullptr);
|
|
EvaluatorCacheMapPtr real = std::make_shared<EvaluatorCacheMap>();
|
|
|
|
(*real)[broaded_argvals] = ret;
|
|
evalcaches_[eval] = real;
|
|
return std::make_pair(broaded_argvals, ret);
|
|
} else {
|
|
MS_LOG(DEBUG) << "Choices.size: " << choices.size();
|
|
return std::make_pair(AbstractBasePtrList(), nullptr);
|
|
}
|
|
}
|
|
|
|
void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
|
|
MS_EXCEPTION_IF_NULL(new_node);
|
|
if (specializer_->seen().count(new_node) > 0) {
|
|
return;
|
|
}
|
|
specializer_->AddSeen(new_node);
|
|
|
|
auto new_inputs = new_node->inputs();
|
|
if (new_inputs.empty()) {
|
|
MS_LOG(EXCEPTION) << "Inputs of CNode is empty";
|
|
}
|
|
AnfNodePtr func = new_inputs[0];
|
|
MS_EXCEPTION_IF_NULL(func);
|
|
|
|
// First element is func so arg start from 1
|
|
std::vector<AnfNodePtr> args(new_inputs.begin() + 1, new_inputs.end());
|
|
// CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...)
|
|
while (IsPrimitiveCNode(func, prim::kPrimPartial)) {
|
|
std::vector<AnfNodePtr> inputs = func->cast<CNodePtr>()->inputs();
|
|
// First element is partial, second is func so arg is start from 2
|
|
(void)args.insert(args.begin(), inputs.begin() + 2, inputs.end());
|
|
func = inputs[1];
|
|
new_inputs = args;
|
|
(void)new_inputs.insert(new_inputs.begin(), func);
|
|
}
|
|
|
|
AbstractBasePtrList argvals;
|
|
MS_EXCEPTION_IF_NULL(new_inputs[0]);
|
|
AbstractBasePtr fnval = new_inputs[0]->abstract();
|
|
MS_LOG(DEBUG) << "The new_inputs[0] node: pointer: " << new_inputs[0]->ToString() << ", "
|
|
<< new_inputs[0]->DebugString() << ", abstract: " << new_inputs[0]->abstract()->ToString();
|
|
|
|
// First element is func so function arguments start from 1
|
|
for (size_t i = 1; i < new_inputs.size(); ++i) {
|
|
argvals.push_back(new_inputs[i]->abstract());
|
|
MS_LOG(DEBUG) << "The new_inputs[" << i << "] node: pointer: " << new_inputs[i]->ToString() << ", "
|
|
<< new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString();
|
|
}
|
|
|
|
if (CanSpecializeNode(func)) {
|
|
new_inputs[0] = BuildSpecializedNode(func, fnval, argvals);
|
|
}
|
|
|
|
for (size_t i = 0; i < argvals.size();) {
|
|
size_t next = i + 1;
|
|
if (CanSpecializeNode(args[i])) {
|
|
new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector<AbstractBasePtr>{});
|
|
}
|
|
// support for partial(Multitype) which Multitype should not be inferred to POLY.
|
|
// after one or more times clone, Multitype metafuncgraph evaluator will specialized to one type only,
|
|
// so even with partial parameter, it will specialize to that graph.
|
|
// Maybe a better idea should inline graph with partial node first, then it will have full
|
|
// parameter list to infer and specialize.
|
|
MS_EXCEPTION_IF_NULL(new_inputs[next]);
|
|
if (new_inputs[next]->isa<ValueNode>() && (GetValueNode(new_inputs[next]) == kPolyNode) &&
|
|
IsPrimitive(func, prim::kPrimPartial)) {
|
|
new_inputs[next] = args[i];
|
|
}
|
|
i = next;
|
|
}
|
|
|
|
new_node->set_inputs(new_inputs);
|
|
}
|
|
|
|
namespace {
|
|
void DumpEvaluatorCache(const EvaluatorCacheMap &evaluator_cache_map, const AbstractBasePtrList &argvals) {
|
|
MS_LOG(DEBUG) << "Find unique argvals failed: " << argvals.size() << ", " << argvals << ". Check cache all items.";
|
|
int i = 0;
|
|
for (const auto &item : evaluator_cache_map) {
|
|
MS_LOG(DEBUG) << "evaluator_cache_map[" << i++ << "]: " << item.first;
|
|
}
|
|
}
|
|
|
|
bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argvals) {
|
|
if (func->isa<PrimitiveAbstractClosure>() && argvals.empty()) {
|
|
MS_LOG(DEBUG) << "High order primitive return POLY.";
|
|
return true;
|
|
}
|
|
if (func->isa<MetaFuncGraphAbstractClosure>() && argvals.empty()) {
|
|
auto meta_func_graph_wrapper = dyn_cast<MetaFuncGraphAbstractClosure>(func);
|
|
auto meta_func_graph = meta_func_graph_wrapper->meta_func_graph();
|
|
if (meta_func_graph != nullptr && meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>()) {
|
|
auto do_signature = dyn_cast<prim::DoSignatureMetaFuncGraph>(meta_func_graph);
|
|
if (do_signature != nullptr && do_signature->function()->isa<Primitive>()) {
|
|
MS_LOG(DEBUG) << "High order primitive " << do_signature->function()->ToString() << " return POLY.";
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
} // end anonymous namespace
|
|
|
|
SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunctionPtr &func, const EvaluatorPtr &eval,
|
|
const AbstractBasePtrList &argvals,
|
|
std::pair<AbstractBasePtrList, AbstractBasePtr> *result) {
|
|
MS_EXCEPTION_IF_NULL(func);
|
|
MS_EXCEPTION_IF_NULL(eval);
|
|
MS_EXCEPTION_IF_NULL(result);
|
|
|
|
EvaluatorCacheMap evaluator_cache_map = *eval->cache();
|
|
if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) {
|
|
*result = std::make_pair(argvals, evaluator_cache_map[argvals]);
|
|
return kSpecializeSuccess;
|
|
}
|
|
DumpEvaluatorCache(evaluator_cache_map, argvals);
|
|
|
|
const EvaluatorCacheMapPtr &choices = GetEvalCache(eval);
|
|
MS_EXCEPTION_IF_NULL(choices);
|
|
|
|
if (choices->count(argvals)) {
|
|
*result = std::make_pair(argvals, (*choices)[argvals]);
|
|
return kSpecializeSuccess;
|
|
} else if (choices->size() == 1) {
|
|
MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
|
|
*result = std::make_pair(choices->begin()->first, choices->begin()->second);
|
|
return kSpecializeSuccess;
|
|
} else if (choices->empty()) {
|
|
MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase.";
|
|
return kSpecializeFindUniqueArgvalDead;
|
|
} else {
|
|
if (IsPolyFunc(func, argvals)) {
|
|
return kSpecializeFindUniqueArgvalPoly;
|
|
}
|
|
|
|
MS_LOG(DEBUG) << "Try to find generalized argvals.";
|
|
*result = BuildFromBroadedArgsVal(eval);
|
|
if (!result->first.empty()) {
|
|
return kSpecializeSuccess;
|
|
}
|
|
MS_LOG(DEBUG) << "Find POLY code, it may be unused code or unresolved polymorphism.";
|
|
return kSpecializeFindUniqueArgvalPoly;
|
|
}
|
|
}
|
|
|
|
AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival) {
|
|
MS_EXCEPTION_IF_NULL(origin_node);
|
|
MS_EXCEPTION_IF_NULL(ival);
|
|
|
|
AbstractFunctionPtr abs = dyn_cast<AbstractFunction>(ival);
|
|
if (abs != nullptr) {
|
|
// Cannot build a determinstic ValueNode if there are multiple possible AbstractFunction.
|
|
if (abs->isa<AbstractFuncUnion>()) {
|
|
return nullptr;
|
|
}
|
|
ValuePtr value = nullptr;
|
|
if (abs->isa<PrimitiveAbstractClosure>()) {
|
|
auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
|
|
value = real_fn->prim();
|
|
} else if (abs->isa<MetaFuncGraphAbstractClosure>()) {
|
|
auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs);
|
|
value = real_fn->meta_func_graph();
|
|
} else if (abs->isa<FuncGraphAbstractClosure>()) {
|
|
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(abs);
|
|
value = real_fn->func_graph();
|
|
} else {
|
|
return nullptr;
|
|
}
|
|
if (!value->isa<FuncGraph>() || value->cast<FuncGraphPtr>()->parent() == nullptr ||
|
|
(IsValueNode<FuncGraph>(origin_node) && IsVisible(func_graph_, value->cast<FuncGraphPtr>()->parent()))) {
|
|
return BuildValueNode(value, ival);
|
|
} else {
|
|
return nullptr;
|
|
}
|
|
} else {
|
|
ValuePtr val = ival->BuildValue();
|
|
if (val->isa<AnyValue>()) {
|
|
return nullptr;
|
|
} else {
|
|
return BuildValueNode(val, ival);
|
|
}
|
|
}
|
|
}
|
|
|
|
AnfNodeConfigPtr FuncGraphSpecializer::MakeConfig(const AnfNodePtr &node) {
|
|
return engine_->MakeConfig(node, context_);
|
|
}
|
|
} // namespace abstract
|
|
} // namespace mindspore
|