@ -55,29 +55,29 @@ AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBase
return nullptr ;
}
void AnalysisCache : : set_value ( const AnfNodeConfigPtr & conf , const AbstractBasePtr & arg ) {
void AnalysisCache : : set_value ( const AnfNodeConfigPtr & conf , const EvalResultPtr & result ) {
MS_LOG ( DEBUG ) < < " AnalysisCache set for NodeConfig: " < < conf - > node ( ) - > DebugString ( )
< < " , Context: " < < conf - > context ( ) - > ToString ( ) < < " , Value: " < < arg - > ToString ( )
< < " , Pointer: " < < arg . get ( ) ;
cache_ [ conf ] = arg ;
< < " , Context: " < < conf - > context ( ) - > ToString ( ) < < " , Value: " < < result- > abstract ( ) - > ToString ( )
< < " , Pointer: " < < result- > abstract ( ) . get ( ) ;
cache_ [ conf ] = result ;
// Set intermediate abstract value.
if ( IsIntermediateAbstract ( arg ) ) {
if ( IsIntermediateAbstract ( result- > abstract ( ) ) ) {
if ( conf - > node ( ) - > intermediate_abstract ( ) = = nullptr ) {
conf - > node ( ) - > set_intermediate_abstract ( arg ) ;
MS_LOG ( DEBUG ) < < " Set intermediate abstract: " < < arg - > ToString ( ) ;
conf - > node ( ) - > set_intermediate_abstract ( result- > abstract ( ) ) ;
MS_LOG ( DEBUG ) < < " Set intermediate abstract: " < < result- > abstract ( ) - > ToString ( ) ;
} else {
auto old_spec = conf - > node ( ) - > intermediate_abstract ( ) ;
auto joined_spec = IntermediateJoin ( arg , old_spec ) ;
auto joined_spec = IntermediateJoin ( result- > abstract ( ) , old_spec ) ;
conf - > node ( ) - > set_intermediate_abstract ( joined_spec ) ;
MS_LOG ( DEBUG ) < < " Set joined intermediate abstract: \n old_spec: \t \t " < < old_spec - > ToString ( ) < < " \n new_spec: \t \t "
< < arg - > ToString ( ) < < " \n joined_spec: \t "
< < result- > abstract ( ) - > ToString ( ) < < " \n joined_spec: \t "
< < ( joined_spec ! = nullptr ? joined_spec - > ToString ( ) : " nullptr " ) ;
}
}
}
AbstractBase Ptr AnalysisCache : : GetValue ( const AnfNodeConfigPtr & conf ) {
EvalResult Ptr AnalysisCache : : GetValue ( const AnfNodeConfigPtr & conf ) {
auto value = cache_ . find ( conf ) ;
if ( value = = cache_ . end ( ) ) {
return nullptr ;
@ -142,12 +142,12 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana
return eval - > graph_context ( ) ;
}
AbstractBase Ptr AnalysisEngine : : GetEvaluatedValue ( const AnfNodeConfigPtr & conf ) {
EvalResult Ptr AnalysisEngine : : GetEvaluatedValue ( const AnfNodeConfigPtr & conf ) {
MS_EXCEPTION_IF_NULL ( conf ) ;
auto value = cache_ . GetValue ( conf ) ;
if ( value ! = nullptr ) {
MS_LOG ( DEBUG ) < < " Evaluate cache hit for NodeConfig: " < < conf - > ToString ( ) < < " , Value: " < < value . get ( ) < < " , "
< < value - > ToString ( ) ;
MS_LOG ( DEBUG ) < < " Evaluate cache hit for NodeConfig: " < < conf - > ToString ( ) < < " , Value: " < < value - > abstract ( ) . get ( )
< < " , " < < value - > abstract ( ) - > ToString ( ) ;
return value ;
}
@ -160,10 +160,10 @@ AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf)
return value ;
}
AbstractBase Ptr AnalysisEngine : : Eval ( const AnfNodeConfigPtr & conf ) {
EvalResult Ptr AnalysisEngine : : Eval ( const AnfNodeConfigPtr & conf ) {
MS_EXCEPTION_IF_NULL ( conf ) ;
AnfNodePtr node = conf - > node ( ) ;
AbstractBasePtr ret_abstrac t = nullptr ;
EvalResultPtr eval_resul t = nullptr ;
# ifdef DEBUG
compute_conf_stack_ . push_back ( node ) ;
std : : ostringstream buffer ;
@ -177,14 +177,14 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL ( node ) ;
if ( node - > abstract ( ) ! = nullptr ) {
MS_LOG ( DEBUG ) < < " Return old abstract: " < < node - > DebugString ( ) ;
ret_abstract = node - > abstract ( ) ;
eval_result = std : : make_shared < EvalResult > ( node - > abstract ( ) , std : : make_shared < AttrValueMap > ( ) ) ;
} else if ( node - > isa < ValueNode > ( ) ) {
auto value_node = node - > cast < ValueNodePtr > ( ) ;
ret_abstract = EvalValueNode ( value_node , conf ) ;
eval_result = std : : make_shared < EvalResult > ( EvalValueNode ( value_node , conf ) , nullptr ) ;
} else if ( node - > isa < CNode > ( ) ) {
auto cnode = node - > cast < CNodePtr > ( ) ;
trace : : TraceEvalCNodeEnter ( conf ) ;
ret_abstrac t = EvalCNode ( cnode , conf ) ;
eval_resul t = EvalCNode ( cnode , conf ) ;
trace : : TraceEvalCNodeLeave ( ) ;
} else {
MS_LOG ( EXCEPTION ) < < " Illegal AnfNode for evaluating, " < < node - > DebugString ( )
@ -193,13 +193,13 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
# ifdef DEBUG
compute_conf_stack_ . pop_back ( ) ;
if ( ret_abstrac t = = nullptr ) {
if ( eval_resul t = = nullptr ) {
MS_LOG ( EXCEPTION ) < < " Compute Config failed, node: " < < node - > DebugString ( )
< < " NodeInfo: " < < trace : : GetDebugInfo ( node - > debug_info ( ) ) ;
}
# endif
MS_LOG ( DEBUG ) < < " End Eval NodeConfig " < < conf - > ToString ( ) < < " , res: " < < ret_abstract - > ToString ( ) ;
return ret_abstrac t;
MS_LOG ( DEBUG ) < < " End Eval NodeConfig " < < conf - > ToString ( ) < < " , res: " < < eval_result- > abstract ( ) - > ToString ( ) ;
return eval_resul t;
}
AbstractBasePtr AnalysisEngine : : EvalValueNode ( const ValueNodePtr & value_node , const AnfNodeConfigPtr & conf ) {
@ -208,7 +208,7 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co
return ToAbstract ( value_node - > value ( ) , conf - > context ( ) , conf ) ;
}
AbstractBase Ptr AnalysisEngine : : EvalCNode ( const CNodePtr & cnode , const AnfNodeConfigPtr & conf ) {
EvalResult Ptr AnalysisEngine : : EvalCNode ( const CNodePtr & cnode , const AnfNodeConfigPtr & conf ) {
MS_EXCEPTION_IF_NULL ( conf ) ;
MS_EXCEPTION_IF_NULL ( cnode ) ;
auto & inputs = cnode - > inputs ( ) ;
@ -223,7 +223,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo
AnfNodeConfigPtr func_conf = MakeConfig ( func_node , context ) ;
MS_EXCEPTION_IF_NULL ( func_conf ) ;
// Keep it in a local variable, otherwise smart pointer will free it.
AbstractBasePtr maybe_func = func_conf - > GetEvaluatedValue ( ) ;
AbstractBasePtr maybe_func = func_conf - > GetEvaluatedValue ( ) - > abstract ( ) ;
if ( maybe_func = = nullptr ) {
MS_LOG ( EXCEPTION ) < < " func_conf.GetEvaluatedValue() return null, func_conf: " < < func_conf - > ToString ( )
< < " NodeInfo: " < < trace : : GetDebugInfo ( cnode - > debug_info ( ) ) ;
@ -253,7 +253,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo
return ExecuteEvaluators ( infs , conf , args_conf_list ) ;
}
AbstractBase Ptr AnalysisEngine : : Execute ( const AbstractFunctionPtr & func , const AbstractBasePtrList & args_spec_list ) {
EvalResult Ptr AnalysisEngine : : Execute ( const AbstractFunctionPtr & func , const AbstractBasePtrList & args_spec_list ) {
ConfigPtrList args_conf_list ;
( void ) std : : transform ( args_spec_list . begin ( ) , args_spec_list . end ( ) , std : : back_inserter ( args_conf_list ) ,
[ ] ( const AbstractBasePtr & arg ) - > ConfigPtr { return std : : make_shared < VirtualConfig > ( arg ) ; } ) ;
@ -454,9 +454,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
return tracked_eval ;
}
AbstractBasePtr AnalysisEngine : : ExecuteEvaluators ( const std : : vector < EvaluatorPtr > & evaluators ,
const AnfNodeConfigPtr & out_conf ,
const ConfigPtrList & args_conf_list ) {
EvalResultPtr AnalysisEngine : : ExecuteEvaluators ( const std : : vector < EvaluatorPtr > & evaluators ,
const AnfNodeConfigPtr & out_conf , const ConfigPtrList & args_conf_list ) {
if ( evaluators . size ( ) = = 1 ) {
EvaluatorPtr eval = evaluators [ 0 ] ;
MS_EXCEPTION_IF_NULL ( eval ) ;
@ -465,9 +464,9 @@ AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr
return ExecuteMultipleEvaluators ( evaluators , out_conf , args_conf_list ) ;
}
AbstractBase Ptr AnalysisEngine : : ExecuteMultipleEvaluators ( const std : : vector < EvaluatorPtr > & evaluators ,
const AnfNodeConfigPtr & out_conf ,
const ConfigPtrList & args_conf_list ) {
EvalResult Ptr AnalysisEngine : : ExecuteMultipleEvaluators ( const std : : vector < EvaluatorPtr > & evaluators ,
const AnfNodeConfigPtr & out_conf ,
const ConfigPtrList & args_conf_list ) {
AbstractBasePtrList out_specs ;
if ( ! multi_poss_ . count ( evaluators [ 0 ] ) ) {
multi_poss_ [ evaluators [ 0 ] ] = evaluators [ 1 ] ;
@ -477,7 +476,7 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
MS_EXCEPTION_IF_NULL ( conf ) ;
return conf - > GetEvaluatedValue ( ) ;
return conf - > GetEvaluatedValue ( ) - > abstract ( ) ;
} ) ;
for ( auto eval : evaluators ) {
auto fg_eval = eval - > cast < FuncGraphEvaluatorPtr > ( ) ;
@ -502,11 +501,10 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
eval_trace_ . push_back ( current_inf ) ;
MS_LOG ( DEBUG ) < < " Trace Evaluator " < < eval - > ToString ( ) < < " ptr: " < < eval . get ( ) ;
MS_EXCEPTION_IF_NULL ( eval ) ;
auto out_spec = eval - > Run ( shared_from_this ( ) , args_conf_list , out_conf ) ;
MS_EXCEPTION_IF_NULL ( out_spec ) ;
MS_LOG ( DEBUG ) < < " Evaluator " < < eval - > ToString ( ) < < " return out_spec: " < < out_spec - > ToString ( ) ;
out_specs . push_back ( out_spec ) ;
MS_LOG ( DEBUG ) < < " Pop Evaluator " < < eval - > ToString ( ) ;
auto eval_result = eval - > Run ( shared_from_this ( ) , args_conf_list , out_conf ) ;
MS_EXCEPTION_IF_NULL ( eval_result - > abstract ( ) ) ;
MS_LOG ( DEBUG ) < < " Evaluator " < < eval - > ToString ( ) < < " return out_spec: " < < eval_result - > abstract ( ) - > ToString ( ) ;
out_specs . push_back ( eval_result - > abstract ( ) ) ;
eval_trace_ . pop_back ( ) ;
if ( eval_trace_ . empty ( ) ) {
multi_poss_ . clear ( ) ;
@ -552,10 +550,11 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
// Try to travel the latest undetermined.
if ( latest_entry ! = eval_trace_ . rbegin ( ) - > first ) {
MS_LOG ( DEBUG ) < < " Direct Run Evaluator " < < eval - > ToString ( ) ;
auto out_spec = latest_entry - > Run ( shared_from_this ( ) , args_conf_list , out_conf ) ;
MS_EXCEPTION_IF_NULL ( out_spec ) ;
MS_LOG ( DEBUG ) < < " Evaluator " < < latest_entry - > ToString ( ) < < " return out_spec: " < < out_spec - > ToString ( ) ;
return out_spec ;
auto eval_result = latest_entry - > Run ( shared_from_this ( ) , args_conf_list , out_conf ) ;
MS_EXCEPTION_IF_NULL ( eval_result - > abstract ( ) ) ;
MS_LOG ( DEBUG ) < < " Evaluator " < < latest_entry - > ToString ( )
< < " return out_spec: " < < eval_result - > abstract ( ) - > ToString ( ) ;
return eval_result ;
}
}
}
@ -566,15 +565,15 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
if ( out_specs . size ( ) = = 1 ) {
MS_EXCEPTION_IF_NULL ( out_specs [ 0 ] ) ;
// If only one result derived, then broaden it to avoid wrong constant propagation.
return out_specs[ 0 ] - > Broaden ( ) ;
return std: : make_shared < EvalResult > ( out_specs[ 0 ] - > Broaden ( ) , std : : make_shared < AttrValueMap > ( ) ) ;
}
auto joined_spec = AbstractJoin ( out_specs ) ;
MS_EXCEPTION_IF_NULL ( joined_spec ) ;
MS_LOG ( DEBUG ) < < " Multiple evaluators joined: " < < joined_spec - > ToString ( ) ;
return joined_spec;
return std: : make_shared < EvalResult > ( joined_spec, std : : make_shared < AttrValueMap > ( ) ) ;
}
AbstractBase Ptr AnfNodeConfig : : GetEvaluatedValue ( ) {
EvalResult Ptr AnfNodeConfig : : GetEvaluatedValue ( ) {
AnfNodeConfigPtr self = shared_from_base < AnfNodeConfig > ( ) ;
return engine_ . lock ( ) - > GetEvaluatedValue ( self ) ;
}
@ -607,7 +606,7 @@ AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
return a ;
}
AbstractBase Ptr EvalOnePrim ( const PrimitivePtr & primitive , const AbstractBasePtrList & arg_specs ) {
EvalResult Ptr EvalOnePrim ( const PrimitivePtr & primitive , const AbstractBasePtrList & arg_specs ) {
auto evaluator = GetPrimEvaluator ( primitive , nullptr ) ;
MS_EXCEPTION_IF_NULL ( evaluator ) ;
if ( ! evaluator - > isa < TrivialPrimEvaluator > ( ) ) {
@ -615,8 +614,8 @@ AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtr
< < evaluator - > ToString ( ) ;
}
auto trivial_evaluator = dyn_cast < TrivialPrimEvaluator > ( evaluator ) ;
auto res_spec = trivial_evaluator - > EvalPrim ( nullptr , arg_specs ) ;
return res_spec ;
auto eval_result = trivial_evaluator - > EvalPrim ( nullptr , arg_specs ) ;
return eval_result ;
}
} // namespace abstract
} // namespace mindspore