|
|
@ -75,7 +75,7 @@ def save_persistables(vardict, dirname, filename=None):
|
|
|
|
_save_var_to_file(vardict, dirname, filename)
|
|
|
|
_save_var_to_file(vardict, dirname, filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_persistables(vardict, dirname, filename=None):
|
|
|
|
def load_persistables(dirname):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
This function trys to load persistable variables from the folder
|
|
|
|
This function trys to load persistable variables from the folder
|
|
|
|
`dirname` or the file `filename`.
|
|
|
|
`dirname` or the file `filename`.
|
|
|
@ -86,11 +86,7 @@ def load_persistables(vardict, dirname, filename=None):
|
|
|
|
the file name.
|
|
|
|
the file name.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
vardict(dict of Parameters): The parameters will be loaded.
|
|
|
|
|
|
|
|
dirname(str): The directory path.
|
|
|
|
dirname(str): The directory path.
|
|
|
|
filename(str|None): The file which saved all variables, this file path should be end with '.npz'. If variables were
|
|
|
|
|
|
|
|
saved in differnet files, set it to None.
|
|
|
|
|
|
|
|
Default: None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
dict: The parameter-dict resumed from file
|
|
|
|
dict: The parameter-dict resumed from file
|
|
|
@ -104,10 +100,7 @@ def load_persistables(vardict, dirname, filename=None):
|
|
|
|
param_1 = param_dict['PtbModel_0.w_1']
|
|
|
|
param_1 = param_dict['PtbModel_0.w_1']
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if isinstance(vardict, collections.OrderedDict):
|
|
|
|
return _load_var_from_file(dirname)
|
|
|
|
return _load_var_from_file(vardict, dirname, filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _save_var_to_file(stat_dict, file_dir, file_name):
|
|
|
|
def _save_var_to_file(stat_dict, file_dir, file_name):
|
|
|
@ -139,41 +132,37 @@ def _save_var_to_file(stat_dict, file_dir, file_name):
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_var_from_file(stat_dict, file_dir, file_name):
|
|
|
|
def _load_var_from_file(file_dir):
|
|
|
|
load_block = default_main_program().global_block()
|
|
|
|
|
|
|
|
load_var_map = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for var_key, each_var in stat_dict.items():
|
|
|
|
def walk_filename(file_dir):
|
|
|
|
assert isinstance(each_var, Variable)
|
|
|
|
var_name_list = []
|
|
|
|
if each_var.type == core.VarDesc.VarType.RAW:
|
|
|
|
if os.path.exists(file_dir) and os.path.exists(os.path.join(file_dir)):
|
|
|
|
continue
|
|
|
|
base_path = os.path.join(file_dir)
|
|
|
|
new_var = _clone_var_in_block_(load_block, each_var)
|
|
|
|
for dirpath, dirnames, filenames in os.walk(os.path.join(file_dir)):
|
|
|
|
if file_name is None:
|
|
|
|
pt = dirpath.replace(base_path, "", 1)[1:]
|
|
|
|
load_block.append_op(
|
|
|
|
for fth_name in filenames:
|
|
|
|
type='load',
|
|
|
|
if fth_name[0] != '.':
|
|
|
|
inputs={},
|
|
|
|
var_name_list.append(os.path.join(pt, fth_name))
|
|
|
|
outputs={'Out': [new_var]},
|
|
|
|
|
|
|
|
attrs={
|
|
|
|
|
|
|
|
'file_path': os.path.join(file_dir,
|
|
|
|
|
|
|
|
os.path.normpath(each_var.name))
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
load_var_map[new_var.name] = new_var
|
|
|
|
return var_name_list
|
|
|
|
|
|
|
|
|
|
|
|
if file_name is not None:
|
|
|
|
load_block = default_main_program().global_block()
|
|
|
|
load_var_list = []
|
|
|
|
load_var_map = {}
|
|
|
|
for name in sorted(load_var_map.keys()):
|
|
|
|
|
|
|
|
load_var_list.append(load_var_map[name])
|
|
|
|
file_var_list = walk_filename(file_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for var_name in file_var_list:
|
|
|
|
|
|
|
|
new_var = Variable(block=load_block, name=var_name)
|
|
|
|
load_block.append_op(
|
|
|
|
load_block.append_op(
|
|
|
|
type='load_combine',
|
|
|
|
type='load',
|
|
|
|
inputs={},
|
|
|
|
inputs={},
|
|
|
|
outputs={"Out": load_var_list},
|
|
|
|
outputs={'Out': [new_var]},
|
|
|
|
attrs={
|
|
|
|
attrs={
|
|
|
|
'file_path': os.path.join(file_dir, os.path.normpath(file_name))
|
|
|
|
'file_path': os.path.join(file_dir,
|
|
|
|
|
|
|
|
os.path.normpath(new_var.name))
|
|
|
|
})
|
|
|
|
})
|
|
|
|
for res_var in load_var_list:
|
|
|
|
|
|
|
|
load_var_map[res_var.name] = res_var
|
|
|
|
load_var_map[new_var.name] = new_var
|
|
|
|
|
|
|
|
|
|
|
|
return load_var_map
|
|
|
|
return load_var_map
|
|
|
|
|
|
|
|
|
|
|
|