revert-31068-fix_conv3d_windows
WeiXin 4 years ago committed by GitHub
parent 9fec1618d2
commit c0fb03a0dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2180,6 +2180,7 @@ def load_program_state(model_path, var_list=None):
with open(parameter_file_name, 'rb') as f:
para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
para_dict = _pack_loaded_dict(para_dict)
opt_file_name = model_prefix + ".pdopt"
if os.path.exists(opt_file_name):
@ -2231,6 +2232,7 @@ def set_program_state(program, state_dict):
static.set_program_state(prog, program_state)
"""
state_dict = _pack_loaded_dict(state_dict)
parameter_list = list(filter(is_persistable, program.list_vars()))
used_para_list = {}

@ -1365,6 +1365,25 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase):
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
# set var to zero
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
ten = fluid.global_scope().find_var(var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
self.assertTrue(np.sum(np.abs(new_t)) == 0)
program_state = fluid.load_program_state(path)
fluid.set_program_state(prog, program_state)
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
class TestProgramStateOldSaveSingleModel(unittest.TestCase):
def test_ptb_rnn_cpu_float32(self):

Binary file not shown.
Loading…
Cancel
Save