@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co . , Ltd
* Copyright 2019 - 2021 Huawei Technologies Co . , Ltd
*
* Licensed under the Apache License , Version 2.0 ( the " License " ) ;
* you may not use this file except in compliance with the License .
@ -83,11 +83,11 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
const auto & arg = args_spec_list [ i ] ;
const auto & node = parameters [ i ] ;
AnfNodeConfigPtr conf = engine - > MakeConfig ( node , graph_context_ ) ;
engine - > cache( ) . set_value ( conf , std : : make_shared < EvalResult > ( arg , nullptr ) ) ;
engine - > analysis_ cache( ) . set_value ( conf , std : : make_shared < EvalResult > ( arg , nullptr ) ) ;
}
const AnfNodePtr & func_node = fg - > get_return ( ) ;
MS_LOG ( DEBUG ) < < " Analysis FuncGraph begin, func graph: " < < fg . get ( ) < < fg - > ToString ( )
MS_LOG ( DEBUG ) < < " Analysis FuncGraph begin, func graph: " < < fg < < " / " < < fg - > ToString ( )
< < " , context: " < < graph_context_ - > ToString ( ) < < " , return node: " < < func_node - > DebugString ( )
< < " , current function call depth: " < < engine - > function_call_depth ( ) ;
AbstractBasePtr ret_base = nullptr ;
@ -97,37 +97,20 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
< < MsContext : : GetInstance ( ) - > get_param < uint32_t > ( MS_CTX_MAX_CALL_DEPTH )
< < " , please call 'context.set_context(max_call_depth=value)' to adjust this value. " ;
}
// Analysis for isolate nodes first, as some validation check in FuncGraph is isolate nodes;
for ( const auto & node : fg - > GetIsolateNodesInOrder ( ) ) {
AnfNodeConfigPtr node_conf = engine - > MakeConfig ( node , graph_context_ ) ;
MS_LOG ( DEBUG ) < < " Analysis isolate_node begin, func graph: " < < fg . get ( ) < < fg - > ToString ( )
< < " , node_conf: " < < node_conf - > ToString ( ) ;
auto isolate_base = engine - > GetEvaluatedValue ( node_conf ) - > abstract ( ) ;
MS_LOG ( DEBUG ) < < " Analysis isolate_node end, func graph: " < < fg . get ( ) < < fg - > ToString ( )
< < " , node_conf: " < < node_conf - > ToString ( ) < < " , abstract: " < < isolate_base - > ToString ( ) ;
}
const auto & all_nodes = TopoSort ( func_node , SuccIncoming , [ & fg ] ( const AnfNodePtr & node ) - > IncludeType {
if ( node - > func_graph ( ) ! = fg | | node - > isa < ValueNode > ( ) ) {
return EXCLUDE ;
}
return FOLLOW ;
} ) ;
bool isolate_node_propagate_flag = false ;
for ( const auto & node : all_nodes ) {
AnfNodeConfigPtr node_conf = engine - > MakeConfig ( node , graph_context_ ) ;
MS_LOG ( DEBUG ) < < " Analysis node begin, func graph: " < < fg . get ( ) < < fg - > ToString ( )
MS_LOG ( DEBUG ) < < " Analysis node begin, func graph: " < < fg < < " / " < < fg - > ToString ( )
< < " , node_conf: " < < node_conf - > ToString ( ) ;
auto node_eval_result = engine - > GetEvaluatedValu e( node_conf ) ;
auto node_eval_result = engine - > ObtainEvalResultWithCach e( node_conf ) ;
ret_base = node_eval_result - > abstract ( ) ;
MS_LOG ( DEBUG ) < < " Analysis node end, func graph: " < < fg . get ( ) < < fg - > ToString ( )
MS_LOG ( DEBUG ) < < " Analysis node end, func graph: " < < fg < < " / " < < fg - > ToString ( )
< < " , node_conf: " < < node_conf - > ToString ( ) < < " , abstract: " < < ret_base - > ToString ( ) ;
if ( node - > isa < CNode > ( ) ) {
isolate_node_propagate_flag | = node_eval_result - > HasIsolateNodesPropagateCNodeFlag ( ) ;
MS_LOG ( DEBUG ) < < " Check isolate_nodes flag for node: " < < node - > DebugString ( )
< < " , abstract: " < < ret_base - > ToString ( )
< < " , flag: " < < node_eval_result - > HasIsolateNodesPropagateCNodeFlag ( ) ;
}
}
engine - > DecreaseFunctionCallDepth ( ) ;
@ -138,12 +121,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
if ( fg - > stub ( ) ) {
ret_base = std : : make_shared < AbstractUndetermined > ( ) ;
}
auto eval_result = std : : make_shared < EvalResult > ( ret_base , std : : make_shared < AttrValueMap > ( ) ) ;
if ( isolate_node_propagate_flag ) {
eval_result - > SetIsolateNodesPropagateCNodeFlag ( true ) ;
eval_result - > SetIsolateNodesPropagateFuncGraphFlag ( true ) ;
}
return eval_result ;
return std : : make_shared < EvalResult > ( ret_base , nullptr ) ;
}
AbstractBasePtrList FuncGraphEvaluator : : NormalizeArgs ( const AbstractBasePtrList & args_spec_list ) const {
@ -280,15 +258,15 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
MS_EXCEPTION_IF_NULL ( conf ) ;
return conf - > GetEvaluatedValue ( ) - > abstract ( ) ;
return conf - > ObtainEvalResult ( ) - > abstract ( ) ;
} ) ;
args_spec_list = NormalizeArgs ( args_spec_list ) ;
args_spec_list = BroadenUndeterminedArgs ( args_spec_list ) ;
trace : : TraceGraphEvalEnter ( shared_from_base < Evaluator > ( ) , out_conf ) ;
MS_LOG ( DEBUG ) < < EvalEntryLogging ( shared_from_base < Evaluator > ( ) , args_spec_list , out_conf ) ;
MS_EXCEPTION_IF_NULL ( cache_) ;
auto iter = cache_- > find ( args_spec_list ) ;
if ( iter = = cache_- > end ( ) ) {
MS_EXCEPTION_IF_NULL ( evaluator_ cache_map _) ;
auto iter = evaluator_ cache_map _- > find ( args_spec_list ) ;
if ( iter = = evaluator_ cache_map _- > end ( ) ) {
MS_LOG ( DEBUG ) < < evaluator_name < < " cache miss, call Eval(). " ;
EvalResultPtr ret = Eval ( engine , args_spec_list ) ;
if ( ret - > abstract ( ) = = nullptr ) {
@ -296,7 +274,7 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args
MS_LOG ( EXCEPTION ) < < " Evaluator " < < evaluator_name < < " result is nullptr. " ;
}
MS_LOG ( DEBUG ) < < evaluator_name < < " set cache. return: " < < ret - > abstract ( ) - > ToString ( ) < < " . " ;
( * cache_) [ args_spec_list ] = ret ;
( * evaluator_ cache_map _) [ args_spec_list ] = ret ;
trace : : TraceGraphEvalLeave ( shared_from_base < Evaluator > ( ) ) ;
return ret ;
} else {
@ -315,7 +293,7 @@ EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
[ is_py_eval ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
MS_EXCEPTION_IF_NULL ( conf ) ;
auto abstract = conf - > GetEvaluatedValue ( ) - > abstract ( ) ;
auto abstract = conf - > ObtainEvalResult ( ) - > abstract ( ) ;
// broaden the ref_key, while infer python prim for cache
if ( is_py_eval & & abstract - > isa < AbstractRef > ( ) ) {
auto abs_ref = abstract - > cast < AbstractRefPtr > ( ) ;
@ -333,7 +311,7 @@ EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const Confi
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
MS_EXCEPTION_IF_NULL ( conf ) ;
return conf - > GetEvaluatedValue ( ) - > abstract ( ) ;
return conf - > ObtainEvalResult ( ) - > abstract ( ) ;
} ) ;
if ( args_conf_list . size ( ) = = 0 ) {
MS_LOG ( EXCEPTION ) < < " Size should greater than 0 " ;
@ -354,12 +332,12 @@ EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrLis
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
MS_EXCEPTION_IF_NULL ( conf ) ;
return conf - > GetEvaluatedValue ( ) - > abstract ( ) ;
return conf - > ObtainEvalResult ( ) - > abstract ( ) ;
} ) ;
EvalResultPtr ret = sub_evaluator_ - > Run ( engine , args_conf_list , out_conf ) ;
// Don't lookup from cache, as different out_conf with same node but different context
// may add different entry to anfnode_config_map_, like getattr primitive.
( * cache_) [ args_spec_list ] = ret ;
( * evaluator_ cache_map _) [ args_spec_list ] = ret ;
return ret ;
}
@ -369,11 +347,11 @@ EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtr
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
MS_EXCEPTION_IF_NULL ( conf ) ;
return conf - > GetEvaluatedValue ( ) - > abstract ( ) ;
return conf - > ObtainEvalResult ( ) - > abstract ( ) ;
} ) ;
MS_EXCEPTION_IF_NULL ( cache_) ;
auto iter = cache_- > find ( args_spec_list ) ;
if ( iter ! = cache_- > end ( ) ) {
MS_EXCEPTION_IF_NULL ( evaluator_ cache_map _) ;
auto iter = evaluator_ cache_map _- > find ( args_spec_list ) ;
if ( iter ! = evaluator_ cache_map _- > end ( ) ) {
return iter - > second ;
}
@ -386,7 +364,7 @@ EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtr
[ ] ( const AbstractBasePtr & arg ) - > ConfigPtr { return std : : make_shared < VirtualConfig > ( arg ) ; } ) ;
EvalResultPtr ret = evaluator_ - > Run ( engine , partial_args_conf_list , out_conf ) ;
( * cache_) [ args_spec_list ] = ret ;
( * evaluator_ cache_map _) [ args_spec_list ] = ret ;
return ret ;
}
@ -395,11 +373,11 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
MS_EXCEPTION_IF_NULL ( conf ) ;
return conf - > GetEvaluatedValue ( ) - > abstract ( ) ;
return conf - > ObtainEvalResult ( ) - > abstract ( ) ;
} ) ;
MS_EXCEPTION_IF_NULL ( cache_) ;
auto iter = cache_- > find ( args_spec_list ) ;
if ( iter ! = cache_- > end ( ) ) {
MS_EXCEPTION_IF_NULL ( evaluator_ cache_map _) ;
auto iter = evaluator_ cache_map _- > find ( args_spec_list ) ;
if ( iter ! = evaluator_ cache_map _- > end ( ) ) {
return iter - > second ;
}
@ -427,7 +405,7 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
AbstractBasePtrList jargs = { result - > abstract ( ) , bprop } ;
AbstractBasePtr jtuple = std : : make_shared < AbstractTuple > ( jargs ) ;
auto infer_reuslt = std : : make_shared < EvalResult > ( jtuple , std : : make_shared < AttrValueMap > ( ) ) ;
( * cache_) [ args_spec_list ] = infer_reuslt ;
( * evaluator_ cache_map _) [ args_spec_list ] = infer_reuslt ;
return infer_reuslt ;
}