|
|
@ -29,6 +29,7 @@ trainer_count = 1
|
|
|
|
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
|
|
|
|
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
|
|
|
|
)
|
|
|
|
)
|
|
|
|
SEED = 10
|
|
|
|
SEED = 10
|
|
|
|
|
|
|
|
step_num = 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_static(args, batch_generator):
|
|
|
|
def train_static(args, batch_generator):
|
|
|
@ -117,7 +118,7 @@ def train_static(args, batch_generator):
|
|
|
|
batch_id += 1
|
|
|
|
batch_id += 1
|
|
|
|
step_idx += 1
|
|
|
|
step_idx += 1
|
|
|
|
total_batch_num = total_batch_num + 1
|
|
|
|
total_batch_num = total_batch_num + 1
|
|
|
|
if step_idx == 10:
|
|
|
|
if step_idx == step_num:
|
|
|
|
if args.save_dygraph_model_path:
|
|
|
|
if args.save_dygraph_model_path:
|
|
|
|
model_path = os.path.join(args.save_static_model_path,
|
|
|
|
model_path = os.path.join(args.save_static_model_path,
|
|
|
|
"transformer")
|
|
|
|
"transformer")
|
|
|
@ -201,7 +202,7 @@ def train_dygraph(args, batch_generator):
|
|
|
|
avg_batch_time = time.time()
|
|
|
|
avg_batch_time = time.time()
|
|
|
|
batch_id += 1
|
|
|
|
batch_id += 1
|
|
|
|
step_idx += 1
|
|
|
|
step_idx += 1
|
|
|
|
if step_idx == 10:
|
|
|
|
if step_idx == step_num:
|
|
|
|
if args.save_dygraph_model_path:
|
|
|
|
if args.save_dygraph_model_path:
|
|
|
|
model_dir = os.path.join(args.save_dygraph_model_path)
|
|
|
|
model_dir = os.path.join(args.save_dygraph_model_path)
|
|
|
|
if not os.path.exists(model_dir):
|
|
|
|
if not os.path.exists(model_dir):
|
|
|
@ -250,10 +251,11 @@ def predict_dygraph(args, batch_generator):
|
|
|
|
transformer.eval()
|
|
|
|
transformer.eval()
|
|
|
|
|
|
|
|
|
|
|
|
step_idx = 0
|
|
|
|
step_idx = 0
|
|
|
|
|
|
|
|
speed_list = []
|
|
|
|
for input_data in test_loader():
|
|
|
|
for input_data in test_loader():
|
|
|
|
(src_word, src_pos, src_slf_attn_bias, trg_word,
|
|
|
|
(src_word, src_pos, src_slf_attn_bias, trg_word,
|
|
|
|
trg_src_attn_bias) = input_data
|
|
|
|
trg_src_attn_bias) = input_data
|
|
|
|
finished_seq, finished_scores = transformer.beam_search(
|
|
|
|
seq_ids, seq_scores = transformer.beam_search(
|
|
|
|
src_word,
|
|
|
|
src_word,
|
|
|
|
src_pos,
|
|
|
|
src_pos,
|
|
|
|
src_slf_attn_bias,
|
|
|
|
src_slf_attn_bias,
|
|
|
@ -263,12 +265,28 @@ def predict_dygraph(args, batch_generator):
|
|
|
|
eos_id=args.eos_idx,
|
|
|
|
eos_id=args.eos_idx,
|
|
|
|
beam_size=args.beam_size,
|
|
|
|
beam_size=args.beam_size,
|
|
|
|
max_len=args.max_out_len)
|
|
|
|
max_len=args.max_out_len)
|
|
|
|
finished_seq = finished_seq.numpy()
|
|
|
|
seq_ids = seq_ids.numpy()
|
|
|
|
finished_scores = finished_scores.numpy()
|
|
|
|
seq_scores = seq_scores.numpy()
|
|
|
|
|
|
|
|
if step_idx % args.print_step == 0:
|
|
|
|
|
|
|
|
if step_idx == 0:
|
|
|
|
|
|
|
|
logging.info(
|
|
|
|
|
|
|
|
"Dygraph Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f"
|
|
|
|
|
|
|
|
% (step_idx, seq_ids[0][0][0], seq_scores[0][0]))
|
|
|
|
|
|
|
|
avg_batch_time = time.time()
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
speed = args.print_step / (time.time() - avg_batch_time)
|
|
|
|
|
|
|
|
speed_list.append(speed)
|
|
|
|
|
|
|
|
logging.info(
|
|
|
|
|
|
|
|
"Dygraph Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f step/s"
|
|
|
|
|
|
|
|
% (step_idx, seq_ids[0][0][0], seq_scores[0][0], speed))
|
|
|
|
|
|
|
|
avg_batch_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
step_idx += 1
|
|
|
|
step_idx += 1
|
|
|
|
if step_idx == 10:
|
|
|
|
if step_idx == step_num:
|
|
|
|
break
|
|
|
|
break
|
|
|
|
return finished_seq
|
|
|
|
logging.info("Dygraph Predict: avg_speed: %.4f step/s" %
|
|
|
|
|
|
|
|
(np.mean(speed_list)))
|
|
|
|
|
|
|
|
return seq_ids, seq_scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_static(args, batch_generator):
|
|
|
|
def predict_static(args, batch_generator):
|
|
|
@ -318,16 +336,34 @@ def predict_static(args, batch_generator):
|
|
|
|
loader.set_batch_generator(batch_generator, places=place)
|
|
|
|
loader.set_batch_generator(batch_generator, places=place)
|
|
|
|
|
|
|
|
|
|
|
|
step_idx = 0
|
|
|
|
step_idx = 0
|
|
|
|
|
|
|
|
speed_list = []
|
|
|
|
for feed_dict in loader:
|
|
|
|
for feed_dict in loader:
|
|
|
|
seq_ids, seq_scores = exe.run(
|
|
|
|
seq_ids, seq_scores = exe.run(
|
|
|
|
test_prog,
|
|
|
|
test_prog,
|
|
|
|
feed=feed_dict,
|
|
|
|
feed=feed_dict,
|
|
|
|
fetch_list=[out_ids.name, out_scores.name],
|
|
|
|
fetch_list=[out_ids.name, out_scores.name],
|
|
|
|
return_numpy=True)
|
|
|
|
return_numpy=True)
|
|
|
|
|
|
|
|
if step_idx % args.print_step == 0:
|
|
|
|
|
|
|
|
if step_idx == 0:
|
|
|
|
|
|
|
|
logging.info(
|
|
|
|
|
|
|
|
"Static Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f,"
|
|
|
|
|
|
|
|
% (step_idx, seq_ids[0][0][0], seq_scores[0][0]))
|
|
|
|
|
|
|
|
avg_batch_time = time.time()
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
speed = args.print_step / (time.time() - avg_batch_time)
|
|
|
|
|
|
|
|
speed_list.append(speed)
|
|
|
|
|
|
|
|
logging.info(
|
|
|
|
|
|
|
|
"Static Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f step/s"
|
|
|
|
|
|
|
|
% (step_idx, seq_ids[0][0][0], seq_scores[0][0], speed))
|
|
|
|
|
|
|
|
avg_batch_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
step_idx += 1
|
|
|
|
step_idx += 1
|
|
|
|
if step_idx == 10:
|
|
|
|
if step_idx == step_num:
|
|
|
|
break
|
|
|
|
break
|
|
|
|
return seq_ids
|
|
|
|
logging.info("Static Predict: avg_speed: %.4f step/s" %
|
|
|
|
|
|
|
|
(np.mean(speed_list)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return seq_ids, seq_scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestTransformer(unittest.TestCase):
|
|
|
|
class TestTransformer(unittest.TestCase):
|
|
|
@ -344,12 +380,17 @@ class TestTransformer(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
|
|
def _test_predict(self):
|
|
|
|
def _test_predict(self):
|
|
|
|
args, batch_generator = self.prepare(mode='test')
|
|
|
|
args, batch_generator = self.prepare(mode='test')
|
|
|
|
static_res = predict_static(args, batch_generator)
|
|
|
|
static_seq_ids, static_scores = predict_static(args, batch_generator)
|
|
|
|
dygraph_res = predict_dygraph(args, batch_generator)
|
|
|
|
dygraph_seq_ids, dygraph_scores = predict_dygraph(args, batch_generator)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
|
|
|
np.allclose(static_seq_ids, static_seq_ids),
|
|
|
|
|
|
|
|
msg="static_seq_ids: {} \n dygraph_seq_ids: {}".format(
|
|
|
|
|
|
|
|
static_seq_ids, dygraph_seq_ids))
|
|
|
|
self.assertTrue(
|
|
|
|
self.assertTrue(
|
|
|
|
np.allclose(static_res, dygraph_res),
|
|
|
|
np.allclose(static_scores, dygraph_scores),
|
|
|
|
msg="static_res: {} \n dygraph_res: {}".format(static_res,
|
|
|
|
msg="static_scores: {} \n dygraph_scores: {}".format(
|
|
|
|
dygraph_res))
|
|
|
|
static_scores, dygraph_scores))
|
|
|
|
|
|
|
|
|
|
|
|
def test_check_result(self):
|
|
|
|
def test_check_result(self):
|
|
|
|
self._test_train()
|
|
|
|
self._test_train()
|
|
|
|