|
|
|
@ -55,29 +55,29 @@ AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBase
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
|
|
|
|
|
void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
|
|
|
|
|
MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString()
|
|
|
|
|
<< ", Context: " << conf->context()->ToString() << ", Value: " << arg->ToString()
|
|
|
|
|
<< ", Pointer: " << arg.get();
|
|
|
|
|
cache_[conf] = arg;
|
|
|
|
|
<< ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString()
|
|
|
|
|
<< ", Pointer: " << result->abstract().get();
|
|
|
|
|
cache_[conf] = result;
|
|
|
|
|
|
|
|
|
|
// Set intermediate abstract value.
|
|
|
|
|
if (IsIntermediateAbstract(arg)) {
|
|
|
|
|
if (IsIntermediateAbstract(result->abstract())) {
|
|
|
|
|
if (conf->node()->intermediate_abstract() == nullptr) {
|
|
|
|
|
conf->node()->set_intermediate_abstract(arg);
|
|
|
|
|
MS_LOG(DEBUG) << "Set intermediate abstract: " << arg->ToString();
|
|
|
|
|
conf->node()->set_intermediate_abstract(result->abstract());
|
|
|
|
|
MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString();
|
|
|
|
|
} else {
|
|
|
|
|
auto old_spec = conf->node()->intermediate_abstract();
|
|
|
|
|
auto joined_spec = IntermediateJoin(arg, old_spec);
|
|
|
|
|
auto joined_spec = IntermediateJoin(result->abstract(), old_spec);
|
|
|
|
|
conf->node()->set_intermediate_abstract(joined_spec);
|
|
|
|
|
MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t"
|
|
|
|
|
<< arg->ToString() << "\njoined_spec:\t"
|
|
|
|
|
<< result->abstract()->ToString() << "\njoined_spec:\t"
|
|
|
|
|
<< (joined_spec != nullptr ? joined_spec->ToString() : "nullptr");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) {
|
|
|
|
|
EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) {
|
|
|
|
|
auto value = cache_.find(conf);
|
|
|
|
|
if (value == cache_.end()) {
|
|
|
|
|
return nullptr;
|
|
|
|
@ -142,12 +142,12 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana
|
|
|
|
|
return eval->graph_context();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) {
|
|
|
|
|
EvalResultPtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(conf);
|
|
|
|
|
auto value = cache_.GetValue(conf);
|
|
|
|
|
if (value != nullptr) {
|
|
|
|
|
MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value.get() << ", "
|
|
|
|
|
<< value->ToString();
|
|
|
|
|
MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value->abstract().get()
|
|
|
|
|
<< ", " << value->abstract()->ToString();
|
|
|
|
|
return value;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -160,10 +160,10 @@ AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf)
|
|
|
|
|
return value;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
|
|
|
|
|
EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(conf);
|
|
|
|
|
AnfNodePtr node = conf->node();
|
|
|
|
|
AbstractBasePtr ret_abstract = nullptr;
|
|
|
|
|
EvalResultPtr eval_result = nullptr;
|
|
|
|
|
#ifdef DEBUG
|
|
|
|
|
compute_conf_stack_.push_back(node);
|
|
|
|
|
std::ostringstream buffer;
|
|
|
|
@ -177,14 +177,14 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (node->abstract() != nullptr) {
|
|
|
|
|
MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString();
|
|
|
|
|
ret_abstract = node->abstract();
|
|
|
|
|
eval_result = std::make_shared<EvalResult>(node->abstract(), std::make_shared<AttrValueMap>());
|
|
|
|
|
} else if (node->isa<ValueNode>()) {
|
|
|
|
|
auto value_node = node->cast<ValueNodePtr>();
|
|
|
|
|
ret_abstract = EvalValueNode(value_node, conf);
|
|
|
|
|
eval_result = std::make_shared<EvalResult>(EvalValueNode(value_node, conf), nullptr);
|
|
|
|
|
} else if (node->isa<CNode>()) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
trace::TraceEvalCNodeEnter(conf);
|
|
|
|
|
ret_abstract = EvalCNode(cnode, conf);
|
|
|
|
|
eval_result = EvalCNode(cnode, conf);
|
|
|
|
|
trace::TraceEvalCNodeLeave();
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString()
|
|
|
|
@ -193,13 +193,13 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
|
|
|
|
|
|
|
|
|
|
#ifdef DEBUG
|
|
|
|
|
compute_conf_stack_.pop_back();
|
|
|
|
|
if (ret_abstract == nullptr) {
|
|
|
|
|
if (eval_result == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString()
|
|
|
|
|
<< " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << ret_abstract->ToString();
|
|
|
|
|
return ret_abstract;
|
|
|
|
|
MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString();
|
|
|
|
|
return eval_result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) {
|
|
|
|
@ -208,7 +208,7 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co
|
|
|
|
|
return ToAbstract(value_node->value(), conf->context(), conf);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
|
|
|
|
|
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(conf);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
auto &inputs = cnode->inputs();
|
|
|
|
@ -223,7 +223,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo
|
|
|
|
|
AnfNodeConfigPtr func_conf = MakeConfig(func_node, context);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_conf);
|
|
|
|
|
// Keep it in a local variable, otherwise smart pointer will free it.
|
|
|
|
|
AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue();
|
|
|
|
|
AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue()->abstract();
|
|
|
|
|
if (maybe_func == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString()
|
|
|
|
|
<< " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
|
|
|
|
@ -253,7 +253,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo
|
|
|
|
|
return ExecuteEvaluators(infs, conf, args_conf_list);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) {
|
|
|
|
|
EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) {
|
|
|
|
|
ConfigPtrList args_conf_list;
|
|
|
|
|
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
|
|
|
|
|
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
|
|
|
|
@ -454,9 +454,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
|
|
|
|
return tracked_eval;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators,
|
|
|
|
|
const AnfNodeConfigPtr &out_conf,
|
|
|
|
|
const ConfigPtrList &args_conf_list) {
|
|
|
|
|
EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators,
|
|
|
|
|
const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) {
|
|
|
|
|
if (evaluators.size() == 1) {
|
|
|
|
|
EvaluatorPtr eval = evaluators[0];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(eval);
|
|
|
|
@ -465,9 +464,9 @@ AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr
|
|
|
|
|
return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
|
|
|
|
|
const AnfNodeConfigPtr &out_conf,
|
|
|
|
|
const ConfigPtrList &args_conf_list) {
|
|
|
|
|
EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
|
|
|
|
|
const AnfNodeConfigPtr &out_conf,
|
|
|
|
|
const ConfigPtrList &args_conf_list) {
|
|
|
|
|
AbstractBasePtrList out_specs;
|
|
|
|
|
if (!multi_poss_.count(evaluators[0])) {
|
|
|
|
|
multi_poss_[evaluators[0]] = evaluators[1];
|
|
|
|
@ -477,7 +476,7 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
|
|
|
|
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
|
|
|
|
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(conf);
|
|
|
|
|
return conf->GetEvaluatedValue();
|
|
|
|
|
return conf->GetEvaluatedValue()->abstract();
|
|
|
|
|
});
|
|
|
|
|
for (auto eval : evaluators) {
|
|
|
|
|
auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>();
|
|
|
|
@ -502,11 +501,10 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
|
|
|
|
|
eval_trace_.push_back(current_inf);
|
|
|
|
|
MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(eval);
|
|
|
|
|
auto out_spec = eval->Run(shared_from_this(), args_conf_list, out_conf);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(out_spec);
|
|
|
|
|
MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << out_spec->ToString();
|
|
|
|
|
out_specs.push_back(out_spec);
|
|
|
|
|
MS_LOG(DEBUG) << "Pop Evaluator " << eval->ToString();
|
|
|
|
|
auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(eval_result->abstract());
|
|
|
|
|
MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString();
|
|
|
|
|
out_specs.push_back(eval_result->abstract());
|
|
|
|
|
eval_trace_.pop_back();
|
|
|
|
|
if (eval_trace_.empty()) {
|
|
|
|
|
multi_poss_.clear();
|
|
|
|
@ -552,10 +550,11 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
|
|
|
|
|
// Try to travel the latest undetermined.
|
|
|
|
|
if (latest_entry != eval_trace_.rbegin()->first) {
|
|
|
|
|
MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString();
|
|
|
|
|
auto out_spec = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(out_spec);
|
|
|
|
|
MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() << " return out_spec: " << out_spec->ToString();
|
|
|
|
|
return out_spec;
|
|
|
|
|
auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(eval_result->abstract());
|
|
|
|
|
MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString()
|
|
|
|
|
<< " return out_spec: " << eval_result->abstract()->ToString();
|
|
|
|
|
return eval_result;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -566,15 +565,15 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
|
|
|
|
|
if (out_specs.size() == 1) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(out_specs[0]);
|
|
|
|
|
// If only one result derived, then broaden it to avoid wrong constant propagation.
|
|
|
|
|
return out_specs[0]->Broaden();
|
|
|
|
|
return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>());
|
|
|
|
|
}
|
|
|
|
|
auto joined_spec = AbstractJoin(out_specs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(joined_spec);
|
|
|
|
|
MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
|
|
|
|
|
return joined_spec;
|
|
|
|
|
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr AnfNodeConfig::GetEvaluatedValue() {
|
|
|
|
|
EvalResultPtr AnfNodeConfig::GetEvaluatedValue() {
|
|
|
|
|
AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>();
|
|
|
|
|
return engine_.lock()->GetEvaluatedValue(self);
|
|
|
|
|
}
|
|
|
|
@ -607,7 +606,7 @@ AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
|
|
|
|
|
return a;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
|
|
|
|
|
EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
|
|
|
|
|
auto evaluator = GetPrimEvaluator(primitive, nullptr);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(evaluator);
|
|
|
|
|
if (!evaluator->isa<TrivialPrimEvaluator>()) {
|
|
|
|
@ -615,8 +614,8 @@ AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtr
|
|
|
|
|
<< evaluator->ToString();
|
|
|
|
|
}
|
|
|
|
|
auto trivial_evaluator = dyn_cast<TrivialPrimEvaluator>(evaluator);
|
|
|
|
|
auto res_spec = trivial_evaluator->EvalPrim(nullptr, arg_specs);
|
|
|
|
|
return res_spec;
|
|
|
|
|
auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs);
|
|
|
|
|
return eval_result;
|
|
|
|
|
}
|
|
|
|
|
} // namespace abstract
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|