|
|
|
@ -135,6 +135,11 @@ class FunctionSpec(object):
|
|
|
|
|
|
|
|
|
|
input_with_spec = pack_sequence_as(args, input_with_spec)
|
|
|
|
|
|
|
|
|
|
# If without specificing name in input_spec, add default name
|
|
|
|
|
# according to argument name from decorated function.
|
|
|
|
|
input_with_spec = replace_spec_empty_name(self._arg_names,
|
|
|
|
|
input_with_spec)
|
|
|
|
|
|
|
|
|
|
return input_with_spec
|
|
|
|
|
|
|
|
|
|
@switch_to_static_graph
|
|
|
|
@ -309,3 +314,61 @@ def convert_to_input_spec(inputs, input_spec):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"The type(input_spec) should be a `InputSpec` or dict/list/tuple of it, but received {}.".
|
|
|
|
|
type_name(input_spec))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def replace_spec_empty_name(args_name, input_with_spec):
|
|
|
|
|
"""
|
|
|
|
|
Adds default name according to argument name from decorated function
|
|
|
|
|
if without specificing InputSpec.name
|
|
|
|
|
|
|
|
|
|
The naming rule are as followed:
|
|
|
|
|
1. If InputSpec.name is not None, do nothing.
|
|
|
|
|
2. If each argument `x` corresponds to an InputSpec, using the argument name like `x`
|
|
|
|
|
3. If the arguments `inputs` corresponds to a list(InputSpec), using name like `inputs_0`, `inputs_1`
|
|
|
|
|
4. If the arguments `input_dic` corresponds to a dict(InputSpec), using key as name.
|
|
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
|
|
|
|
|
# case 1: foo(x, y)
|
|
|
|
|
foo = to_static(foo, input_spec=[InputSpec([None, 10]), InputSpec([None])])
|
|
|
|
|
print([in_var.name for in_var in foo.inputs]) # [x, y]
|
|
|
|
|
|
|
|
|
|
# case 2: foo(inputs) where inputs is a list
|
|
|
|
|
foo = to_static(foo, input_spec=[[InputSpec([None, 10]), InputSpec([None])]])
|
|
|
|
|
print([in_var.name for in_var in foo.inputs]) # [inputs_0, inputs_1]
|
|
|
|
|
|
|
|
|
|
# case 3: foo(inputs) where inputs is a dict
|
|
|
|
|
foo = to_static(foo, input_spec=[{'x': InputSpec([None, 10]), 'y': InputSpec([None])}])
|
|
|
|
|
print([in_var.name for in_var in foo.inputs]) # [x, y]
|
|
|
|
|
"""
|
|
|
|
|
input_with_spec = list(input_with_spec)
|
|
|
|
|
candidate_arg_names = args_name[:len(input_with_spec)]
|
|
|
|
|
|
|
|
|
|
for i, arg_name in enumerate(candidate_arg_names):
|
|
|
|
|
input_spec = input_with_spec[i]
|
|
|
|
|
input_with_spec[i] = _replace_spec_name(arg_name, input_spec)
|
|
|
|
|
|
|
|
|
|
return input_with_spec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _replace_spec_name(name, input_spec):
|
|
|
|
|
"""
|
|
|
|
|
Replaces InputSpec.name with given `name` while not specificing it.
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(input_spec, paddle.static.InputSpec):
|
|
|
|
|
if input_spec.name is None:
|
|
|
|
|
input_spec.name = name
|
|
|
|
|
return input_spec
|
|
|
|
|
elif isinstance(input_spec, (list, tuple)):
|
|
|
|
|
processed_specs = []
|
|
|
|
|
for i, spec in enumerate(input_spec):
|
|
|
|
|
new_name = "{}_{}".format(name, i)
|
|
|
|
|
processed_specs.append(_replace_spec_name(new_name, spec))
|
|
|
|
|
return processed_specs
|
|
|
|
|
elif isinstance(input_spec, dict):
|
|
|
|
|
processed_specs = {}
|
|
|
|
|
for key, spec in six.iteritems(input_spec):
|
|
|
|
|
processed_specs[key] = _replace_spec_name(key, spec)
|
|
|
|
|
return processed_specs
|
|
|
|
|
else:
|
|
|
|
|
return input_spec
|
|
|
|
|