|
|
|
@ -1,5 +1,5 @@
|
|
|
|
|
/**
|
|
|
|
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
|
|
|
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
|
|
|
|
*
|
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
@ -83,11 +83,11 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|
|
|
|
const auto &arg = args_spec_list[i];
|
|
|
|
|
const auto &node = parameters[i];
|
|
|
|
|
AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_);
|
|
|
|
|
engine->cache().set_value(conf, std::make_shared<EvalResult>(arg, nullptr));
|
|
|
|
|
engine->analysis_cache().set_value(conf, std::make_shared<EvalResult>(arg, nullptr));
|
|
|
|
|
}
|
|
|
|
|
const AnfNodePtr &func_node = fg->get_return();
|
|
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << fg->ToString()
|
|
|
|
|
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString()
|
|
|
|
|
<< ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString()
|
|
|
|
|
<< ", current function call depth: " << engine->function_call_depth();
|
|
|
|
|
AbstractBasePtr ret_base = nullptr;
|
|
|
|
@ -97,37 +97,20 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|
|
|
|
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
|
|
|
|
|
<< ", please call 'context.set_context(max_call_depth=value)' to adjust this value.";
|
|
|
|
|
}
|
|
|
|
|
// Analysis for isolate nodes first, as some validation check in FuncGraph is isolate nodes;
|
|
|
|
|
for (const auto &node : fg->GetIsolateNodesInOrder()) {
|
|
|
|
|
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
|
|
|
|
|
MS_LOG(DEBUG) << "Analysis isolate_node begin, func graph: " << fg.get() << fg->ToString()
|
|
|
|
|
<< ", node_conf: " << node_conf->ToString();
|
|
|
|
|
auto isolate_base = engine->GetEvaluatedValue(node_conf)->abstract();
|
|
|
|
|
MS_LOG(DEBUG) << "Analysis isolate_node end, func graph: " << fg.get() << fg->ToString()
|
|
|
|
|
<< ", node_conf: " << node_conf->ToString() << ", abstract: " << isolate_base->ToString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const auto &all_nodes = TopoSort(func_node, SuccIncoming, [&fg](const AnfNodePtr &node) -> IncludeType {
|
|
|
|
|
if (node->func_graph() != fg || node->isa<ValueNode>()) {
|
|
|
|
|
return EXCLUDE;
|
|
|
|
|
}
|
|
|
|
|
return FOLLOW;
|
|
|
|
|
});
|
|
|
|
|
bool isolate_node_propagate_flag = false;
|
|
|
|
|
for (const auto &node : all_nodes) {
|
|
|
|
|
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
|
|
|
|
|
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << fg->ToString()
|
|
|
|
|
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg << "/" << fg->ToString()
|
|
|
|
|
<< ", node_conf: " << node_conf->ToString();
|
|
|
|
|
auto node_eval_result = engine->GetEvaluatedValue(node_conf);
|
|
|
|
|
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
|
|
|
|
|
ret_base = node_eval_result->abstract();
|
|
|
|
|
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString()
|
|
|
|
|
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg << "/" << fg->ToString()
|
|
|
|
|
<< ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString();
|
|
|
|
|
if (node->isa<CNode>()) {
|
|
|
|
|
isolate_node_propagate_flag |= node_eval_result->HasIsolateNodesPropagateCNodeFlag();
|
|
|
|
|
MS_LOG(DEBUG) << "Check isolate_nodes flag for node: " << node->DebugString()
|
|
|
|
|
<< ", abstract: " << ret_base->ToString()
|
|
|
|
|
<< ", flag: " << node_eval_result->HasIsolateNodesPropagateCNodeFlag();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
engine->DecreaseFunctionCallDepth();
|
|
|
|
|
|
|
|
|
@ -138,12 +121,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|
|
|
|
if (fg->stub()) {
|
|
|
|
|
ret_base = std::make_shared<AbstractUndetermined>();
|
|
|
|
|
}
|
|
|
|
|
auto eval_result = std::make_shared<EvalResult>(ret_base, std::make_shared<AttrValueMap>());
|
|
|
|
|
if (isolate_node_propagate_flag) {
|
|
|
|
|
eval_result->SetIsolateNodesPropagateCNodeFlag(true);
|
|
|
|
|
eval_result->SetIsolateNodesPropagateFuncGraphFlag(true);
|
|
|
|
|
}
|
|
|
|
|
return eval_result;
|
|
|
|
|
return std::make_shared<EvalResult>(ret_base, nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
|
|
|
|
@ -280,15 +258,15 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args
|
|
|
|
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
|
|
|
|
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(conf);
|
|
|
|
|
return conf->GetEvaluatedValue()->abstract();
|
|
|
|
|
return conf->ObtainEvalResult()->abstract();
|
|
|
|
|
});
|
|
|
|
|
args_spec_list = NormalizeArgs(args_spec_list);
|
|
|
|
|
args_spec_list = BroadenUndeterminedArgs(args_spec_list);
|
|
|
|
|
trace::TraceGraphEvalEnter(shared_from_base<Evaluator>(), out_conf);
|
|
|
|
|
MS_LOG(DEBUG) << EvalEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cache_);
|
|
|
|
|
auto iter = cache_->find(args_spec_list);
|
|
|
|
|
if (iter == cache_->end()) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(evaluator_cache_map_);
|
|
|
|
|
auto iter = evaluator_cache_map_->find(args_spec_list);
|
|
|
|
|
if (iter == evaluator_cache_map_->end()) {
|
|
|
|
|
MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval().";
|
|
|
|
|
EvalResultPtr ret = Eval(engine, args_spec_list);
|
|
|
|
|
if (ret->abstract() == nullptr) {
|
|
|
|
@ -296,7 +274,7 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args
|
|
|
|
|
MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->abstract()->ToString() << ".";
|
|
|
|
|
(*cache_)[args_spec_list] = ret;
|
|
|
|
|
(*evaluator_cache_map_)[args_spec_list] = ret;
|
|
|
|
|
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
|
|
|
|
|
return ret;
|
|
|
|
|
} else {
|
|
|
|
@ -315,7 +293,7 @@ EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|
|
|
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
|
|
|
|
[is_py_eval](const ConfigPtr &conf) -> AbstractBasePtr {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(conf);
|
|
|
|
|
auto abstract = conf->GetEvaluatedValue()->abstract();
|
|
|
|
|
auto abstract = conf->ObtainEvalResult()->abstract();
|
|
|
|
|
// broaden the ref_key, while infer python prim for cache
|
|
|
|
|
if (is_py_eval && abstract->isa<AbstractRef>()) {
|
|
|
|
|
auto abs_ref = abstract->cast<AbstractRefPtr>();
|
|
|
|
@ -333,7 +311,7 @@ EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const Confi
|
|
|
|
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
|
|
|
|
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(conf);
|
|
|
|
|
return conf->GetEvaluatedValue()->abstract();
|
|
|
|
|
return conf->ObtainEvalResult()->abstract();
|
|
|
|
|
});
|
|
|
|
|
if (args_conf_list.size() == 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Size should greater than 0";
|
|
|
|
@ -354,12 +332,12 @@ EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrLis
|
|
|
|
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
|
|
|
|
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(conf);
|
|
|
|
|
return conf->GetEvaluatedValue()->abstract();
|
|
|
|
|
return conf->ObtainEvalResult()->abstract();
|
|
|
|
|
});
|
|
|
|
|
EvalResultPtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf);
|
|
|
|
|
// Don't lookup from cache, as different out_conf with same node but different context
|
|
|
|
|
// may add different entry to anfnode_config_map_, like getattr primitive.
|
|
|
|
|
(*cache_)[args_spec_list] = ret;
|
|
|
|
|
(*evaluator_cache_map_)[args_spec_list] = ret;
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -369,11 +347,11 @@ EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtr
|
|
|
|
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
|
|
|
|
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(conf);
|
|
|
|
|
return conf->GetEvaluatedValue()->abstract();
|
|
|
|
|
return conf->ObtainEvalResult()->abstract();
|
|
|
|
|
});
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cache_);
|
|
|
|
|
auto iter = cache_->find(args_spec_list);
|
|
|
|
|
if (iter != cache_->end()) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(evaluator_cache_map_);
|
|
|
|
|
auto iter = evaluator_cache_map_->find(args_spec_list);
|
|
|
|
|
if (iter != evaluator_cache_map_->end()) {
|
|
|
|
|
return iter->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -386,7 +364,7 @@ EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtr
|
|
|
|
|
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
|
|
|
|
|
EvalResultPtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf);
|
|
|
|
|
|
|
|
|
|
(*cache_)[args_spec_list] = ret;
|
|
|
|
|
(*evaluator_cache_map_)[args_spec_list] = ret;
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -395,11 +373,11 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
|
|
|
|
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
|
|
|
|
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(conf);
|
|
|
|
|
return conf->GetEvaluatedValue()->abstract();
|
|
|
|
|
return conf->ObtainEvalResult()->abstract();
|
|
|
|
|
});
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cache_);
|
|
|
|
|
auto iter = cache_->find(args_spec_list);
|
|
|
|
|
if (iter != cache_->end()) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(evaluator_cache_map_);
|
|
|
|
|
auto iter = evaluator_cache_map_->find(args_spec_list);
|
|
|
|
|
if (iter != evaluator_cache_map_->end()) {
|
|
|
|
|
return iter->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -427,7 +405,7 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
|
|
|
|
|
AbstractBasePtrList jargs = {result->abstract(), bprop};
|
|
|
|
|
AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs);
|
|
|
|
|
auto infer_reuslt = std::make_shared<EvalResult>(jtuple, std::make_shared<AttrValueMap>());
|
|
|
|
|
(*cache_)[args_spec_list] = infer_reuslt;
|
|
|
|
|
(*evaluator_cache_map_)[args_spec_list] = infer_reuslt;
|
|
|
|
|
return infer_reuslt;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|