|
|
|
@ -131,8 +131,7 @@ class TestImperativeMnist(unittest.TestCase):
|
|
|
|
|
dy_out = avg_loss._numpy()
|
|
|
|
|
|
|
|
|
|
if epoch == 0 and batch_id == 0:
|
|
|
|
|
for param in fluid.default_main_program().global_block(
|
|
|
|
|
).all_parameters():
|
|
|
|
|
for param in mnist.parameters():
|
|
|
|
|
dy_param_init_value[param.name] = param._numpy()
|
|
|
|
|
|
|
|
|
|
avg_loss._backward()
|
|
|
|
@ -142,8 +141,7 @@ class TestImperativeMnist(unittest.TestCase):
|
|
|
|
|
fluid.default_main_program().global_block()._clear_block()
|
|
|
|
|
|
|
|
|
|
dy_param_value = {}
|
|
|
|
|
for param in fluid.default_main_program().global_block(
|
|
|
|
|
).all_parameters():
|
|
|
|
|
for param in mnist.parameters():
|
|
|
|
|
dy_param_value[param.name] = param._numpy()
|
|
|
|
|
|
|
|
|
|
with new_program_scope():
|
|
|
|
@ -169,8 +167,7 @@ class TestImperativeMnist(unittest.TestCase):
|
|
|
|
|
# initialize params and fetch them
|
|
|
|
|
static_param_init_value = {}
|
|
|
|
|
static_param_name_list = []
|
|
|
|
|
for param in fluid.default_startup_program().global_block(
|
|
|
|
|
).all_parameters():
|
|
|
|
|
for param in mnist.parameters():
|
|
|
|
|
static_param_name_list.append(param.name)
|
|
|
|
|
|
|
|
|
|
out = exe.run(fluid.default_startup_program(),
|
|
|
|
@ -204,16 +201,12 @@ class TestImperativeMnist(unittest.TestCase):
|
|
|
|
|
self.assertTrue(np.allclose(dy_x_data.all(), static_x_data.all()))
|
|
|
|
|
|
|
|
|
|
for key, value in six.iteritems(static_param_init_value):
|
|
|
|
|
if not np.allclose(value, dy_param_init_value[key]):
|
|
|
|
|
print(key, value, dy_param_value[key])
|
|
|
|
|
# self.assertTrue(np.allclose(value, dy_param_init_value[key]))
|
|
|
|
|
self.assertTrue(np.allclose(value, dy_param_init_value[key]))
|
|
|
|
|
|
|
|
|
|
self.assertTrue(np.allclose(static_out, dy_out))
|
|
|
|
|
|
|
|
|
|
for key, value in six.iteritems(static_param_value):
|
|
|
|
|
if not np.allclose(value, dy_param_value[key], atol=1e-6):
|
|
|
|
|
print(key, value, dy_param_value[key])
|
|
|
|
|
# self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
|
|
|
|
|
self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|