Format file path (#17280)

The parameter dirpath will be passed directly to c++ operater. The file address format will be different under win and UNIX.
revert-17304-fix_default_paddle_version
lujun 6 years ago committed by GitHub
parent 5d6a1fcf16
commit a88a1faa48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -17,8 +17,6 @@ from __future__ import print_function
import os import os
import errno import errno
import warnings import warnings
import time
import shutil
import six import six
import logging import logging
from functools import reduce from functools import reduce
@ -168,6 +166,7 @@ def save_vars(executor,
# var_a, var_b and var_c will be saved. And they are going to be # var_a, var_b and var_c will be saved. And they are going to be
# saved in the same file named 'var_file' in the path "./my_paddle_model". # saved in the same file named 'var_file' in the path "./my_paddle_model".
""" """
save_dirname = os.path.normpath(dirname)
if vars is None: if vars is None:
if main_program is None: if main_program is None:
main_program = default_main_program() main_program = default_main_program()
@ -177,7 +176,7 @@ def save_vars(executor,
save_vars( save_vars(
executor, executor,
main_program=main_program, main_program=main_program,
dirname=dirname, dirname=save_dirname,
vars=list(filter(predicate, main_program.list_vars())), vars=list(filter(predicate, main_program.list_vars())),
filename=filename) filename=filename)
else: else:
@ -200,7 +199,9 @@ def save_vars(executor,
type='save', type='save',
inputs={'X': [new_var]}, inputs={'X': [new_var]},
outputs={}, outputs={},
attrs={'file_path': os.path.join(dirname, new_var.name)}) attrs={
'file_path': os.path.join(save_dirname, new_var.name)
})
else: else:
save_var_map[new_var.name] = new_var save_var_map[new_var.name] = new_var
@ -213,7 +214,7 @@ def save_vars(executor,
type='save_combine', type='save_combine',
inputs={'X': save_var_list}, inputs={'X': save_var_list},
outputs={}, outputs={},
attrs={'file_path': os.path.join(dirname, filename)}) attrs={'file_path': os.path.join(save_dirname, filename)})
executor.run(save_program) executor.run(save_program)
@ -567,6 +568,7 @@ def load_vars(executor,
# var_a, var_b and var_c will be loaded. And they are supposed to haven # var_a, var_b and var_c will be loaded. And they are supposed to haven
# been saved in the same file named 'var_file' in the path "./my_paddle_model". # been saved in the same file named 'var_file' in the path "./my_paddle_model".
""" """
load_dirname = os.path.normpath(dirname)
if vars is None: if vars is None:
if main_program is None: if main_program is None:
main_program = default_main_program() main_program = default_main_program()
@ -575,7 +577,7 @@ def load_vars(executor,
load_vars( load_vars(
executor, executor,
dirname=dirname, dirname=load_dirname,
main_program=main_program, main_program=main_program,
vars=list(filter(predicate, main_program.list_vars())), vars=list(filter(predicate, main_program.list_vars())),
filename=filename) filename=filename)
@ -599,7 +601,9 @@ def load_vars(executor,
type='load', type='load',
inputs={}, inputs={},
outputs={'Out': [new_var]}, outputs={'Out': [new_var]},
attrs={'file_path': os.path.join(dirname, new_var.name)}) attrs={
'file_path': os.path.join(load_dirname, new_var.name)
})
else: else:
load_var_map[new_var.name] = new_var load_var_map[new_var.name] = new_var
@ -612,7 +616,7 @@ def load_vars(executor,
type='load_combine', type='load_combine',
inputs={}, inputs={},
outputs={"Out": load_var_list}, outputs={"Out": load_var_list},
attrs={'file_path': os.path.join(dirname, filename)}) attrs={'file_path': os.path.join(load_dirname, filename)})
executor.run(load_prog) executor.run(load_prog)
@ -985,8 +989,10 @@ def save_inference_model(dirname,
target_var_name_list = [var.name for var in target_vars] target_var_name_list = [var.name for var in target_vars]
# when a pserver and a trainer running on the same machine, mkdir may conflict # when a pserver and a trainer running on the same machine, mkdir may conflict
save_dirname = dirname
try: try:
os.makedirs(dirname) save_dirname = os.path.normpath(dirname)
os.makedirs(save_dirname)
except OSError as e: except OSError as e:
if e.errno != errno.EEXIST: if e.errno != errno.EEXIST:
raise raise
@ -995,7 +1001,7 @@ def save_inference_model(dirname,
model_basename = os.path.basename(model_filename) model_basename = os.path.basename(model_filename)
else: else:
model_basename = "__model__" model_basename = "__model__"
model_basename = os.path.join(dirname, model_basename) model_basename = os.path.join(save_dirname, model_basename)
# When export_for_deployment is true, we modify the program online so that # When export_for_deployment is true, we modify the program online so that
# it can only be loaded for inference directly. If it's false, the whole # it can only be loaded for inference directly. If it's false, the whole
@ -1038,7 +1044,7 @@ def save_inference_model(dirname,
if params_filename is not None: if params_filename is not None:
params_filename = os.path.basename(params_filename) params_filename = os.path.basename(params_filename)
save_persistables(executor, dirname, main_program, params_filename) save_persistables(executor, save_dirname, main_program, params_filename)
return target_var_name_list return target_var_name_list
@ -1102,14 +1108,15 @@ def load_inference_model(dirname,
# program to get the inference result. # program to get the inference result.
""" """
if not os.path.isdir(dirname): load_dirname = os.path.normpath(dirname)
if not os.path.isdir(load_dirname):
raise ValueError("There is no directory named '%s'", dirname) raise ValueError("There is no directory named '%s'", dirname)
if model_filename is not None: if model_filename is not None:
model_filename = os.path.basename(model_filename) model_filename = os.path.basename(model_filename)
else: else:
model_filename = "__model__" model_filename = "__model__"
model_filename = os.path.join(dirname, model_filename) model_filename = os.path.join(load_dirname, model_filename)
if params_filename is not None: if params_filename is not None:
params_filename = os.path.basename(params_filename) params_filename = os.path.basename(params_filename)
@ -1122,7 +1129,7 @@ def load_inference_model(dirname,
raise ValueError("Unsupported program version: %d\n" % raise ValueError("Unsupported program version: %d\n" %
program._version()) program._version())
# Binary data also need versioning. # Binary data also need versioning.
load_persistables(executor, dirname, program, params_filename) load_persistables(executor, load_dirname, program, params_filename)
if pserver_endpoints: if pserver_endpoints:
program = _endpoints_replacement(program, pserver_endpoints) program = _endpoints_replacement(program, pserver_endpoints)

Loading…
Cancel
Save