|
|
|
@ -113,14 +113,17 @@ def load_persistables(vardict, dirname, filename=None):
|
|
|
|
|
def _save_var_to_file(stat_dict, file_dir, file_name):
|
|
|
|
|
save_block = default_main_program().global_block()
|
|
|
|
|
save_var_map = {}
|
|
|
|
|
for each_var in stat_dict.items():
|
|
|
|
|
for var_key, each_var in stat_dict.items():
|
|
|
|
|
save_var_map[each_var.name] = each_var
|
|
|
|
|
if file_name is None:
|
|
|
|
|
save_block.append_op(
|
|
|
|
|
type='save',
|
|
|
|
|
inputs={'X': [each_var]},
|
|
|
|
|
outputs={},
|
|
|
|
|
attrs={'file_path': os.path.join(file_dir, each_var.name)})
|
|
|
|
|
attrs={
|
|
|
|
|
'file_path': os.path.join(file_dir,
|
|
|
|
|
os.path.normpath(each_var.name))
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if file_name is not None:
|
|
|
|
|
save_var_list = []
|
|
|
|
@ -131,14 +134,16 @@ def _save_var_to_file(stat_dict, file_dir, file_name):
|
|
|
|
|
type='save_combine',
|
|
|
|
|
inputs={'X': save_var_list},
|
|
|
|
|
outputs={},
|
|
|
|
|
attrs={'file_path': os.path.join(file_dir, file_name)})
|
|
|
|
|
attrs={
|
|
|
|
|
'file_path': os.path.join(file_dir, os.path.normpath(file_name))
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_var_from_file(stat_dict, file_dir, file_name):
|
|
|
|
|
load_block = default_main_program().global_block()
|
|
|
|
|
load_var_map = {}
|
|
|
|
|
|
|
|
|
|
for each_var in stat_dict.items():
|
|
|
|
|
for var_key, each_var in stat_dict.items():
|
|
|
|
|
assert isinstance(each_var, Variable)
|
|
|
|
|
if each_var.type == core.VarDesc.VarType.RAW:
|
|
|
|
|
continue
|
|
|
|
@ -148,7 +153,10 @@ def _load_var_from_file(stat_dict, file_dir, file_name):
|
|
|
|
|
type='load',
|
|
|
|
|
inputs={},
|
|
|
|
|
outputs={'Out': [new_var]},
|
|
|
|
|
attrs={'file_path': os.path.join(file_dir, each_var.name)})
|
|
|
|
|
attrs={
|
|
|
|
|
'file_path': os.path.join(file_dir,
|
|
|
|
|
os.path.normpath(each_var.name))
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
load_var_map[new_var.name] = new_var
|
|
|
|
|
|
|
|
|
@ -161,7 +169,9 @@ def _load_var_from_file(stat_dict, file_dir, file_name):
|
|
|
|
|
type='load_combine',
|
|
|
|
|
inputs={},
|
|
|
|
|
outputs={"Out": load_var_list},
|
|
|
|
|
attrs={'file_path': os.path.join(file_dir, file_name)})
|
|
|
|
|
attrs={
|
|
|
|
|
'file_path': os.path.join(file_dir, os.path.normpath(file_name))
|
|
|
|
|
})
|
|
|
|
|
for res_var in load_var_list:
|
|
|
|
|
load_var_map[res_var.name] = res_var
|
|
|
|
|
|
|
|
|
@ -175,5 +185,5 @@ def _clone_var_in_block_(block, var):
|
|
|
|
|
shape=var.shape,
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
type=var.type,
|
|
|
|
|
lod_level=var.lod_level,
|
|
|
|
|
lod_level=0,
|
|
|
|
|
persistable=True)
|
|
|
|
|