[Dy2Stat] Add test for dygraph seq2seq model. (#25054)

* The arg of append() can be not Tensor temporarily.

* Add Seq2Seq as ProgramTranslator Unit Test. 

* set dtype of vocab_size_tensor to int64 to pass Windows-CI.
fix-sync_batch_norm-hang-in-fleet
liym27 5 years ago committed by GitHub
parent 8fc31d501b
commit db601f70cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -213,27 +213,29 @@ class ListTransformer(gast.NodeTransformer):
if value_name not in self.list_name_to_updated:
return False
# 3. The arg of append() is one `Tensor`
# 3. The number of arg of append() is one
# Only one argument is supported in Python list.append()
if len(node.args) != 1:
return False
arg = node.args[0]
if isinstance(arg, gast.Name):
# TODO: `arg.id` may be not in scope_var_type_dict if `arg.id` is the arg of decorated function
# Need a better way to confirm whether `arg.id` is a Tensor.
try:
var_type_set = self.scope_var_type_dict[arg.id]
except KeyError:
return False
if NodeVarType.NUMPY_NDARRAY in var_type_set:
return False
if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
return False
# else:
# Todo: Consider that `arg` may be a gast.Call about Paddle Api.
# eg: list_a.append(fluid.layers.reshape(x))
# return True
# TODO(liym27): The arg of append() should be Tensor. But because the type of arg is often wrong with static analysis,
# the arg is not required to be Tensor here.
# 4. The arg of append() is Tensor
# arg = node.args[0]
# if isinstance(arg, gast.Name):
# # TODO: `arg.id` may be not in scope_var_type_dict if `arg.id` is the arg of decorated function
# # Need a better way to confirm whether `arg.id` is a Tensor.
# try:
# var_type_set = self.scope_var_type_dict[arg.id]
# except KeyError:
# return False
# if NodeVarType.NUMPY_NDARRAY in var_type_set:
# return False
# if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
# return False
# # TODO: Consider that `arg` may be a gast.Call about Paddle Api. eg: list_a.append(fluid.layers.reshape(x))
# # else:
# # return True
self.list_name_to_updated[value_name.strip()] = True
return True

@ -0,0 +1,135 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
SEED = 2020
def build_fake_sentence(seed):
random = np.random.RandomState(seed)
sentence_len = random.randint(5, 15)
token_ids = [random.randint(0, 1000) for _ in range(sentence_len - 1)]
return token_ids
def get_data_iter(batch_size, mode='train', cache_num=20):
self_random = np.random.RandomState(SEED)
def to_pad_np(data, source=False):
max_len = 0
bs = min(batch_size, len(data))
for ele in data:
if len(ele) > max_len:
max_len = len(ele)
ids = np.ones((bs, max_len), dtype='int64') * 2
mask = np.zeros((bs), dtype='int32')
for i, ele in enumerate(data):
ids[i, :len(ele)] = ele
if not source:
mask[i] = len(ele) - 1
else:
mask[i] = len(ele)
return ids, mask
b_src = []
if mode != "train":
cache_num = 1
data_len = 1000
for j in range(data_len):
if len(b_src) == batch_size * cache_num:
if mode == 'infer':
new_cache = b_src
else:
new_cache = sorted(b_src, key=lambda k: len(k[0]))
for i in range(cache_num):
batch_data = new_cache[i * batch_size:(i + 1) * batch_size]
src_cache = [w[0] for w in batch_data]
tar_cache = [w[1] for w in batch_data]
src_ids, src_mask = to_pad_np(src_cache, source=True)
tar_ids, tar_mask = to_pad_np(tar_cache)
yield (src_ids, src_mask, tar_ids, tar_mask)
b_src = []
src_seed = self_random.randint(0, data_len)
tar_seed = self_random.randint(0, data_len)
src_data = build_fake_sentence(src_seed)
tar_data = build_fake_sentence(tar_seed)
b_src.append((src_data, tar_data))
if len(b_src) == batch_size * cache_num or mode == 'infer':
if mode == 'infer':
new_cache = b_src
else:
new_cache = sorted(b_src, key=lambda k: len(k[0]))
for i in range(cache_num):
batch_end = min(len(new_cache), (i + 1) * batch_size)
batch_data = new_cache[i * batch_size:batch_end]
src_cache = [w[0] for w in batch_data]
tar_cache = [w[1] for w in batch_data]
src_ids, src_mask = to_pad_np(src_cache, source=True)
tar_ids, tar_mask = to_pad_np(tar_cache)
yield (src_ids, src_mask, tar_ids, tar_mask)
class Seq2SeqModelHyperParams(object):
# Whether use attention model
attention = False
# learning rate for optimizer
learning_rate = 0.01
# layers number of encoder and decoder
num_layers = 2
# hidden size of encoder and decoder
hidden_size = 8
src_vocab_size = 1000
tar_vocab_size = 1000
batch_size = 8
max_epoch = 12
# max length for source and target sentence
max_len = 30
# drop probability
dropout = 0.0
# init scale for parameter
init_scale = 0.1
# max grad norm for global norm clip
max_grad_norm = 5.0
# model path for model to save
model_path = "dy2stat/model/seq2seq"
# reload model to inference
reload_model = "model/epoch_0.pdparams"
beam_size = 10
max_seq_len = 3

@ -0,0 +1,172 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.clip import GradientClipByGlobalNorm
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from seq2seq_dygraph_model import BaseModel
from seq2seq_utils import Seq2SeqModelHyperParams as args
from seq2seq_utils import get_data_iter
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
)
program_translator = ProgramTranslator()
STEP_NUM = 10
PRINT_STEP = 2
def prepare_input(batch):
src_ids, src_mask, tar_ids, tar_mask = batch
src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1]))
in_tar = tar_ids[:, :-1]
label_tar = tar_ids[:, 1:]
in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1]))
label_tar = label_tar.reshape((label_tar.shape[0], label_tar.shape[1], 1))
inputs = [src_ids, in_tar, label_tar, src_mask, tar_mask]
return inputs, np.sum(tar_mask)
def train():
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = 2020
fluid.default_main_program().random_seed = 2020
model = BaseModel(
args.hidden_size,
args.src_vocab_size,
args.tar_vocab_size,
args.batch_size,
num_layers=args.num_layers,
init_scale=args.init_scale,
dropout=args.dropout)
gloabl_norm_clip = GradientClipByGlobalNorm(args.max_grad_norm)
optimizer = fluid.optimizer.SGD(args.learning_rate,
parameter_list=model.parameters(),
grad_clip=gloabl_norm_clip)
model.train()
train_data_iter = get_data_iter(args.batch_size)
batch_times = []
for batch_id, batch in enumerate(train_data_iter):
total_loss = 0
word_count = 0.0
batch_start_time = time.time()
input_data_feed, word_num = prepare_input(batch)
input_data_feed = [
fluid.dygraph.to_variable(np_inp) for np_inp in input_data_feed
]
word_count += word_num
loss = model(input_data_feed)
loss.backward()
optimizer.minimize(loss)
model.clear_gradients()
total_loss += loss * args.batch_size
batch_end_time = time.time()
batch_time = batch_end_time - batch_start_time
batch_times.append(batch_time)
if batch_id % PRINT_STEP == 0:
print(
"Batch:[%d]; Time: %.5f s; loss: %.5f; total_loss: %.5f; word num: %.5f; ppl: %.5f"
% (batch_id, batch_time, loss.numpy(), total_loss.numpy(),
word_count, np.exp(total_loss.numpy() / word_count)))
if batch_id + 1 >= STEP_NUM:
break
model_dir = os.path.join(args.model_path)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
fluid.save_dygraph(model.state_dict(), model_dir)
return loss.numpy()
def infer():
with fluid.dygraph.guard(place):
model = BaseModel(
args.hidden_size,
args.src_vocab_size,
args.tar_vocab_size,
args.batch_size,
beam_size=args.beam_size,
num_layers=args.num_layers,
init_scale=args.init_scale,
dropout=0.0,
mode='beam_search')
state_dict, _ = fluid.dygraph.load_dygraph(args.model_path)
model.set_dict(state_dict)
model.eval()
train_data_iter = get_data_iter(args.batch_size, mode='infer')
batch_times = []
for batch_id, batch in enumerate(train_data_iter):
batch_start_time = time.time()
input_data_feed, word_num = prepare_input(batch)
input_data_feed = [
fluid.dygraph.to_variable(np_inp) for np_inp in input_data_feed
]
outputs = model.beam_search(input_data_feed)
batch_end_time = time.time()
batch_time = batch_end_time - batch_start_time
batch_times.append(batch_time)
if batch_id > STEP_NUM:
break
return outputs.numpy()
class TestSeq2seq(unittest.TestCase):
def run_dygraph(self, mode="train"):
program_translator.enable(False)
if mode == "train":
return train()
else:
return infer()
def run_static(self, mode="train"):
program_translator.enable(True)
if mode == "train":
return train()
else:
return infer()
def _test_train(self):
dygraph_loss = self.run_dygraph(mode="train")
static_loss = self.run_static(mode="train")
result = np.allclose(dygraph_loss, static_loss)
self.assertTrue(
result,
msg="\ndygraph_loss = {} \nstatic_loss = {}".format(dygraph_loss,
static_loss))
def _test_predict(self):
pred_dygraph = self.run_dygraph(mode="test")
pred_static = self.run_static(mode="test")
result = np.allclose(pred_static, pred_dygraph)
self.assertTrue(
result,
msg="\npred_dygraph = {} \npred_static = {}".format(pred_dygraph,
pred_static))
def test_check_result(self):
self._test_train()
self._test_predict()
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save