|
|
|
@ -189,12 +189,8 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
|
|
|
|
|
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
|
|
|
|
func_graph_->joined_shapes_.clear();
|
|
|
|
|
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(),
|
|
|
|
|
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
|
|
|
|
|
if (arg_spec->isa<AbstractRef>()) {
|
|
|
|
|
return arg_spec->cast<AbstractRefPtr>()->ref()->GetShapeTrack();
|
|
|
|
|
}
|
|
|
|
|
return arg_spec->GetShapeTrack();
|
|
|
|
|
});
|
|
|
|
|
std::back_inserter(func_graph_->joined_shapes_),
|
|
|
|
|
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); });
|
|
|
|
|
joined_args_spec_list = NormalizeArgs(joined_args_spec_list);
|
|
|
|
|
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
|
|
|
|
|
}
|
|
|
|
@ -212,12 +208,8 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
|
|
|
|
|
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
|
|
|
|
func_graph_->joined_shapes_.clear();
|
|
|
|
|
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(),
|
|
|
|
|
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
|
|
|
|
|
if (arg_spec->isa<AbstractRef>()) {
|
|
|
|
|
return arg_spec->cast<AbstractRefPtr>()->ref()->GetShapeTrack();
|
|
|
|
|
}
|
|
|
|
|
return arg_spec->GetShapeTrack();
|
|
|
|
|
});
|
|
|
|
|
std::back_inserter(func_graph_->joined_shapes_),
|
|
|
|
|
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); });
|
|
|
|
|
joined_args_spec_list = NormalizeArgs(joined_args_spec_list);
|
|
|
|
|
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
|
|
|
|
|
}
|
|
|
|
@ -317,10 +309,17 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args
|
|
|
|
|
EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
|
|
|
|
AnfNodeConfigPtr) {
|
|
|
|
|
AbstractBasePtrList args_spec_list;
|
|
|
|
|
auto is_py_eval = (identifier_ == "PythonPrimEvaluator");
|
|
|
|
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
|
|
|
|
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
|
|
|
|
[is_py_eval](const ConfigPtr &conf) -> AbstractBasePtr {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(conf);
|
|
|
|
|
return conf->GetEvaluatedValue()->abstract();
|
|
|
|
|
auto abstract = conf->GetEvaluatedValue()->abstract();
|
|
|
|
|
// broaden the ref_key, while infer python prim for cache
|
|
|
|
|
if (is_py_eval && abstract->isa<AbstractRef>()) {
|
|
|
|
|
auto abs_ref = abstract->cast<AbstractRefPtr>();
|
|
|
|
|
abstract = std::make_shared<AbstractRef>(abs_ref->ref_key()->Broaden(), abs_ref);
|
|
|
|
|
}
|
|
|
|
|
return abstract;
|
|
|
|
|
});
|
|
|
|
|
EvalResultPtr ret = EvalPrim(engine, args_spec_list);
|
|
|
|
|
return ret;
|
|
|
|
|