Fix bug for 'save mutiple method' (#30218)

* Fix bug for 'save mutiple method'

* To pass coverage.

* edit code to pass coverage.

* edit code to pass coverage.

* add unittest for coverage.

* change for coverage.

* edit for coverage.
revert-31562-mean
WeiXin 5 years ago committed by GitHub
parent 66dc4ac77b
commit edafb5465a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -600,9 +600,13 @@ def _construct_program_holders(model_path, model_filename=None):
model_file_path = os.path.join(model_path, model_filename)
elif filename.endswith(INFER_MODEL_SUFFIX) and filename.startswith(
model_name):
func_name = filename[len(model_name) + 1:-len(
INFER_MODEL_SUFFIX)]
model_file_path = os.path.join(model_path, filename)
parsing_names = filename[len(model_name):-len(
INFER_MODEL_SUFFIX) + 1].split('.')
if len(parsing_names) == 3 and len(parsing_names[1]) > 0:
func_name = parsing_names[1]
model_file_path = os.path.join(model_path, filename)
else:
continue
else:
continue
program_holder_dict[func_name] = _ProgramHolder(
@ -636,10 +640,14 @@ def _construct_params_and_buffers(model_path,
model_name = params_filename[:-len(INFER_PARAMS_SUFFIX)]
#Load every file that meets the requirements in the directory model_path.
for file_name in os.listdir(model_path):
if file_name.endswith(INFER_PARAMS_SUFFIX) and file_name.startswith(
model_name) and file_name != params_filename:
func_name = file_name[len(model_name) + 1:-len(
INFER_PARAMS_SUFFIX)]
if file_name.startswith(model_name) and file_name.endswith(
INFER_PARAMS_SUFFIX):
parsing_names = file_name[len(model_name):-len(
INFER_PARAMS_SUFFIX) + 1].split('.')
if len(parsing_names) == 3 and len(parsing_names[1]) > 0:
func_name = parsing_names[1]
else:
continue
else:
continue
var_info_path = os.path.join(model_path, var_info_filename)

@ -864,6 +864,18 @@ class TestJitSaveLoadMultiMethods(unittest.TestCase):
paddle.jit.save(
layer, model_path, input_spec=[InputSpec(shape=[None, 784])])
def test_parse_name(self):
model_path_inference = "jit_save_load_parse_name/model"
IMAGE_SIZE = 224
layer = LinearNet(IMAGE_SIZE, 1)
inps = paddle.randn([1, IMAGE_SIZE])
layer(inps)
paddle.jit.save(layer, model_path_inference)
paddle.jit.save(layer, model_path_inference + '_v2')
load_net = paddle.jit.load(model_path_inference)
self.assertFalse(hasattr(load_net, 'v2'))
class LayerSaved(paddle.nn.Layer):
def __init__(self, in_size, out_size):

Loading…
Cancel
Save