[Dy2Stat] Add test for dygraph seq2seq model. (#25054)
* The arg of append() can be not Tensor temporarily. * Add Seq2Seq as ProgramTranslator Unit Test. * set dtype of vocab_size_tensor to int64 to pass Windows-CI.fix-sync_batch_norm-hang-in-fleet
parent
8fc31d501b
commit
db601f70cc
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,135 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
SEED = 2020
|
||||
|
||||
|
||||
def build_fake_sentence(seed):
|
||||
random = np.random.RandomState(seed)
|
||||
sentence_len = random.randint(5, 15)
|
||||
token_ids = [random.randint(0, 1000) for _ in range(sentence_len - 1)]
|
||||
return token_ids
|
||||
|
||||
|
||||
def get_data_iter(batch_size, mode='train', cache_num=20):
|
||||
|
||||
self_random = np.random.RandomState(SEED)
|
||||
|
||||
def to_pad_np(data, source=False):
|
||||
max_len = 0
|
||||
bs = min(batch_size, len(data))
|
||||
for ele in data:
|
||||
if len(ele) > max_len:
|
||||
max_len = len(ele)
|
||||
|
||||
ids = np.ones((bs, max_len), dtype='int64') * 2
|
||||
mask = np.zeros((bs), dtype='int32')
|
||||
|
||||
for i, ele in enumerate(data):
|
||||
ids[i, :len(ele)] = ele
|
||||
if not source:
|
||||
mask[i] = len(ele) - 1
|
||||
else:
|
||||
mask[i] = len(ele)
|
||||
|
||||
return ids, mask
|
||||
|
||||
b_src = []
|
||||
|
||||
if mode != "train":
|
||||
cache_num = 1
|
||||
data_len = 1000
|
||||
for j in range(data_len):
|
||||
if len(b_src) == batch_size * cache_num:
|
||||
if mode == 'infer':
|
||||
new_cache = b_src
|
||||
else:
|
||||
new_cache = sorted(b_src, key=lambda k: len(k[0]))
|
||||
|
||||
for i in range(cache_num):
|
||||
batch_data = new_cache[i * batch_size:(i + 1) * batch_size]
|
||||
src_cache = [w[0] for w in batch_data]
|
||||
tar_cache = [w[1] for w in batch_data]
|
||||
src_ids, src_mask = to_pad_np(src_cache, source=True)
|
||||
tar_ids, tar_mask = to_pad_np(tar_cache)
|
||||
yield (src_ids, src_mask, tar_ids, tar_mask)
|
||||
|
||||
b_src = []
|
||||
src_seed = self_random.randint(0, data_len)
|
||||
tar_seed = self_random.randint(0, data_len)
|
||||
src_data = build_fake_sentence(src_seed)
|
||||
tar_data = build_fake_sentence(tar_seed)
|
||||
b_src.append((src_data, tar_data))
|
||||
|
||||
if len(b_src) == batch_size * cache_num or mode == 'infer':
|
||||
if mode == 'infer':
|
||||
new_cache = b_src
|
||||
else:
|
||||
new_cache = sorted(b_src, key=lambda k: len(k[0]))
|
||||
|
||||
for i in range(cache_num):
|
||||
batch_end = min(len(new_cache), (i + 1) * batch_size)
|
||||
batch_data = new_cache[i * batch_size:batch_end]
|
||||
src_cache = [w[0] for w in batch_data]
|
||||
tar_cache = [w[1] for w in batch_data]
|
||||
src_ids, src_mask = to_pad_np(src_cache, source=True)
|
||||
tar_ids, tar_mask = to_pad_np(tar_cache)
|
||||
yield (src_ids, src_mask, tar_ids, tar_mask)
|
||||
|
||||
|
||||
class Seq2SeqModelHyperParams(object):
|
||||
# Whether use attention model
|
||||
attention = False
|
||||
|
||||
# learning rate for optimizer
|
||||
learning_rate = 0.01
|
||||
|
||||
# layers number of encoder and decoder
|
||||
num_layers = 2
|
||||
|
||||
# hidden size of encoder and decoder
|
||||
hidden_size = 8
|
||||
|
||||
src_vocab_size = 1000
|
||||
tar_vocab_size = 1000
|
||||
batch_size = 8
|
||||
max_epoch = 12
|
||||
|
||||
# max length for source and target sentence
|
||||
max_len = 30
|
||||
|
||||
# drop probability
|
||||
dropout = 0.0
|
||||
|
||||
# init scale for parameter
|
||||
init_scale = 0.1
|
||||
|
||||
# max grad norm for global norm clip
|
||||
max_grad_norm = 5.0
|
||||
|
||||
# model path for model to save
|
||||
model_path = "dy2stat/model/seq2seq"
|
||||
|
||||
# reload model to inference
|
||||
reload_model = "model/epoch_0.pdparams"
|
||||
|
||||
beam_size = 10
|
||||
|
||||
max_seq_len = 3
|
@ -0,0 +1,172 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.clip import GradientClipByGlobalNorm
|
||||
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
|
||||
|
||||
from seq2seq_dygraph_model import BaseModel
|
||||
from seq2seq_utils import Seq2SeqModelHyperParams as args
|
||||
from seq2seq_utils import get_data_iter
|
||||
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
|
||||
)
|
||||
program_translator = ProgramTranslator()
|
||||
STEP_NUM = 10
|
||||
PRINT_STEP = 2
|
||||
|
||||
|
||||
def prepare_input(batch):
|
||||
src_ids, src_mask, tar_ids, tar_mask = batch
|
||||
src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1]))
|
||||
in_tar = tar_ids[:, :-1]
|
||||
label_tar = tar_ids[:, 1:]
|
||||
|
||||
in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1]))
|
||||
label_tar = label_tar.reshape((label_tar.shape[0], label_tar.shape[1], 1))
|
||||
inputs = [src_ids, in_tar, label_tar, src_mask, tar_mask]
|
||||
return inputs, np.sum(tar_mask)
|
||||
|
||||
|
||||
def train():
|
||||
with fluid.dygraph.guard(place):
|
||||
fluid.default_startup_program().random_seed = 2020
|
||||
fluid.default_main_program().random_seed = 2020
|
||||
|
||||
model = BaseModel(
|
||||
args.hidden_size,
|
||||
args.src_vocab_size,
|
||||
args.tar_vocab_size,
|
||||
args.batch_size,
|
||||
num_layers=args.num_layers,
|
||||
init_scale=args.init_scale,
|
||||
dropout=args.dropout)
|
||||
|
||||
gloabl_norm_clip = GradientClipByGlobalNorm(args.max_grad_norm)
|
||||
optimizer = fluid.optimizer.SGD(args.learning_rate,
|
||||
parameter_list=model.parameters(),
|
||||
grad_clip=gloabl_norm_clip)
|
||||
|
||||
model.train()
|
||||
train_data_iter = get_data_iter(args.batch_size)
|
||||
|
||||
batch_times = []
|
||||
for batch_id, batch in enumerate(train_data_iter):
|
||||
total_loss = 0
|
||||
word_count = 0.0
|
||||
batch_start_time = time.time()
|
||||
input_data_feed, word_num = prepare_input(batch)
|
||||
input_data_feed = [
|
||||
fluid.dygraph.to_variable(np_inp) for np_inp in input_data_feed
|
||||
]
|
||||
word_count += word_num
|
||||
loss = model(input_data_feed)
|
||||
loss.backward()
|
||||
optimizer.minimize(loss)
|
||||
model.clear_gradients()
|
||||
total_loss += loss * args.batch_size
|
||||
batch_end_time = time.time()
|
||||
batch_time = batch_end_time - batch_start_time
|
||||
batch_times.append(batch_time)
|
||||
if batch_id % PRINT_STEP == 0:
|
||||
print(
|
||||
"Batch:[%d]; Time: %.5f s; loss: %.5f; total_loss: %.5f; word num: %.5f; ppl: %.5f"
|
||||
% (batch_id, batch_time, loss.numpy(), total_loss.numpy(),
|
||||
word_count, np.exp(total_loss.numpy() / word_count)))
|
||||
if batch_id + 1 >= STEP_NUM:
|
||||
break
|
||||
model_dir = os.path.join(args.model_path)
|
||||
if not os.path.exists(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
fluid.save_dygraph(model.state_dict(), model_dir)
|
||||
return loss.numpy()
|
||||
|
||||
|
||||
def infer():
|
||||
with fluid.dygraph.guard(place):
|
||||
model = BaseModel(
|
||||
args.hidden_size,
|
||||
args.src_vocab_size,
|
||||
args.tar_vocab_size,
|
||||
args.batch_size,
|
||||
beam_size=args.beam_size,
|
||||
num_layers=args.num_layers,
|
||||
init_scale=args.init_scale,
|
||||
dropout=0.0,
|
||||
mode='beam_search')
|
||||
state_dict, _ = fluid.dygraph.load_dygraph(args.model_path)
|
||||
model.set_dict(state_dict)
|
||||
model.eval()
|
||||
train_data_iter = get_data_iter(args.batch_size, mode='infer')
|
||||
batch_times = []
|
||||
for batch_id, batch in enumerate(train_data_iter):
|
||||
batch_start_time = time.time()
|
||||
input_data_feed, word_num = prepare_input(batch)
|
||||
input_data_feed = [
|
||||
fluid.dygraph.to_variable(np_inp) for np_inp in input_data_feed
|
||||
]
|
||||
outputs = model.beam_search(input_data_feed)
|
||||
batch_end_time = time.time()
|
||||
batch_time = batch_end_time - batch_start_time
|
||||
batch_times.append(batch_time)
|
||||
if batch_id > STEP_NUM:
|
||||
break
|
||||
|
||||
return outputs.numpy()
|
||||
|
||||
|
||||
class TestSeq2seq(unittest.TestCase):
|
||||
def run_dygraph(self, mode="train"):
|
||||
program_translator.enable(False)
|
||||
if mode == "train":
|
||||
return train()
|
||||
else:
|
||||
return infer()
|
||||
|
||||
def run_static(self, mode="train"):
|
||||
program_translator.enable(True)
|
||||
if mode == "train":
|
||||
return train()
|
||||
else:
|
||||
return infer()
|
||||
|
||||
def _test_train(self):
|
||||
dygraph_loss = self.run_dygraph(mode="train")
|
||||
static_loss = self.run_static(mode="train")
|
||||
result = np.allclose(dygraph_loss, static_loss)
|
||||
self.assertTrue(
|
||||
result,
|
||||
msg="\ndygraph_loss = {} \nstatic_loss = {}".format(dygraph_loss,
|
||||
static_loss))
|
||||
|
||||
def _test_predict(self):
|
||||
pred_dygraph = self.run_dygraph(mode="test")
|
||||
pred_static = self.run_static(mode="test")
|
||||
result = np.allclose(pred_static, pred_dygraph)
|
||||
self.assertTrue(
|
||||
result,
|
||||
msg="\npred_dygraph = {} \npred_static = {}".format(pred_dygraph,
|
||||
pred_static))
|
||||
|
||||
def test_check_result(self):
|
||||
self._test_train()
|
||||
self._test_predict()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue