|
|
|
@ -36,6 +36,7 @@ import paddle.fluid as fluid
|
|
|
|
|
import paddle.fluid.layers as layers
|
|
|
|
|
from paddle.fluid import core
|
|
|
|
|
from test_dist_base import TestDistRunnerBase, runtime_main
|
|
|
|
|
import paddle.compat as cpt
|
|
|
|
|
from paddle.compat import long_type
|
|
|
|
|
|
|
|
|
|
import hashlib
|
|
|
|
@ -315,8 +316,9 @@ def pad_batch_data(insts,
|
|
|
|
|
"""
|
|
|
|
|
return_list = []
|
|
|
|
|
max_len = max(len(inst) for inst in insts)
|
|
|
|
|
num_token = reduce(lambda x, y: x + y,
|
|
|
|
|
[len(inst) for inst in insts]) if return_num_token else 0
|
|
|
|
|
num_token = six.moves.reduce(
|
|
|
|
|
lambda x, y: x + y,
|
|
|
|
|
[len(inst) for inst in insts]) if return_num_token else 0
|
|
|
|
|
# Any token included in dict can be used to pad, since the paddings' loss
|
|
|
|
|
# will be masked out by weights and make no effect on parameter gradients.
|
|
|
|
|
inst_data = np.array(
|
|
|
|
@ -328,7 +330,7 @@ def pad_batch_data(insts,
|
|
|
|
|
return_list += [inst_weight.astype("float32").reshape([-1, 1])]
|
|
|
|
|
else: # position data
|
|
|
|
|
inst_pos = np.array([
|
|
|
|
|
range(1, len(inst) + 1) + [0] * (max_len - len(inst))
|
|
|
|
|
list(range(1, len(inst) + 1)) + [0] * (max_len - len(inst))
|
|
|
|
|
for inst in insts
|
|
|
|
|
])
|
|
|
|
|
return_list += [inst_pos.astype("int64").reshape([-1, 1])]
|
|
|
|
@ -385,10 +387,11 @@ def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
|
|
|
|
|
return_num_token=True)
|
|
|
|
|
|
|
|
|
|
data_input_dict = dict(
|
|
|
|
|
zip(data_input_names, [
|
|
|
|
|
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
|
|
|
|
|
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
|
|
|
|
|
]))
|
|
|
|
|
list(
|
|
|
|
|
zip(data_input_names, [
|
|
|
|
|
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
|
|
|
|
|
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
|
|
|
|
|
])))
|
|
|
|
|
return data_input_dict, np.asarray([num_token], dtype="float32")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -561,7 +564,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
|
|
|
|
|
np.log(TrainTaskConfig.label_smooth_eps / (
|
|
|
|
|
ModelHyperParams.trg_vocab_size - 1) + 1e-20))
|
|
|
|
|
init = False
|
|
|
|
|
for pass_id in xrange(TrainTaskConfig.pass_num):
|
|
|
|
|
for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
|
|
|
|
|
pass_start_time = time.time()
|
|
|
|
|
for batch_id, data in enumerate(train_data()):
|
|
|
|
|
if batch_id >= 5:
|
|
|
|
@ -587,11 +590,11 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
|
|
|
|
|
ModelHyperParams.eos_idx, ModelHyperParams.n_head,
|
|
|
|
|
ModelHyperParams.d_model)
|
|
|
|
|
total_num_token += num_token
|
|
|
|
|
feed_kv_pairs = data_input_dict.items()
|
|
|
|
|
feed_kv_pairs = list(data_input_dict.items())
|
|
|
|
|
if TrainTaskConfig.local:
|
|
|
|
|
feed_kv_pairs += {
|
|
|
|
|
feed_kv_pairs += list({
|
|
|
|
|
lr_scheduler.learning_rate.name: lr_rate
|
|
|
|
|
}.items()
|
|
|
|
|
}.items())
|
|
|
|
|
feed_list.append(dict(feed_kv_pairs))
|
|
|
|
|
|
|
|
|
|
if not init:
|
|
|
|
@ -873,6 +876,7 @@ class DataReader(object):
|
|
|
|
|
|
|
|
|
|
f = tarfile.open(fpaths[0], "r")
|
|
|
|
|
for line in f.extractfile(tar_fname):
|
|
|
|
|
line = cpt.to_text(line)
|
|
|
|
|
fields = line.strip("\n").split(self._field_delimiter)
|
|
|
|
|
if (not self._only_src and len(fields) == 2) or (
|
|
|
|
|
self._only_src and len(fields) == 1):
|
|
|
|
@ -882,8 +886,9 @@ class DataReader(object):
|
|
|
|
|
if not os.path.isfile(fpath):
|
|
|
|
|
raise IOError("Invalid file: %s" % fpath)
|
|
|
|
|
|
|
|
|
|
with open(fpath, "r") as f:
|
|
|
|
|
with open(fpath, "rb") as f:
|
|
|
|
|
for line in f:
|
|
|
|
|
line = cpt.to_text(line)
|
|
|
|
|
fields = line.strip("\n").split(self._field_delimiter)
|
|
|
|
|
if (not self._only_src and len(fields) == 2) or (
|
|
|
|
|
self._only_src and len(fields) == 1):
|
|
|
|
@ -892,8 +897,9 @@ class DataReader(object):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def load_dict(dict_path, reverse=False):
|
|
|
|
|
word_dict = {}
|
|
|
|
|
with open(dict_path, "r") as fdict:
|
|
|
|
|
with open(dict_path, "rb") as fdict:
|
|
|
|
|
for idx, line in enumerate(fdict):
|
|
|
|
|
line = cpt.to_text(line)
|
|
|
|
|
if reverse:
|
|
|
|
|
word_dict[idx] = line.strip("\n")
|
|
|
|
|
else:
|
|
|
|
@ -1034,7 +1040,7 @@ def multi_head_attention(queries,
|
|
|
|
|
# size of the input as the output dimension size.
|
|
|
|
|
return layers.reshape(
|
|
|
|
|
x=trans_x,
|
|
|
|
|
shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]]))
|
|
|
|
|
shape=list(map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]])))
|
|
|
|
|
|
|
|
|
|
def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
|
|
|
|
|
"""
|
|
|
|
|