|
|
|
@ -146,19 +146,25 @@ class CacheKey(object):
|
|
|
|
|
Cached key for ProgramCache.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
__slots__ = ['function_spec', 'input_with_spec', 'class_instance']
|
|
|
|
|
__slots__ = [
|
|
|
|
|
'function_spec', 'input_args_with_spec', 'input_kwargs_with_spec',
|
|
|
|
|
'class_instance'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def __init__(self, function_spec, input_with_spec, class_instance):
|
|
|
|
|
def __init__(self, function_spec, input_args_with_spec,
|
|
|
|
|
input_kwargs_with_spec, class_instance):
|
|
|
|
|
"""
|
|
|
|
|
Initializes a cache key.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
functions_spec(FunctionSpec): a FunctionSpec instance of decorated function.
|
|
|
|
|
input_with_spec(list[InputSpec]): actual inputs with some arguments replaced by InputSpec.
|
|
|
|
|
input_args_with_spec(list[InputSpec]): actual input args with some arguments replaced by InputSpec.
|
|
|
|
|
input_kwargs_with_spec(list[{string:InputSpec}]): actual input kwargs with some arguments replaced by InputSpec.
|
|
|
|
|
class_instance(object): a instance of class `Layer`.
|
|
|
|
|
"""
|
|
|
|
|
self.function_spec = function_spec
|
|
|
|
|
self.input_with_spec = input_with_spec
|
|
|
|
|
self.input_args_with_spec = input_args_with_spec
|
|
|
|
|
self.input_kwargs_with_spec = input_kwargs_with_spec
|
|
|
|
|
self.class_instance = class_instance
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@ -177,15 +183,18 @@ class CacheKey(object):
|
|
|
|
|
args = args[1:]
|
|
|
|
|
# 2. convert tensor and numpy array into InputSpec
|
|
|
|
|
_args, _kwargs = function_spec.unified_args_and_kwargs(args, kwargs)
|
|
|
|
|
input_with_spec = function_spec.args_to_input_spec(_args, _kwargs)
|
|
|
|
|
input_args_with_spec, input_kwargs_with_spec = function_spec.args_to_input_spec(
|
|
|
|
|
_args, _kwargs)
|
|
|
|
|
|
|
|
|
|
# 3. check whether hit the cache or build a new program for the input arguments
|
|
|
|
|
return CacheKey(function_spec, input_with_spec, class_instance)
|
|
|
|
|
return CacheKey(function_spec, input_args_with_spec,
|
|
|
|
|
input_kwargs_with_spec, class_instance)
|
|
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
|
error_msg = "Arguments to a `@paddle.jit.to_static` must be a hashable Python objects (or nested structures of these types)."
|
|
|
|
|
return hash((id(self.function_spec),
|
|
|
|
|
make_hashable(self.input_with_spec, error_msg),
|
|
|
|
|
make_hashable(self.input_args_with_spec, error_msg),
|
|
|
|
|
make_hashable(self.input_kwargs_with_spec, error_msg),
|
|
|
|
|
self.class_instance))
|
|
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
@ -195,8 +204,9 @@ class CacheKey(object):
|
|
|
|
|
return not self == other
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
return "id(function_spec): {}, input_with_spec: {}, class_instance: {}".format(
|
|
|
|
|
id(self.function_spec), self.input_with_spec, self.class_instance)
|
|
|
|
|
return "id(function_spec): {}, input_args_with_spec: {}, input_kwargs_with_spec: {}, class_instance: {}".format(
|
|
|
|
|
id(self.function_spec), self.input_args_with_spec,
|
|
|
|
|
self.input_kwargs_with_spec, self.class_instance)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unwrap_decorators(func):
|
|
|
|
@ -380,11 +390,12 @@ class StaticFunction(object):
|
|
|
|
|
if len(args) != len(self._function_spec.args_name):
|
|
|
|
|
args, kwargs = self._function_spec.unified_args_and_kwargs(args,
|
|
|
|
|
kwargs)
|
|
|
|
|
input_with_spec = self._function_spec.args_to_input_spec(args, kwargs)
|
|
|
|
|
input_args_with_spec, input_kwargs_with_spec = self._function_spec.args_to_input_spec(
|
|
|
|
|
args, kwargs)
|
|
|
|
|
|
|
|
|
|
# 2. generate cache key
|
|
|
|
|
cache_key = CacheKey(self._function_spec, input_with_spec,
|
|
|
|
|
self._class_instance)
|
|
|
|
|
cache_key = CacheKey(self._function_spec, input_args_with_spec,
|
|
|
|
|
input_kwargs_with_spec, self._class_instance)
|
|
|
|
|
|
|
|
|
|
# 3. check whether hit the cache or build a new program for the input arguments
|
|
|
|
|
concrete_program, partial_program_layer = self._program_cache[cache_key]
|
|
|
|
@ -564,7 +575,8 @@ class ConcreteProgram(object):
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
@switch_to_static_graph
|
|
|
|
|
def from_func_spec(func_spec, input_spec, class_instance):
|
|
|
|
|
def from_func_spec(func_spec, input_spec, input_kwargs_spec,
|
|
|
|
|
class_instance):
|
|
|
|
|
"""
|
|
|
|
|
Builds the main_program with specialized inputs and returns outputs
|
|
|
|
|
of program as fetch_list.
|
|
|
|
@ -593,6 +605,8 @@ class ConcreteProgram(object):
|
|
|
|
|
# 1. Adds `fluid.data` layers for input if needed
|
|
|
|
|
inputs = func_spec.to_static_inputs_with_spec(input_spec,
|
|
|
|
|
main_program)
|
|
|
|
|
kwargs = func_spec.to_static_inputs_with_spec(input_kwargs_spec,
|
|
|
|
|
main_program)
|
|
|
|
|
if class_instance:
|
|
|
|
|
inputs = tuple([class_instance] + list(inputs))
|
|
|
|
|
|
|
|
|
@ -605,7 +619,10 @@ class ConcreteProgram(object):
|
|
|
|
|
class_instance, False)), param_guard(
|
|
|
|
|
get_buffers(class_instance, False)):
|
|
|
|
|
try:
|
|
|
|
|
outputs = static_func(*inputs)
|
|
|
|
|
if kwargs:
|
|
|
|
|
outputs = static_func(*inputs, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
outputs = static_func(*inputs)
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
# NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here.
|
|
|
|
|
error.attach_error_data(e)
|
|
|
|
@ -653,7 +670,8 @@ class ProgramCache(object):
|
|
|
|
|
def _build_once(self, cache_key):
|
|
|
|
|
concrete_program = ConcreteProgram.from_func_spec(
|
|
|
|
|
func_spec=cache_key.function_spec,
|
|
|
|
|
input_spec=cache_key.input_with_spec,
|
|
|
|
|
input_spec=cache_key.input_args_with_spec,
|
|
|
|
|
input_kwargs_spec=cache_key.input_kwargs_with_spec,
|
|
|
|
|
class_instance=cache_key.class_instance)
|
|
|
|
|
return concrete_program, partial_program_from(concrete_program)
|
|
|
|
|
|
|
|
|
|