|
|
@ -26,8 +26,8 @@
|
|
|
|
namespace mindspore {
|
|
|
|
namespace mindspore {
|
|
|
|
namespace abstract {
|
|
|
|
namespace abstract {
|
|
|
|
namespace {
|
|
|
|
namespace {
|
|
|
|
void InferEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &arg_spec_list,
|
|
|
|
void EvalEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &arg_spec_list,
|
|
|
|
const AnfNodeConfigPtr &out_conf) {
|
|
|
|
const AnfNodeConfigPtr &out_conf) {
|
|
|
|
MS_EXCEPTION_IF_NULL(evaluator);
|
|
|
|
MS_EXCEPTION_IF_NULL(evaluator);
|
|
|
|
if (out_conf != nullptr) {
|
|
|
|
if (out_conf != nullptr) {
|
|
|
|
MS_LOG(DEBUG) << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name();
|
|
|
|
MS_LOG(DEBUG) << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name();
|
|
|
@ -37,7 +37,7 @@ void InferEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void InferFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, const AnfNodeConfigPtr &out_conf) {
|
|
|
|
void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, const AnfNodeConfigPtr &out_conf) {
|
|
|
|
MS_EXCEPTION_IF_NULL(evaluator);
|
|
|
|
MS_EXCEPTION_IF_NULL(evaluator);
|
|
|
|
if (out_conf != nullptr) {
|
|
|
|
if (out_conf != nullptr) {
|
|
|
|
auto node = out_conf->node();
|
|
|
|
auto node = out_conf->node();
|
|
|
@ -89,7 +89,7 @@ static std::vector<AnfNodePtr> FastShadowSort(const AnfNodePtr &ret_node) {
|
|
|
|
return sorted_nodes;
|
|
|
|
return sorted_nodes;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr BaseFuncGraphEvaluator::Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
|
|
|
|
AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
|
|
|
|
FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list);
|
|
|
|
FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list);
|
|
|
|
MS_EXCEPTION_IF_NULL(fg);
|
|
|
|
MS_EXCEPTION_IF_NULL(fg);
|
|
|
|
std::size_t nargs = fg->parameters().size();
|
|
|
|
std::size_t nargs = fg->parameters().size();
|
|
|
@ -124,7 +124,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::Infer(AnalysisEnginePtr engine, const Ab
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ret_base);
|
|
|
|
MS_EXCEPTION_IF_NULL(ret_base);
|
|
|
|
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " infer end, inferred abstract: " << ret_base->ToString();
|
|
|
|
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " Eval end, evaluated abstract: " << ret_base->ToString();
|
|
|
|
return ret_base;
|
|
|
|
return ret_base;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -155,7 +155,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
|
|
|
|
<< ", context: " << parent_context_->ToString();
|
|
|
|
<< ", context: " << parent_context_->ToString();
|
|
|
|
auto last_context = parent_context_->Filter(func_graph_);
|
|
|
|
auto last_context = parent_context_->Filter(func_graph_);
|
|
|
|
if (last_context && last_context->func_graph() == func_graph_) {
|
|
|
|
if (last_context && last_context->func_graph() == func_graph_) {
|
|
|
|
MS_LOG(DEBUG) << "Find last infer context: " << last_context->ToString();
|
|
|
|
MS_LOG(DEBUG) << "Find last eval context: " << last_context->ToString();
|
|
|
|
MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list);
|
|
|
|
MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list);
|
|
|
|
MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(last_context->args_spec_list());
|
|
|
|
MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(last_context->args_spec_list());
|
|
|
|
// Join the last eval arguments and current arguments to check if there are loop variant.
|
|
|
|
// Join the last eval arguments and current arguments to check if there are loop variant.
|
|
|
@ -248,26 +248,26 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar
|
|
|
|
});
|
|
|
|
});
|
|
|
|
args_spec_list = NormalizeArgs(args_spec_list);
|
|
|
|
args_spec_list = NormalizeArgs(args_spec_list);
|
|
|
|
args_spec_list = BroadenUndeterminedArgs(args_spec_list);
|
|
|
|
args_spec_list = BroadenUndeterminedArgs(args_spec_list);
|
|
|
|
trace::TraceGraphInferEnter(shared_from_base<Evaluator>(), out_conf);
|
|
|
|
trace::TraceGraphEvalEnter(shared_from_base<Evaluator>(), out_conf);
|
|
|
|
InferEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
|
|
|
|
EvalEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
|
|
|
|
MS_EXCEPTION_IF_NULL(cache_);
|
|
|
|
MS_EXCEPTION_IF_NULL(cache_);
|
|
|
|
auto iter = cache_->find(args_spec_list);
|
|
|
|
auto iter = cache_->find(args_spec_list);
|
|
|
|
if (iter == cache_->end()) {
|
|
|
|
if (iter == cache_->end()) {
|
|
|
|
MS_LOG(DEBUG) << evaluator_name << " cache miss, call Infer().";
|
|
|
|
MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval().";
|
|
|
|
AbstractBasePtr ret = Infer(engine, args_spec_list);
|
|
|
|
AbstractBasePtr ret = Eval(engine, args_spec_list);
|
|
|
|
if (ret == nullptr) {
|
|
|
|
if (ret == nullptr) {
|
|
|
|
InferFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
|
|
|
|
EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
|
|
|
|
MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
|
|
|
|
MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
MS_EXCEPTION_IF_NULL(ret);
|
|
|
|
MS_EXCEPTION_IF_NULL(ret);
|
|
|
|
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->ToString() << ".";
|
|
|
|
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->ToString() << ".";
|
|
|
|
(*cache_)[args_spec_list] = ret;
|
|
|
|
(*cache_)[args_spec_list] = ret;
|
|
|
|
trace::TraceGraphInferLeave(shared_from_base<Evaluator>());
|
|
|
|
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
|
|
|
|
return ret;
|
|
|
|
return ret;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
MS_EXCEPTION_IF_NULL(iter->second);
|
|
|
|
MS_EXCEPTION_IF_NULL(iter->second);
|
|
|
|
MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->ToString() << ".";
|
|
|
|
MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->ToString() << ".";
|
|
|
|
trace::TraceGraphInferLeave(shared_from_base<Evaluator>());
|
|
|
|
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
|
|
|
|
return iter->second;
|
|
|
|
return iter->second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -378,7 +378,7 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a
|
|
|
|
return jtuple;
|
|
|
|
return jtuple;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr VirtualEvaluator::Infer(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) {
|
|
|
|
AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) {
|
|
|
|
if (args_spec_list.size() != args_spec_list_.size()) {
|
|
|
|
if (args_spec_list.size() != args_spec_list_.size()) {
|
|
|
|
MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size()
|
|
|
|
MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size()
|
|
|
|
<< ", arguments no: " << args_spec_list.size();
|
|
|
|
<< ", arguments no: " << args_spec_list.size();
|
|
|
|