@ -89,7 +89,7 @@ static std::vector<AnfNodePtr> FastShadowSort(const AnfNodePtr &ret_node) {
return sorted_nodes ;
return sorted_nodes ;
}
}
AbstractBase Ptr BaseFuncGraphEvaluator : : Eval ( AnalysisEnginePtr engine , const AbstractBasePtrList & args_spec_list ) {
EvalResult Ptr BaseFuncGraphEvaluator : : Eval ( AnalysisEnginePtr engine , const AbstractBasePtrList & args_spec_list ) {
FuncGraphPtr fg = GetFuncGraph ( engine , args_spec_list ) ;
FuncGraphPtr fg = GetFuncGraph ( engine , args_spec_list ) ;
MS_EXCEPTION_IF_NULL ( fg ) ;
MS_EXCEPTION_IF_NULL ( fg ) ;
std : : size_t nargs = fg - > parameters ( ) . size ( ) ;
std : : size_t nargs = fg - > parameters ( ) . size ( ) ;
@ -106,7 +106,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs
const auto & arg = args_spec_list [ i ] ;
const auto & arg = args_spec_list [ i ] ;
const auto & node = parameters [ i ] ;
const auto & node = parameters [ i ] ;
AnfNodeConfigPtr conf = engine - > MakeConfig ( node , graph_context_ ) ;
AnfNodeConfigPtr conf = engine - > MakeConfig ( node , graph_context_ ) ;
engine - > cache ( ) . set_value ( conf , arg) ;
engine - > cache ( ) . set_value ( conf , std: : make_shared < EvalResult > ( arg, nullptr ) ) ;
}
}
const AnfNodePtr & func_node = fg - > get_return ( ) ;
const AnfNodePtr & func_node = fg - > get_return ( ) ;
@ -118,14 +118,14 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs
const auto & node = * it ;
const auto & node = * it ;
AnfNodeConfigPtr node_conf = engine - > MakeConfig ( node , graph_context_ ) ;
AnfNodeConfigPtr node_conf = engine - > MakeConfig ( node , graph_context_ ) ;
MS_LOG ( DEBUG ) < < " Analysis node begin, func graph: " < < fg - > ToString ( ) < < " , node_conf: " < < node_conf - > ToString ( ) ;
MS_LOG ( DEBUG ) < < " Analysis node begin, func graph: " < < fg - > ToString ( ) < < " , node_conf: " < < node_conf - > ToString ( ) ;
ret_base = engine - > GetEvaluatedValue ( node_conf ) ;
ret_base = engine - > GetEvaluatedValue ( node_conf ) - > abstract ( ) ;
MS_LOG ( DEBUG ) < < " Analysis node end, func graph: " < < fg - > ToString ( ) < < " , node_conf: " < < node_conf - > ToString ( )
MS_LOG ( DEBUG ) < < " Analysis node end, func graph: " < < fg - > ToString ( ) < < " , node_conf: " < < node_conf - > ToString ( )
< < " , abstract: " < < ret_base - > ToString ( ) ;
< < " , abstract: " < < ret_base - > ToString ( ) ;
}
}
MS_EXCEPTION_IF_NULL ( ret_base ) ;
MS_EXCEPTION_IF_NULL ( ret_base ) ;
MS_LOG ( DEBUG ) < < " BaseFuncGraph " < < fg - > ToString ( ) < < " E val end, evaluated abstract: " < < ret_base - > ToString ( ) ;
MS_LOG ( DEBUG ) < < " BaseFuncGraph " < < fg - > ToString ( ) < < " e val end, evaluated abstract: " < < ret_base - > ToString ( ) ;
return ret_base;
return std: : make_shared < EvalResult > ( ret_base, nullptr ) ;
}
}
AbstractBasePtrList FuncGraphEvaluator : : NormalizeArgs ( const AbstractBasePtrList & args_spec_list ) const {
AbstractBasePtrList FuncGraphEvaluator : : NormalizeArgs ( const AbstractBasePtrList & args_spec_list ) const {
@ -236,15 +236,14 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons
return cloned_func_graph ;
return cloned_func_graph ;
}
}
AbstractBasePtr Evaluator : : Run ( AnalysisEnginePtr engine , const ConfigPtrList & args_conf_list ,
EvalResultPtr Evaluator : : Run ( AnalysisEnginePtr engine , const ConfigPtrList & args_conf_list , AnfNodeConfigPtr out_conf ) {
AnfNodeConfigPtr out_conf ) {
const std : : string & evaluator_name = ToString ( ) ;
const std : : string & evaluator_name = ToString ( ) ;
AbstractBasePtrList args_spec_list ;
AbstractBasePtrList args_spec_list ;
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
MS_EXCEPTION_IF_NULL ( conf ) ;
MS_EXCEPTION_IF_NULL ( conf ) ;
return conf - > GetEvaluatedValue ( ) ;
return conf - > GetEvaluatedValue ( ) - > abstract ( ) ;
} ) ;
} ) ;
args_spec_list = NormalizeArgs ( args_spec_list ) ;
args_spec_list = NormalizeArgs ( args_spec_list ) ;
args_spec_list = BroadenUndeterminedArgs ( args_spec_list ) ;
args_spec_list = BroadenUndeterminedArgs ( args_spec_list ) ;
@ -254,79 +253,79 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar
auto iter = cache_ - > find ( args_spec_list ) ;
auto iter = cache_ - > find ( args_spec_list ) ;
if ( iter = = cache_ - > end ( ) ) {
if ( iter = = cache_ - > end ( ) ) {
MS_LOG ( DEBUG ) < < evaluator_name < < " cache miss, call Eval(). " ;
MS_LOG ( DEBUG ) < < evaluator_name < < " cache miss, call Eval(). " ;
AbstractBase Ptr ret = Eval ( engine , args_spec_list ) ;
EvalResult Ptr ret = Eval ( engine , args_spec_list ) ;
if ( ret = = nullptr ) {
if ( ret - > abstract ( ) = = nullptr ) {
EvalFailLogging ( shared_from_base < Evaluator > ( ) , args_spec_list , out_conf ) ;
EvalFailLogging ( shared_from_base < Evaluator > ( ) , args_spec_list , out_conf ) ;
MS_LOG ( EXCEPTION ) < < " Evaluator " < < evaluator_name < < " result is nullptr. " ;
MS_LOG ( EXCEPTION ) < < " Evaluator " < < evaluator_name < < " result is nullptr. " ;
}
}
MS_EXCEPTION_IF_NULL ( ret ) ;
MS_LOG ( DEBUG ) < < evaluator_name < < " set cache. return: " < < ret - > abstract ( ) - > ToString ( ) < < " . " ;
MS_LOG ( DEBUG ) < < evaluator_name < < " set cache. return: " < < ret - > ToString ( ) < < " . " ;
( * cache_ ) [ args_spec_list ] = ret ;
( * cache_ ) [ args_spec_list ] = ret ;
trace : : TraceGraphEvalLeave ( shared_from_base < Evaluator > ( ) ) ;
trace : : TraceGraphEvalLeave ( shared_from_base < Evaluator > ( ) ) ;
return ret ;
return ret ;
} else {
} else {
MS_EXCEPTION_IF_NULL ( iter - > second ) ;
MS_EXCEPTION_IF_NULL ( iter - > second ) ;
MS_LOG ( DEBUG ) < < evaluator_name < < " cache hit. return: " < < iter - > second - > ToString ( ) < < " . " ;
MS_EXCEPTION_IF_NULL ( iter - > second - > abstract ( ) ) ;
MS_LOG ( DEBUG ) < < evaluator_name < < " cache hit. return: " < < iter - > second - > abstract ( ) - > ToString ( ) < < " . " ;
trace : : TraceGraphEvalLeave ( shared_from_base < Evaluator > ( ) ) ;
trace : : TraceGraphEvalLeave ( shared_from_base < Evaluator > ( ) ) ;
return iter - > second ;
return iter - > second ;
}
}
}
}
AbstractBase Ptr TrivialPrimEvaluator : : Run ( AnalysisEnginePtr engine , const ConfigPtrList & args_conf_list ,
EvalResult Ptr TrivialPrimEvaluator : : Run ( AnalysisEnginePtr engine , const ConfigPtrList & args_conf_list ,
AnfNodeConfigPtr ) {
AnfNodeConfigPtr ) {
AbstractBasePtrList args_spec_list ;
AbstractBasePtrList args_spec_list ;
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
MS_EXCEPTION_IF_NULL ( conf ) ;
MS_EXCEPTION_IF_NULL ( conf ) ;
return conf - > GetEvaluatedValue ( ) ;
return conf - > GetEvaluatedValue ( ) - > abstract ( ) ;
} ) ;
} ) ;
AbstractBase Ptr ret = EvalPrim ( engine , args_spec_list ) ;
EvalResult Ptr ret = EvalPrim ( engine , args_spec_list ) ;
return ret ;
return ret ;
}
}
AbstractBase Ptr TransitionPrimEvaluator : : Run ( AnalysisEnginePtr engine , const ConfigPtrList & args_conf_list ,
EvalResult Ptr TransitionPrimEvaluator : : Run ( AnalysisEnginePtr engine , const ConfigPtrList & args_conf_list ,
AnfNodeConfigPtr out_conf ) {
AnfNodeConfigPtr out_conf ) {
AbstractBasePtrList args_spec_list ;
AbstractBasePtrList args_spec_list ;
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
MS_EXCEPTION_IF_NULL ( conf ) ;
MS_EXCEPTION_IF_NULL ( conf ) ;
return conf - > GetEvaluatedValue ( ) ;
return conf - > GetEvaluatedValue ( ) - > abstract ( ) ;
} ) ;
} ) ;
if ( args_conf_list . size ( ) = = 0 ) {
if ( args_conf_list . size ( ) = = 0 ) {
MS_LOG ( EXCEPTION ) < < " Size should greater than 0 " ;
MS_LOG ( EXCEPTION ) < < " Size should greater than 0 " ;
}
}
AbstractBase Ptr ret = EvalPrim ( engine , args_spec_list , args_conf_list [ 0 ] , out_conf ) ;
EvalResult Ptr ret = EvalPrim ( engine , args_spec_list , args_conf_list [ 0 ] , out_conf ) ;
// No need to cache.
// No need to cache.
return ret ;
return ret ;
}
}
AbstractBase Ptr SymbolicPrimEvaluator : : Run ( AnalysisEnginePtr , const ConfigPtrList & args_conf_list , AnfNodeConfigPtr ) {
EvalResult Ptr SymbolicPrimEvaluator : : Run ( AnalysisEnginePtr , const ConfigPtrList & args_conf_list , AnfNodeConfigPtr ) {
AbstractBase Ptr ret = EvalPrim ( args_conf_list ) ;
EvalResult Ptr ret = EvalPrim ( args_conf_list ) ;
return ret ;
return ret ;
}
}
AbstractBase Ptr TrackedEvaluator : : Run ( AnalysisEnginePtr engine , const ConfigPtrList & args_conf_list ,
EvalResult Ptr TrackedEvaluator : : Run ( AnalysisEnginePtr engine , const ConfigPtrList & args_conf_list ,
AnfNodeConfigPtr out_conf ) {
AnfNodeConfigPtr out_conf ) {
AbstractBasePtrList args_spec_list ;
AbstractBasePtrList args_spec_list ;
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
MS_EXCEPTION_IF_NULL ( conf ) ;
MS_EXCEPTION_IF_NULL ( conf ) ;
return conf - > GetEvaluatedValue ( ) ;
return conf - > GetEvaluatedValue ( ) - > abstract ( ) ;
} ) ;
} ) ;
AbstractBase Ptr ret = sub_evaluator_ - > Run ( engine , args_conf_list , out_conf ) ;
EvalResult Ptr ret = sub_evaluator_ - > Run ( engine , args_conf_list , out_conf ) ;
// Don't lookup from cache, as different out_conf with same node but different context
// Don't lookup from cache, as different out_conf with same node but different context
// may add different entry to anfnode_config_map_, like getattr primitive.
// may add different entry to anfnode_config_map_, like getattr primitive.
( * cache_ ) [ args_spec_list ] = ret ;
( * cache_ ) [ args_spec_list ] = ret ;
return ret ;
return ret ;
}
}
AbstractBase Ptr PartialAppEvaluator : : Run ( AnalysisEnginePtr engine , const ConfigPtrList & args_conf_list ,
EvalResult Ptr PartialAppEvaluator : : Run ( AnalysisEnginePtr engine , const ConfigPtrList & args_conf_list ,
AnfNodeConfigPtr out_conf ) {
AnfNodeConfigPtr out_conf ) {
AbstractBasePtrList args_spec_list ;
AbstractBasePtrList args_spec_list ;
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
MS_EXCEPTION_IF_NULL ( conf ) ;
MS_EXCEPTION_IF_NULL ( conf ) ;
return conf - > GetEvaluatedValue ( ) ;
return conf - > GetEvaluatedValue ( ) - > abstract ( ) ;
} ) ;
} ) ;
MS_EXCEPTION_IF_NULL ( cache_ ) ;
MS_EXCEPTION_IF_NULL ( cache_ ) ;
auto iter = cache_ - > find ( args_spec_list ) ;
auto iter = cache_ - > find ( args_spec_list ) ;
@ -341,17 +340,18 @@ AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigP
( void ) std : : transform ( args_spec_list . begin ( ) , args_spec_list . end ( ) , std : : back_inserter ( partial_args_conf_list ) ,
( void ) std : : transform ( args_spec_list . begin ( ) , args_spec_list . end ( ) , std : : back_inserter ( partial_args_conf_list ) ,
[ ] ( const AbstractBasePtr & arg ) - > ConfigPtr { return std : : make_shared < VirtualConfig > ( arg ) ; } ) ;
[ ] ( const AbstractBasePtr & arg ) - > ConfigPtr { return std : : make_shared < VirtualConfig > ( arg ) ; } ) ;
AbstractBasePtr ret = evaluator_ - > Run ( engine , partial_args_conf_list , out_conf ) ;
EvalResultPtr ret = evaluator_ - > Run ( engine , partial_args_conf_list , out_conf ) ;
( * cache_ ) [ args_spec_list ] = ret ;
( * cache_ ) [ args_spec_list ] = ret ;
return ret ;
return ret ;
}
}
AbstractBase Ptr JEvaluator : : Run ( AnalysisEnginePtr engine , const ConfigPtrList & args_conf_list , AnfNodeConfigPtr ) {
EvalResult Ptr JEvaluator : : Run ( AnalysisEnginePtr engine , const ConfigPtrList & args_conf_list , AnfNodeConfigPtr ) {
AbstractBasePtrList args_spec_list ;
AbstractBasePtrList args_spec_list ;
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
( void ) std : : transform ( args_conf_list . begin ( ) , args_conf_list . end ( ) , std : : back_inserter ( args_spec_list ) ,
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
[ ] ( const ConfigPtr & conf ) - > AbstractBasePtr {
MS_EXCEPTION_IF_NULL ( conf ) ;
MS_EXCEPTION_IF_NULL ( conf ) ;
return conf - > GetEvaluatedValue ( ) ;
return conf - > GetEvaluatedValue ( ) - > abstract ( ) ;
} ) ;
} ) ;
MS_EXCEPTION_IF_NULL ( cache_ ) ;
MS_EXCEPTION_IF_NULL ( cache_ ) ;
auto iter = cache_ - > find ( args_spec_list ) ;
auto iter = cache_ - > find ( args_spec_list ) ;
@ -360,7 +360,7 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a
}
}
// Call the original evaluator, get the result: y = f(x)
// Call the original evaluator, get the result: y = f(x)
AbstractBase Ptr result = evaluator_ - > Run ( engine , args_conf_list , nullptr ) ;
EvalResult Ptr result = evaluator_ - > Run ( engine , args_conf_list , nullptr ) ;
// Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input
// Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input
// parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
// parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
AbstractBasePtrList bparams ;
AbstractBasePtrList bparams ;
@ -369,16 +369,18 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a
args_spec_list . begin ( ) , args_spec_list . end ( ) , std : : back_inserter ( bparams ) ,
args_spec_list . begin ( ) , args_spec_list . end ( ) , std : : back_inserter ( bparams ) ,
[ ] ( const AbstractBasePtr & arg_spec ) - > AbstractBasePtr { return SensitivityTransform ( arg_spec ) ; } ) ;
[ ] ( const AbstractBasePtr & arg_spec ) - > AbstractBasePtr { return SensitivityTransform ( arg_spec ) ; } ) ;
AbstractBasePtr bparams_final = std : : make_shared < AbstractTuple > ( bparams ) ;
AbstractBasePtr bparams_final = std : : make_shared < AbstractTuple > ( bparams ) ;
AbstractFunctionPtr bprop = std : : make_shared < VirtualAbstractClosure > ( SensitivityTransform ( result ) , bparams_final ) ;
AbstractFunctionPtr bprop =
std : : make_shared < VirtualAbstractClosure > ( SensitivityTransform ( result - > abstract ( ) ) , bparams_final ) ;
// J(f)(J(x)) return a tuple (y, bprop_f)
// J(f)(J(x)) return a tuple (y, bprop_f)
AbstractBasePtrList jargs = { result , bprop } ;
AbstractBasePtrList jargs = { result - > abstract ( ) , bprop } ;
AbstractBasePtr jtuple = std : : make_shared < AbstractTuple > ( jargs ) ;
AbstractBasePtr jtuple = std : : make_shared < AbstractTuple > ( jargs ) ;
( * cache_ ) [ args_spec_list ] = jtuple ;
auto infer_reuslt = std : : make_shared < EvalResult > ( jtuple , std : : make_shared < AttrValueMap > ( ) ) ;
return jtuple ;
( * cache_ ) [ args_spec_list ] = infer_reuslt ;
return infer_reuslt ;
}
}
AbstractBase Ptr VirtualEvaluator : : Eval ( AnalysisEnginePtr , const AbstractBasePtrList & args_spec_list ) {
EvalResult Ptr VirtualEvaluator : : Eval ( AnalysisEnginePtr , const AbstractBasePtrList & args_spec_list ) {
if ( args_spec_list . size ( ) ! = args_spec_list_ . size ( ) ) {
if ( args_spec_list . size ( ) ! = args_spec_list_ . size ( ) ) {
MS_LOG ( EXCEPTION ) < < " Arguments mismatch, parameters no: " < < args_spec_list_ . size ( )
MS_LOG ( EXCEPTION ) < < " Arguments mismatch, parameters no: " < < args_spec_list_ . size ( )
< < " , arguments no: " < < args_spec_list . size ( ) ;
< < " , arguments no: " < < args_spec_list . size ( ) ;
@ -388,7 +390,7 @@ AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrL
MS_EXCEPTION_IF_NULL ( args_spec_list [ i ] ) ;
MS_EXCEPTION_IF_NULL ( args_spec_list [ i ] ) ;
( void ) args_spec_list [ i ] - > Join ( args_spec_list_ [ i ] ) ;
( void ) args_spec_list [ i ] - > Join ( args_spec_list_ [ i ] ) ;
}
}
return output_;
return std: : make_shared < EvalResult > ( output_, std : : make_shared < AttrValueMap > ( ) ) ;
}
}
} // namespace abstract
} // namespace abstract
} // namespace mindspore
} // namespace mindspore