|
|
|
@ -59,7 +59,7 @@ class SimpleLSTMRNN(fluid.imperative.Layer):
|
|
|
|
|
dtype="float32",
|
|
|
|
|
default_initializer=fluid.initializer.UniformInitializer(
|
|
|
|
|
low=-self._init_scale, high=self._init_scale))
|
|
|
|
|
self.weight_1_arr.append(weight_1)
|
|
|
|
|
self.weight_1_arr.append(self.add_parameter('w_%d' % i, weight_1))
|
|
|
|
|
bias_1 = self.create_parameter(
|
|
|
|
|
attr=fluid.ParamAttr(
|
|
|
|
|
initializer=fluid.initializer.UniformInitializer(
|
|
|
|
@ -67,7 +67,7 @@ class SimpleLSTMRNN(fluid.imperative.Layer):
|
|
|
|
|
shape=[self._hidden_size * 4],
|
|
|
|
|
dtype="float32",
|
|
|
|
|
default_initializer=fluid.initializer.Constant(0.0))
|
|
|
|
|
self.bias_arr.append(bias_1)
|
|
|
|
|
self.bias_arr.append(self.add_parameter('b_%d' % i, bias_1))
|
|
|
|
|
|
|
|
|
|
def forward(self, input_embedding, init_hidden=None, init_cell=None):
|
|
|
|
|
self.cell_array = []
|
|
|
|
@ -242,7 +242,7 @@ class TestImperativePtbRnn(unittest.TestCase):
|
|
|
|
|
dy_loss = None
|
|
|
|
|
last_hidden = None
|
|
|
|
|
last_cell = None
|
|
|
|
|
batch_num = 50
|
|
|
|
|
batch_num = 200
|
|
|
|
|
|
|
|
|
|
for i in range(batch_num):
|
|
|
|
|
x_data = np.arange(12).reshape(4, 3).astype('int64')
|
|
|
|
@ -264,8 +264,10 @@ class TestImperativePtbRnn(unittest.TestCase):
|
|
|
|
|
dy_param_init[param.name] = param._numpy()
|
|
|
|
|
dy_loss._backward()
|
|
|
|
|
sgd.minimize(dy_loss)
|
|
|
|
|
for param in ptb_model.parameters():
|
|
|
|
|
dy_param_updated[param.name] = param._numpy()
|
|
|
|
|
ptb_model.clear_gradients()
|
|
|
|
|
if i == batch_num - 1:
|
|
|
|
|
for param in ptb_model.parameters():
|
|
|
|
|
dy_param_updated[param.name] = param._numpy()
|
|
|
|
|
|
|
|
|
|
with new_program_scope():
|
|
|
|
|
fluid.default_startup_program().random_seed = seed
|
|
|
|
@ -323,25 +325,28 @@ class TestImperativePtbRnn(unittest.TestCase):
|
|
|
|
|
},
|
|
|
|
|
fetch_list=fetch_list)
|
|
|
|
|
static_loss_value = out[0]
|
|
|
|
|
static_last_cell_value = out[1]
|
|
|
|
|
static_last_hidden_value = out[2]
|
|
|
|
|
for k in range(3, len(out)):
|
|
|
|
|
static_param_updated[static_param_name_list[k - 3]] = out[k]
|
|
|
|
|
static_last_hidden_value = out[1]
|
|
|
|
|
static_last_cell_value = out[2]
|
|
|
|
|
|
|
|
|
|
if i == batch_num - 1:
|
|
|
|
|
for k in range(3, len(out)):
|
|
|
|
|
static_param_updated[static_param_name_list[k -
|
|
|
|
|
3]] = out[k]
|
|
|
|
|
|
|
|
|
|
self.assertTrue(np.allclose(static_loss_value, dy_loss._numpy()))
|
|
|
|
|
self.assertTrue(np.allclose(static_last_cell_value, last_cell._numpy()))
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(static_last_hidden_value, last_hidden._numpy()))
|
|
|
|
|
for key, value in six.iteritems(static_param_init):
|
|
|
|
|
# print("static_init name: {}, value {}".format(key, value))
|
|
|
|
|
# print("dy_init name: {}, value {}".format(key, dy_param_init[key]))
|
|
|
|
|
self.assertTrue(np.allclose(value, dy_param_init[key], atol=1e-5))
|
|
|
|
|
for key, value in six.iteritems(static_param_updated):
|
|
|
|
|
# print("static name: {}, value {}".format(key, value))
|
|
|
|
|
# print("dy name: {}, value {}".format(key, dy_param_updated[key]))
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(static_loss_value.all(), dy_loss._numpy().all()))
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(static_last_cell_value.all(),
|
|
|
|
|
last_cell._numpy().all()))
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(static_last_hidden_value.all(),
|
|
|
|
|
last_hidden._numpy().all()))
|
|
|
|
|
for key, value in six.iteritems(static_param_init):
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(value.all(), dy_param_init[key].all()))
|
|
|
|
|
for key, value in six.iteritems(static_param_updated):
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(value.all(), dy_param_updated[key].all()))
|
|
|
|
|
np.allclose(
|
|
|
|
|
value, dy_param_updated[key], atol=1e-5))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|