|
|
|
@ -20,7 +20,9 @@ from paddle.fluid.imperative.nn import EMBEDDING
|
|
|
|
|
import paddle.fluid.framework as framework
|
|
|
|
|
from paddle.fluid.optimizer import SGDOptimizer
|
|
|
|
|
from paddle.fluid.imperative.base import to_variable
|
|
|
|
|
from test_imperative_base import new_program_scope
|
|
|
|
|
import numpy as np
|
|
|
|
|
import six
|
|
|
|
|
from paddle.fluid.backward import append_backward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -36,8 +38,8 @@ class SimpleLSTMRNN(fluid.imperative.Layer):
|
|
|
|
|
self._num_layers = num_layers
|
|
|
|
|
self._init_scale = init_scale
|
|
|
|
|
self._dropout = dropout
|
|
|
|
|
self.input = None
|
|
|
|
|
self.num_steps = num_steps
|
|
|
|
|
self._input = None
|
|
|
|
|
self._num_steps = num_steps
|
|
|
|
|
|
|
|
|
|
def _build_once(self, input_embedding, init_hidden=None, init_cell=None):
|
|
|
|
|
self.weight_1_arr = []
|
|
|
|
@ -75,58 +77,49 @@ class SimpleLSTMRNN(fluid.imperative.Layer):
|
|
|
|
|
|
|
|
|
|
def forward(self, input_embedding, init_hidden=None, init_cell=None):
|
|
|
|
|
res = []
|
|
|
|
|
for index in range(self.num_steps):
|
|
|
|
|
self.input = fluid.layers.slice(
|
|
|
|
|
for index in range(self._num_steps):
|
|
|
|
|
self._input = fluid.layers.slice(
|
|
|
|
|
input_embedding, axes=[1], starts=[index], ends=[index + 1])
|
|
|
|
|
self.input = fluid.layers.reshape(
|
|
|
|
|
self.input, shape=[-1, self._hidden_size])
|
|
|
|
|
self._input = fluid.layers.reshape(
|
|
|
|
|
self._input, shape=[-1, self._hidden_size])
|
|
|
|
|
for k in range(self._num_layers):
|
|
|
|
|
pre_hidden = self.hidden_array[k]
|
|
|
|
|
print("pre_hidden shape is:{}".format(pre_hidden.shape))
|
|
|
|
|
print("input shape is:{}".format(self.input.shape))
|
|
|
|
|
pre_cell = self.cell_array[k]
|
|
|
|
|
weight_1 = self.weight_1_arr[k]
|
|
|
|
|
bias = self.bias_arr[k]
|
|
|
|
|
|
|
|
|
|
nn = fluid.layers.concat([self.input, pre_hidden], 1)
|
|
|
|
|
nn = fluid.layers.concat([self._input, pre_hidden], 1)
|
|
|
|
|
gate_input = fluid.layers.matmul(x=nn, y=weight_1)
|
|
|
|
|
|
|
|
|
|
gate_input = fluid.layers.elementwise_add(gate_input, bias)
|
|
|
|
|
print("gate_input shape is: {}".format(gate_input.shape))
|
|
|
|
|
print("gate_input value is :{}".format(gate_input._numpy()))
|
|
|
|
|
print("gate_input desc is :{}".format(gate_input))
|
|
|
|
|
# i, j, f, o = fluid.layers.split(gate_input, num_or_sections=4, dim=-1)
|
|
|
|
|
# #
|
|
|
|
|
# # c = pre_cell * fluid.layers.sigmoid(f) + fluid.layers.sigmoid(
|
|
|
|
|
# # i) * fluid.layers.tanh(j)
|
|
|
|
|
# # m = fluid.layers.tanh(c) * fluid.layers.sigmoid(o)
|
|
|
|
|
# #
|
|
|
|
|
# # self.hidden_array[k] = m
|
|
|
|
|
# # self.cell_array[k] = c
|
|
|
|
|
# # self.input = m
|
|
|
|
|
# #
|
|
|
|
|
# # if self.dropout is not None and self.dropout > 0.0:
|
|
|
|
|
# # self.input = fluid.layers.dropout(
|
|
|
|
|
# # self.input,
|
|
|
|
|
# # dropout_prob=self.dropout,
|
|
|
|
|
# # dropout_implementation='upscale_in_train')
|
|
|
|
|
# #
|
|
|
|
|
# # res.append(
|
|
|
|
|
# # fluid.layers.reshape(
|
|
|
|
|
# # input, shape=[1, -1, self._hidden_size]))
|
|
|
|
|
# # real_res = fluid.layers.concat(res, 0)
|
|
|
|
|
# # real_res = fluid.layers.transpose(x=real_res, perm=[1, 0, 2])
|
|
|
|
|
# # last_hidden = fluid.layers.concat(self.hidden_array, 1)
|
|
|
|
|
# # last_hidden = fluid.layers.reshape(
|
|
|
|
|
# # last_hidden, shape=[-1, self._num_layers, self._hidden_size])
|
|
|
|
|
# # last_hidden = fluid.layers.transpose(x=last_hidden, perm=[1, 0, 2])
|
|
|
|
|
# # last_cell = fluid.layers.concat(self.cell_array, 1)
|
|
|
|
|
# # last_cell = fluid.layers.reshape(
|
|
|
|
|
# # last_cell, shape=[-1, self._num_layers, self._hidden_size])
|
|
|
|
|
# # last_cell = fluid.layers.transpose(x=last_cell, perm=[1, 0, 2])
|
|
|
|
|
# #
|
|
|
|
|
# return real_res, last_hidden, last_cell
|
|
|
|
|
return [1], [2], [3]
|
|
|
|
|
i, j, f, o = fluid.layers.split(
|
|
|
|
|
gate_input, num_or_sections=4, dim=-1)
|
|
|
|
|
c = pre_cell * fluid.layers.sigmoid(f) + fluid.layers.sigmoid(
|
|
|
|
|
i) * fluid.layers.tanh(j)
|
|
|
|
|
m = fluid.layers.tanh(c) * fluid.layers.sigmoid(o)
|
|
|
|
|
self.hidden_array[k] = m
|
|
|
|
|
self.cell_array[k] = c
|
|
|
|
|
self._input = m
|
|
|
|
|
|
|
|
|
|
if self._dropout is not None and self._dropout > 0.0:
|
|
|
|
|
self._input = fluid.layers.dropout(
|
|
|
|
|
self._input,
|
|
|
|
|
dropout_prob=self._dropout,
|
|
|
|
|
dropout_implementation='upscale_in_train')
|
|
|
|
|
res.append(
|
|
|
|
|
fluid.layers.reshape(
|
|
|
|
|
self._input, shape=[1, -1, self._hidden_size]))
|
|
|
|
|
real_res = fluid.layers.concat(res, 0)
|
|
|
|
|
real_res = fluid.layers.transpose(x=real_res, perm=[1, 0, 2])
|
|
|
|
|
last_hidden = fluid.layers.concat(self.hidden_array, 1)
|
|
|
|
|
last_hidden = fluid.layers.reshape(
|
|
|
|
|
last_hidden, shape=[-1, self._num_layers, self._hidden_size])
|
|
|
|
|
last_hidden = fluid.layers.transpose(x=last_hidden, perm=[1, 0, 2])
|
|
|
|
|
last_cell = fluid.layers.concat(self.cell_array, 1)
|
|
|
|
|
last_cell = fluid.layers.reshape(
|
|
|
|
|
last_cell, shape=[-1, self._num_layers, self._hidden_size])
|
|
|
|
|
last_cell = fluid.layers.transpose(x=last_cell, perm=[1, 0, 2])
|
|
|
|
|
return real_res, last_hidden, last_cell
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PtbModel(fluid.imperative.Layer):
|
|
|
|
@ -189,12 +182,11 @@ class PtbModel(fluid.imperative.Layer):
|
|
|
|
|
x_emb,
|
|
|
|
|
dropout_prob=self.drop_out,
|
|
|
|
|
dropout_implementation='upscale_in_train')
|
|
|
|
|
print("init_c is {}".format(init_c))
|
|
|
|
|
rnn_out, last_hidden, last_cell = self.simple_lstm_rnn(x_emb, init_h,
|
|
|
|
|
init_c)
|
|
|
|
|
rnn_out = fluid.layers.reshape(
|
|
|
|
|
rnn_out, shape=[-1, self.num_steps, self.hidden_size])
|
|
|
|
|
projection = fluid.layers.reshape(rnn_out, self.softmax_weight)
|
|
|
|
|
projection = fluid.layers.matmul(rnn_out, self.softmax_weight)
|
|
|
|
|
projection = fluid.layers.elementwise_add(projection, self.softmax_bias)
|
|
|
|
|
projection = fluid.layers.reshape(
|
|
|
|
|
projection, shape=[-1, self.vocab_size])
|
|
|
|
@ -232,7 +224,8 @@ class TestImperativePtbRnn(unittest.TestCase):
|
|
|
|
|
init_scale=init_scale)
|
|
|
|
|
|
|
|
|
|
sgd = SGDOptimizer(learning_rate=1e-3)
|
|
|
|
|
print("q")
|
|
|
|
|
dy_param_updated = dict()
|
|
|
|
|
dy_param_init = dict()
|
|
|
|
|
for i in range(2):
|
|
|
|
|
x_data = np.arange(12).reshape(4, 3).astype('int64')
|
|
|
|
|
y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
|
|
|
|
@ -248,17 +241,95 @@ class TestImperativePtbRnn(unittest.TestCase):
|
|
|
|
|
init_cell = to_variable(init_cell_data)
|
|
|
|
|
dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
|
|
|
|
|
init_cell)
|
|
|
|
|
dy_param_init = dict()
|
|
|
|
|
if i == 0:
|
|
|
|
|
for param in fluid.default_main_program().global_block(
|
|
|
|
|
).all_parameters():
|
|
|
|
|
dy_param_init[param.name] = param._numpy()
|
|
|
|
|
dy_loss._backward()
|
|
|
|
|
sgd.minimize(dy_loss)
|
|
|
|
|
dy_param_updated = dict()
|
|
|
|
|
for param in fluid.default_main_program().global_block(
|
|
|
|
|
).all_parameters():
|
|
|
|
|
dy_param_updated[param.name] = param._numpy()
|
|
|
|
|
# print("dy_loss is {}".format(dy_loss._numpy()))
|
|
|
|
|
# print("last_hidden is {}".format(last_hidden._numpy()))
|
|
|
|
|
# print("last_cell is {}".format(last_cell._numpy()))
|
|
|
|
|
|
|
|
|
|
with new_program_scope():
|
|
|
|
|
fluid.default_startup_program().random_seed = seed
|
|
|
|
|
fluid.default_main_program().random_seed = seed
|
|
|
|
|
# TODO: marsyang1993 Change seed to
|
|
|
|
|
ptb_model = PtbModel(
|
|
|
|
|
hidden_size=hidden_size,
|
|
|
|
|
vocab_size=vocab_size,
|
|
|
|
|
num_layers=num_layers,
|
|
|
|
|
num_steps=num_steps,
|
|
|
|
|
init_scale=init_scale)
|
|
|
|
|
|
|
|
|
|
exe = fluid.Executor(fluid.CPUPlace())
|
|
|
|
|
sgd = SGDOptimizer(learning_rate=1e-3)
|
|
|
|
|
x = fluid.layers.data(name="x", shape=[-1, 3, 1], dtype='int64')
|
|
|
|
|
y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32')
|
|
|
|
|
init_hidden = fluid.layers.data(
|
|
|
|
|
name="init_hidden", shape=[1], dtype='float32')
|
|
|
|
|
init_cell = fluid.layers.data(
|
|
|
|
|
name="init_cell", shape=[1], dtype='float32')
|
|
|
|
|
|
|
|
|
|
static_loss, static_last_hidden, static_last_cell = ptb_model(
|
|
|
|
|
x, y, init_hidden, init_cell)
|
|
|
|
|
sgd.minimize(static_loss)
|
|
|
|
|
static_param_updated = dict()
|
|
|
|
|
static_param_init = dict()
|
|
|
|
|
static_param_name_list = list()
|
|
|
|
|
for param in fluid.default_startup_program().global_block(
|
|
|
|
|
).all_parameters():
|
|
|
|
|
static_param_name_list.append(param.name)
|
|
|
|
|
|
|
|
|
|
out = exe.run(framework.default_startup_program(),
|
|
|
|
|
fetch_list=static_param_name_list)
|
|
|
|
|
for i in range(len(static_param_name_list)):
|
|
|
|
|
static_param_init[static_param_name_list[i]] = out[i]
|
|
|
|
|
|
|
|
|
|
for i in range(2):
|
|
|
|
|
x_data = np.arange(12).reshape(4, 3).astype('int64')
|
|
|
|
|
y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
|
|
|
|
|
x_data = x_data.reshape((-1, num_steps, 1))
|
|
|
|
|
y_data = y_data.reshape((-1, 1))
|
|
|
|
|
init_hidden_data = np.zeros(
|
|
|
|
|
(num_layers, batch_size, hidden_size), dtype='float32')
|
|
|
|
|
init_cell_data = np.zeros(
|
|
|
|
|
(num_layers, batch_size, hidden_size), dtype='float32')
|
|
|
|
|
fetch_list = [static_loss, static_last_hidden, static_last_cell]
|
|
|
|
|
fetch_list.extend(static_param_name_list)
|
|
|
|
|
out = exe.run(fluid.default_main_program(),
|
|
|
|
|
feed={
|
|
|
|
|
"x": x_data,
|
|
|
|
|
"y": y_data,
|
|
|
|
|
"init_hidden": init_hidden_data,
|
|
|
|
|
"init_cell": init_cell_data
|
|
|
|
|
},
|
|
|
|
|
fetch_list=fetch_list)
|
|
|
|
|
static_loss_value = out[0]
|
|
|
|
|
static_last_cell_value = out[1]
|
|
|
|
|
static_last_hidden_value = out[2]
|
|
|
|
|
# print("static_loss is {}".format(out[0]))
|
|
|
|
|
# print("last_hidden is {}".format(out[1]))
|
|
|
|
|
# print("last_cell is {}".format(out[2]))
|
|
|
|
|
for i in range(3, len(out)):
|
|
|
|
|
static_param_updated[static_param_name_list[i - 3]] = out[i]
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(static_loss_value.all(), dy_loss._numpy().all()))
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(static_last_cell_value.all(),
|
|
|
|
|
last_cell._numpy().all()))
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(static_last_hidden_value.all(),
|
|
|
|
|
last_hidden._numpy().all()))
|
|
|
|
|
for key, value in six.iteritems(static_param_init):
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(value.all(), dy_param_init[key].all()))
|
|
|
|
|
for key, value in six.iteritems(static_param_updated):
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(value.all(), dy_param_updated[key].all()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|