|
|
|
@ -117,7 +117,7 @@ class ModelHyperParams(object):
|
|
|
|
|
# to process after each sub-layer
|
|
|
|
|
postprocess_cmd = "da" # dropout + residual connection
|
|
|
|
|
# random seed used in dropout for CE.
|
|
|
|
|
dropout_seed = 1
|
|
|
|
|
dropout_seed = None
|
|
|
|
|
# the flag indicating whether to share embedding and softmax weights.
|
|
|
|
|
# vocabularies in source and target should be same for weight sharing.
|
|
|
|
|
weight_sharing = True
|
|
|
|
@ -167,15 +167,21 @@ def create_data(is_static=False):
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
enc_inputs = [
|
|
|
|
|
to_variable(src_word_np), to_variable(src_pos_np),
|
|
|
|
|
to_variable(src_slf_attn_bias_np)
|
|
|
|
|
to_variable(
|
|
|
|
|
src_word_np, name='src_word'), to_variable(
|
|
|
|
|
src_pos_np, name='src_pos'), to_variable(
|
|
|
|
|
src_slf_attn_bias_np, name='src_slf_attn_bias')
|
|
|
|
|
]
|
|
|
|
|
dec_inputs = [
|
|
|
|
|
to_variable(trg_word_np), to_variable(trg_pos_np),
|
|
|
|
|
to_variable(trg_slf_attn_bias_np), to_variable(trg_src_attn_bias_np)
|
|
|
|
|
to_variable(
|
|
|
|
|
trg_word_np, name='trg_word'), to_variable(
|
|
|
|
|
trg_pos_np, name='trg_pos'), to_variable(
|
|
|
|
|
trg_slf_attn_bias_np, name='trg_slf_attn_bias'),
|
|
|
|
|
to_variable(
|
|
|
|
|
trg_src_attn_bias_np, name='trg_src_attn_bias')
|
|
|
|
|
]
|
|
|
|
|
label = to_variable(lbl_word_np)
|
|
|
|
|
weight = to_variable(lbl_weight_np)
|
|
|
|
|
label = to_variable(lbl_word_np, name='lbl_word')
|
|
|
|
|
weight = to_variable(lbl_weight_np, name='lbl_weight')
|
|
|
|
|
return enc_inputs, dec_inputs, label, weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -212,7 +218,7 @@ def make_all_inputs(input_fields):
|
|
|
|
|
# The placeholder for batch_size in compile time. Must be -1 currently to be
|
|
|
|
|
# consistent with some ops' infer-shape output in compile time, such as the
|
|
|
|
|
# sequence_expand op used in beamsearch decoder.
|
|
|
|
|
batch_size = 32
|
|
|
|
|
batch_size = -1
|
|
|
|
|
# The placeholder for squence length in compile time.
|
|
|
|
|
seq_len = ModelHyperParams.max_length
|
|
|
|
|
# Here list the data shapes and data types of all inputs.
|
|
|
|
@ -306,54 +312,40 @@ sync = False
|
|
|
|
|
# how many batches we use
|
|
|
|
|
batch_num = 5
|
|
|
|
|
|
|
|
|
|
np.random.seed = 1
|
|
|
|
|
np.random.seed = 90
|
|
|
|
|
src_word_np = np.random.randint(
|
|
|
|
|
1,
|
|
|
|
|
ModelHyperParams.src_vocab_size - 1,
|
|
|
|
|
size=(batch_size, seq_len, 1),
|
|
|
|
|
size=(TrainTaskConfig.batch_size, seq_len, 1),
|
|
|
|
|
dtype='int64')
|
|
|
|
|
src_pos_np = np.random.randint(
|
|
|
|
|
1, seq_len, size=(batch_size, seq_len, 1), dtype='int64')
|
|
|
|
|
src_slf_attn_bias_np = np.random.randn(batch_size, ModelHyperParams.n_head,
|
|
|
|
|
seq_len, seq_len).astype('float32')
|
|
|
|
|
1, seq_len, size=(TrainTaskConfig.batch_size, seq_len, 1), dtype='int64')
|
|
|
|
|
src_slf_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size,
|
|
|
|
|
ModelHyperParams.n_head, seq_len,
|
|
|
|
|
seq_len).astype('float32')
|
|
|
|
|
|
|
|
|
|
trg_word_np = np.random.randint(
|
|
|
|
|
1,
|
|
|
|
|
ModelHyperParams.src_vocab_size - 1,
|
|
|
|
|
size=(batch_size, seq_len, 1),
|
|
|
|
|
size=(TrainTaskConfig.batch_size, seq_len, 1),
|
|
|
|
|
dtype='int64')
|
|
|
|
|
trg_pos_np = np.random.randint(
|
|
|
|
|
1, seq_len, size=(batch_size, seq_len, 1), dtype='int64')
|
|
|
|
|
trg_slf_attn_bias_np = np.random.randn(batch_size, ModelHyperParams.n_head,
|
|
|
|
|
seq_len, seq_len).astype('float32')
|
|
|
|
|
trg_src_attn_bias_np = np.random.randn(batch_size, ModelHyperParams.n_head,
|
|
|
|
|
seq_len, seq_len).astype('float32')
|
|
|
|
|
1, seq_len, size=(TrainTaskConfig.batch_size, seq_len, 1), dtype='int64')
|
|
|
|
|
trg_slf_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size,
|
|
|
|
|
ModelHyperParams.n_head, seq_len,
|
|
|
|
|
seq_len).astype('float32')
|
|
|
|
|
trg_src_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size,
|
|
|
|
|
ModelHyperParams.n_head, seq_len,
|
|
|
|
|
seq_len).astype('float32')
|
|
|
|
|
|
|
|
|
|
lbl_word_np = np.random.randint(
|
|
|
|
|
1,
|
|
|
|
|
ModelHyperParams.src_vocab_size - 1,
|
|
|
|
|
size=(batch_size * seq_len, 1),
|
|
|
|
|
size=(TrainTaskConfig.batch_size * seq_len, 1),
|
|
|
|
|
dtype='int64')
|
|
|
|
|
lbl_weight_np = np.random.randn(batch_size * seq_len, 1).astype('float32')
|
|
|
|
|
|
|
|
|
|
# np.random.seed = 1
|
|
|
|
|
# src_word_np = np.arange(0, 10).reshape([batch_size, seq_len, 1]).astype('int64')
|
|
|
|
|
# src_pos_np = np.random.randint(
|
|
|
|
|
# 1, seq_len, size=(batch_size, seq_len, 1), dtype='int64')
|
|
|
|
|
# src_slf_attn_bias_np = np.random.randn(batch_size, ModelHyperParams.n_head,
|
|
|
|
|
# seq_len, seq_len).astype('float32')
|
|
|
|
|
#
|
|
|
|
|
# trg_word_np = np.arange(0, 10).reshape([batch_size, seq_len, 1]).astype('int64')
|
|
|
|
|
# trg_pos_np = np.random.randint(
|
|
|
|
|
# 1, seq_len, size=(batch_size, seq_len, 1), dtype='int64')
|
|
|
|
|
# trg_slf_attn_bias_np = np.random.randn(batch_size, ModelHyperParams.n_head,
|
|
|
|
|
# seq_len, seq_len).astype('float32')
|
|
|
|
|
# trg_src_attn_bias_np = np.random.randn(batch_size, ModelHyperParams.n_head,
|
|
|
|
|
# seq_len, seq_len).astype('float32')
|
|
|
|
|
#
|
|
|
|
|
# lbl_word_np = np.arange(0, 10).reshape([batch_size * seq_len, 1]).astype('int64')
|
|
|
|
|
# lbl_weight_np = np.random.randn(batch_size * seq_len, 1).astype('float32')
|
|
|
|
|
#
|
|
|
|
|
lbl_weight_np = np.random.randn(TrainTaskConfig.batch_size * seq_len,
|
|
|
|
|
1).astype('float32')
|
|
|
|
|
|
|
|
|
|
pos_inp1 = position_encoding_init(ModelHyperParams.max_length,
|
|
|
|
|
ModelHyperParams.d_model)
|
|
|
|
|
pos_inp2 = position_encoding_init(ModelHyperParams.max_length,
|
|
|
|
@ -467,7 +459,7 @@ class MultiHeadAttentionLayer(Layer):
|
|
|
|
|
x=v, shape=[0, 0, self._n_head, self._d_value], inplace=False)
|
|
|
|
|
transpose_v = fluid.layers.transpose(x=reshaped_v, perm=[0, 2, 1, 3])
|
|
|
|
|
|
|
|
|
|
#scale dot product attention
|
|
|
|
|
# scale dot product attention
|
|
|
|
|
product = fluid.layers.matmul(
|
|
|
|
|
x=transpose_q,
|
|
|
|
|
y=transpose_k,
|
|
|
|
@ -740,7 +732,7 @@ class DecoderSubLayer(Layer):
|
|
|
|
|
enc_attn_output_pp = self._multihead_attention_layer2(
|
|
|
|
|
pre_process_rlt2, enc_output, enc_output, dec_enc_attn_bias)
|
|
|
|
|
enc_attn_output = self._post_process_layer2(
|
|
|
|
|
slf_attn_output, enc_attn_output_pp, self._postprocess_cmd,
|
|
|
|
|
slf_attn_output_pp, enc_attn_output_pp, self._postprocess_cmd,
|
|
|
|
|
self._prepostprcess_dropout)
|
|
|
|
|
pre_process_rlt3 = self._pre_process_layer3(None, enc_attn_output,
|
|
|
|
|
self._preprocess_cmd,
|
|
|
|
@ -991,6 +983,7 @@ class TestDygraphTransformer(unittest.TestCase):
|
|
|
|
|
enc_inputs, dec_inputs, label, weights = create_data()
|
|
|
|
|
dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = transformer(
|
|
|
|
|
enc_inputs, dec_inputs, label, weights)
|
|
|
|
|
|
|
|
|
|
if i == 0:
|
|
|
|
|
for param in transformer.parameters():
|
|
|
|
|
dy_param_init[param.name] = param.numpy()
|
|
|
|
@ -998,6 +991,7 @@ class TestDygraphTransformer(unittest.TestCase):
|
|
|
|
|
dy_avg_cost.backward()
|
|
|
|
|
optimizer.minimize(dy_avg_cost)
|
|
|
|
|
transformer.clear_gradients()
|
|
|
|
|
|
|
|
|
|
if i == batch_num - 1:
|
|
|
|
|
for param in transformer.parameters():
|
|
|
|
|
dy_param_updated[param.name] = param.numpy()
|
|
|
|
@ -1044,7 +1038,6 @@ class TestDygraphTransformer(unittest.TestCase):
|
|
|
|
|
static_param_name_list = list()
|
|
|
|
|
static_sum_cost, static_avg_cost, static_predict, static_token_num = transformer(
|
|
|
|
|
enc_inputs, dec_inputs, label, weights)
|
|
|
|
|
|
|
|
|
|
optimizer.minimize(static_avg_cost)
|
|
|
|
|
for param in transformer.parameters():
|
|
|
|
|
static_param_name_list.append(param.name)
|
|
|
|
@ -1062,8 +1055,8 @@ class TestDygraphTransformer(unittest.TestCase):
|
|
|
|
|
static_sum_cost, static_avg_cost, static_predict,
|
|
|
|
|
static_token_num
|
|
|
|
|
]
|
|
|
|
|
fetch_list.extend(static_param_name_list)
|
|
|
|
|
|
|
|
|
|
fetch_list.extend(static_param_name_list)
|
|
|
|
|
out = exe.run(fluid.default_main_program(),
|
|
|
|
|
feed=feed_dict,
|
|
|
|
|
fetch_list=fetch_list)
|
|
|
|
|