|
|
@ -765,7 +765,7 @@ def load_vars(executor,
|
|
|
|
var_temp = paddle.fluid.global_scope().find_var(each_var.name)
|
|
|
|
var_temp = paddle.fluid.global_scope().find_var(each_var.name)
|
|
|
|
assert var_temp != None, "can't not find var: " + each_var.name
|
|
|
|
assert var_temp != None, "can't not find var: " + each_var.name
|
|
|
|
new_shape = (np.array(var_temp.get_tensor())).shape
|
|
|
|
new_shape = (np.array(var_temp.get_tensor())).shape
|
|
|
|
assert each_var.name in orig_para_shape, earch_var.name + "MUST in var list"
|
|
|
|
assert each_var.name in orig_para_shape, each_var.name + "MUST in var list"
|
|
|
|
orig_shape = orig_para_shape.get(each_var.name)
|
|
|
|
orig_shape = orig_para_shape.get(each_var.name)
|
|
|
|
if new_shape != orig_shape:
|
|
|
|
if new_shape != orig_shape:
|
|
|
|
raise RuntimeError(
|
|
|
|
raise RuntimeError(
|
|
|
@ -1541,14 +1541,14 @@ def save(program, model_path):
|
|
|
|
parameter_list = list(filter(is_parameter, program.list_vars()))
|
|
|
|
parameter_list = list(filter(is_parameter, program.list_vars()))
|
|
|
|
param_dict = {p.name: get_tensor(p) for p in parameter_list}
|
|
|
|
param_dict = {p.name: get_tensor(p) for p in parameter_list}
|
|
|
|
with open(model_path + ".pdparams", 'wb') as f:
|
|
|
|
with open(model_path + ".pdparams", 'wb') as f:
|
|
|
|
pickle.dump(param_dict, f)
|
|
|
|
pickle.dump(param_dict, f, protocol=2)
|
|
|
|
|
|
|
|
|
|
|
|
optimizer_var_list = list(
|
|
|
|
optimizer_var_list = list(
|
|
|
|
filter(is_belong_to_optimizer, program.list_vars()))
|
|
|
|
filter(is_belong_to_optimizer, program.list_vars()))
|
|
|
|
|
|
|
|
|
|
|
|
opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list}
|
|
|
|
opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list}
|
|
|
|
with open(model_path + ".pdopt", 'wb') as f:
|
|
|
|
with open(model_path + ".pdopt", 'wb') as f:
|
|
|
|
pickle.dump(opt_dict, f)
|
|
|
|
pickle.dump(opt_dict, f, protocol=2)
|
|
|
|
|
|
|
|
|
|
|
|
main_program = program.clone()
|
|
|
|
main_program = program.clone()
|
|
|
|
program.desc.flush()
|
|
|
|
program.desc.flush()
|
|
|
@ -1695,7 +1695,8 @@ def load(program, model_path, executor=None, var_list=None):
|
|
|
|
global_scope(),
|
|
|
|
global_scope(),
|
|
|
|
executor._default_executor)
|
|
|
|
executor._default_executor)
|
|
|
|
with open(parameter_file_name, 'rb') as f:
|
|
|
|
with open(parameter_file_name, 'rb') as f:
|
|
|
|
load_dict = pickle.load(f)
|
|
|
|
load_dict = pickle.load(f) if six.PY2 else pickle.load(
|
|
|
|
|
|
|
|
f, encoding='bytes')
|
|
|
|
for v in parameter_list:
|
|
|
|
for v in parameter_list:
|
|
|
|
assert v.name in load_dict, \
|
|
|
|
assert v.name in load_dict, \
|
|
|
|
"Can not find [{}] in model file [{}]".format(
|
|
|
|
"Can not find [{}] in model file [{}]".format(
|
|
|
@ -1715,7 +1716,8 @@ def load(program, model_path, executor=None, var_list=None):
|
|
|
|
optimizer_var_list, global_scope(), executor._default_executor)
|
|
|
|
optimizer_var_list, global_scope(), executor._default_executor)
|
|
|
|
|
|
|
|
|
|
|
|
with open(opt_file_name, 'rb') as f:
|
|
|
|
with open(opt_file_name, 'rb') as f:
|
|
|
|
load_dict = pickle.load(f)
|
|
|
|
load_dict = pickle.load(f) if six.PY2 else pickle.load(
|
|
|
|
|
|
|
|
f, encoding='bytes')
|
|
|
|
for v in optimizer_var_list:
|
|
|
|
for v in optimizer_var_list:
|
|
|
|
assert v.name in load_dict, \
|
|
|
|
assert v.name in load_dict, \
|
|
|
|
"Can not find [{}] in model file [{}]".format(
|
|
|
|
"Can not find [{}] in model file [{}]".format(
|
|
|
@ -1839,12 +1841,14 @@ def load_program_state(model_path, var_list=None):
|
|
|
|
"Parameter file [{}] not exits".format(parameter_file_name)
|
|
|
|
"Parameter file [{}] not exits".format(parameter_file_name)
|
|
|
|
|
|
|
|
|
|
|
|
with open(parameter_file_name, 'rb') as f:
|
|
|
|
with open(parameter_file_name, 'rb') as f:
|
|
|
|
para_dict = pickle.load(f)
|
|
|
|
para_dict = pickle.load(f) if six.PY2 else pickle.load(
|
|
|
|
|
|
|
|
f, encoding='bytes')
|
|
|
|
|
|
|
|
|
|
|
|
opt_file_name = model_prefix + ".pdopt"
|
|
|
|
opt_file_name = model_prefix + ".pdopt"
|
|
|
|
if os.path.exists(opt_file_name):
|
|
|
|
if os.path.exists(opt_file_name):
|
|
|
|
with open(opt_file_name, 'rb') as f:
|
|
|
|
with open(opt_file_name, 'rb') as f:
|
|
|
|
opti_dict = pickle.load(f)
|
|
|
|
opti_dict = pickle.load(f) if six.PY2 else pickle.load(
|
|
|
|
|
|
|
|
f, encoding='bytes')
|
|
|
|
|
|
|
|
|
|
|
|
para_dict.update(opti_dict)
|
|
|
|
para_dict.update(opti_dict)
|
|
|
|
|
|
|
|
|
|
|
|