fix jit.save input_spec type change problem (#25683)

* fix jit.save input type change error

* add unittes
fix_copy_if_different
Chen Weihang 5 years ago committed by GitHub
parent 364cc53618
commit e8caffbb4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -653,8 +653,9 @@ def save(layer, model_path, input_spec=None, configs=None):
"""
def get_inout_spec(all_vars, target_vars, return_name=False):
valid_vars = [var for var in all_vars if isinstance(var, Variable)]
result_list = []
valid_var_dict = {}
valid_vars = [var for var in all_vars if isinstance(var, Variable)]
for var in valid_vars:
valid_var_dict[var.name] = var
if target_vars:
@ -663,13 +664,13 @@ def save(layer, model_path, input_spec=None, configs=None):
if var.name not in valid_var_dict:
raise RuntimeError(
"The variable to feed/fetch are not exist.")
target_vars[i] = valid_var_dict[var.name]
result_list.append(valid_var_dict[var.name])
else:
target_vars = valid_vars
result_list = valid_vars
if return_name:
target_vars = [var.name for var in target_vars]
result_list = [var.name for var in target_vars]
return target_vars
return result_list
# 1. input check
prog_translator = ProgramTranslator()

@ -114,8 +114,11 @@ class TestJitSaveLoad(unittest.TestCase):
def train_and_save_model(self):
layer = LinearNet(784, 1)
example_inputs, layer, _ = train(layer)
orig_input_types = [type(x) for x in example_inputs]
fluid.dygraph.jit.save(
layer=layer, model_path=self.model_path, input_spec=example_inputs)
new_input_types = [type(x) for x in example_inputs]
self.assertEqual(orig_input_types, new_input_types)
return layer
def test_save(self):

Loading…
Cancel
Save