@ -24,7 +24,7 @@ from paddle.fluid import core, executor
from paddle . fluid . data import data
from paddle . fluid . dygraph . dygraph_to_static import convert_to_static
__all__ = [ ' AutoTrace r' ]
__all__ = [ ' ProgramTranslato r' ]
class FunctionCache ( object ) :
@ -33,36 +33,32 @@ class FunctionCache(object):
"""
def __init__ ( self ) :
self . _ cache _funcs = dict ( )
self . _ func_to_transformer = dict ( )
self . _ dy code_to_st ati c_func = dict ( )
self . _ static_ func_to_transformer = dict ( )
def __call__ ( self , func ) :
static_func = self . _get_or_cache_func ( func )
return static_func
def _get_or_cache_func ( self , func ) :
cache_key = self . hash_key ( func )
static_func = self . _cache_funcs . get ( cache_key , None )
def get_or_cache_func ( self , func ) :
code = self . _get_dedent_code_string ( func )
static_func = self . _dycode_to_static_func . get ( code , None )
if static_func is None :
static_func , dygraph_to_static = convert_to_static ( func )
self . _cache_funcs [ cache_key ] = static_func
self . _func_to_transformer [ static_func ] = dygraph_to_static
static_func , dygraph_to_static_transformer = convert_to_static ( func )
self . _dycode_to_static_func [ code ] = static_func
self . _static_func_to_transformer [
static_func ] = dygraph_to_static_transformer
return static_func
def transformer( self , func ) :
return self . _ func_to_transformer. get ( func , None )
def get_ transformer( self , func ) :
return self . _ static_ func_to_transformer. get ( func , None )
def hash_key ( self , func ) :
def _get_dedent_code_string ( self , func ) :
raw_code = inspect . getsource ( func )
code = textwrap . dedent ( raw_code )
return hash ( code )
dedent_code = textwrap . dedent ( raw_code )
return dedent_code
def exist ( self , func ) :
return self . _cache_funcs . get ( self . hash_key ( func ) , None ) is not None
return self . _dycode_to_static_func . get (
self . _get_dedent_code_string ( func ) , None ) is not None
def synchronized ( func ) :
@ -97,9 +93,10 @@ class ProgramCache(object):
# sub class in `forward()`.
self . _in_build_process = True
def __call__ ( self , dyfunc , * args , * * kwargs ) :
def build_program_and_return_output ( self , dyfunc , * args , * * kwargs ) :
"""
Executes the main_program with specialized inputs .
Executes the main_program with specialized inputs so that the program
is built . This method also return outputs of program as fetch_list
"""
# Transfroms dygraph function into static functions and caches them.
static_func = self . _transform_or_cache_layers ( dyfunc )
@ -124,7 +121,7 @@ class ProgramCache(object):
"""
Transforms dygraph function into static function .
"""
static_func = self . _func_cache ( dyfunc )
static_func = self . _func_cache . get_or_cache_func ( dyfunc )
# self._forward_func is entry function of Net or Model.
# It can be called for multiple times, but layers from these functions
# call stack will be added into self._program only once.
@ -181,7 +178,7 @@ class ProgramCache(object):
Returns name and index of input args from ` forward ( args ) `
that need to be replaced with ` fluid . data ` .
"""
transformer = self . _func_cache . transformer( func )
transformer = self . _func_cache . get_ transformer( func )
feed_name_to_idx = transformer . get_feed_name_to_idx ( )
return feed_name_to_idx
@ -206,7 +203,7 @@ class ProgramCache(object):
return self . _in_build_process
class AutoTrace r( object ) :
class ProgramTranslato r( object ) :
_instance = None
@ -214,32 +211,32 @@ class AutoTracer(object):
def __new__ ( cls , * args , * * kwargs ) :
if cls . _instance is None :
cls . _instance = object . __new__ ( cls , * args , * * kwargs )
cls . _instance . _ _ initialized = False
cls . _instance . _ initialized = False
return cls . _instance
@classmethod
def get_instance ( cls ) :
if cls . _instance is None :
raise ValueError ( " Func Program hasn\' t been created! " )
raise ValueError ( " ProgramTranslator hasn\' t been created! " )
return cls . _instance
@classmethod
def reset ( cls ) :
if cls . _instance is not None :
cls . _instance . _ _ initialized = False
cls . _instance . _ initialized = False
cls . _instance . __init__ ( )
def __init__ ( self , exe = None , place = None ) :
# To make sure that calls __init__ only once.
if self . _ _ initialized:
if self . _ initialized:
return
self . _ _ initialized = True
self . _ initialized = True
self . _place = core . CPUPlace ( ) if place is None else place
if exe is None :
self . _exe = executor . Executor ( self . _place )
else :
self . _exe = exe
self . _ cached_ program = ProgramCache ( )
self . _ program_cache = ProgramCache ( )
self . _optimizer = None
self . _already_minimized = False
# Once main_program is changed, should run startup_program.
@ -251,7 +248,7 @@ class AutoTracer(object):
"""
feed_dict , fetch_list = self . _prepare ( args )
main_program = self . _ cached_ program. program
main_program = self . _ program_cache . program
outputs = self . _exe . run ( main_program ,
feed = feed_dict ,
fetch_list = fetch_list )
@ -266,7 +263,7 @@ class AutoTracer(object):
# Updates batch_data for feed_dict
feed_dict = self . _update_batch_data ( args )
fetch_list = self . _ cached_ program. outputs
fetch_list = self . _ program_cache . outputs
# Adds optimizer if needed.
if self . _optimizer and not self . _already_minimized :
@ -284,16 +281,16 @@ class AutoTracer(object):
In some models and unittest , program will be switched frequently by ` program_guard ` .
If does , the cached program and other properties are not available and should be reset .
"""
if self . _ cached_ program. program :
if self . _ cached_ program. program != framework . default_main_program ( ) :
AutoTrace r. reset ( )
if self . _ program_cache . program :
if self . _ program_cache . program != framework . default_main_program ( ) :
ProgramTranslato r. reset ( )
def _update_batch_data ( self , args ) :
"""
Updates cached batch data while training program .
"""
feed_name_to_idx = self . _ cached_ program. feed_name_to_idx
feed_vars = self . _ cached_ program. inputs
feed_name_to_idx = self . _ program_cache . feed_name_to_idx
feed_vars = self . _ program_cache . inputs
feed_dict = { }
for feed_var in feed_vars :
idx = feed_name_to_idx [ feed_var . name ]
@ -318,7 +315,7 @@ class AutoTracer(object):
"""
Supports to set or update the optimizer used to minimize loss .
"""
main_program = self . _ cached_ program. program
main_program = self . _ program_cache . program
all_vars = main_program . block ( 0 ) . vars
loss_var = all_vars . get ( self . _loss_name , None )
@ -333,13 +330,13 @@ class AutoTracer(object):
# Avoids to set optimizer repeatedly.
self . _already_minimized = True
def get_ cached_ program( self ) :
def get_ program_cache ( self ) :
"""
Returns the ProgramCache instance .
"""
self . _check_cache_valid ( )
return self . _ cached_ program
return self . _ program_cache
@property
def program ( self ) :
return self . _ cached_ program. program
return self . _ program_cache . program