|
|
|
@ -18,6 +18,7 @@ import unittest
|
|
|
|
|
import numpy as np
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
|
|
|
|
|
from paddle.fluid.dygraph.io import VARIABLE_FILENAME
|
|
|
|
|
|
|
|
|
|
from bert_dygraph_model import PretrainModelLayer
|
|
|
|
|
from bert_utils import get_bert_config, get_feed_data_reader
|
|
|
|
@ -28,9 +29,11 @@ place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
|
|
|
|
|
SEED = 2020
|
|
|
|
|
STEP_NUM = 10
|
|
|
|
|
PRINT_STEP = 2
|
|
|
|
|
MODEL_SAVE_PATH = "./bert.inference.model"
|
|
|
|
|
DY_STATE_DICT_SAVE_PATH = "./bert.dygraph"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(bert_config, data_reader):
|
|
|
|
|
def train(bert_config, data_reader, to_static):
|
|
|
|
|
with fluid.dygraph.guard(place):
|
|
|
|
|
fluid.default_main_program().random_seed = SEED
|
|
|
|
|
fluid.default_startup_program().random_seed = SEED
|
|
|
|
@ -79,18 +82,74 @@ def train(bert_config, data_reader):
|
|
|
|
|
|
|
|
|
|
step_idx += 1
|
|
|
|
|
if step_idx == STEP_NUM:
|
|
|
|
|
if to_static:
|
|
|
|
|
fluid.dygraph.jit.save(bert, MODEL_SAVE_PATH)
|
|
|
|
|
else:
|
|
|
|
|
fluid.dygraph.save_dygraph(bert.state_dict(),
|
|
|
|
|
DY_STATE_DICT_SAVE_PATH)
|
|
|
|
|
break
|
|
|
|
|
return loss, ppl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_dygraph(bert_config, data_reader):
|
|
|
|
|
program_translator.enable(False)
|
|
|
|
|
return train(bert_config, data_reader)
|
|
|
|
|
return train(bert_config, data_reader, False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_static(bert_config, data_reader):
|
|
|
|
|
program_translator.enable(True)
|
|
|
|
|
return train(bert_config, data_reader)
|
|
|
|
|
return train(bert_config, data_reader, True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_static(data):
|
|
|
|
|
exe = fluid.Executor(place)
|
|
|
|
|
# load inference model
|
|
|
|
|
[inference_program, feed_target_names,
|
|
|
|
|
fetch_targets] = fluid.io.load_inference_model(
|
|
|
|
|
MODEL_SAVE_PATH, executor=exe, params_filename=VARIABLE_FILENAME)
|
|
|
|
|
pred_res = exe.run(inference_program,
|
|
|
|
|
feed=dict(zip(feed_target_names, data)),
|
|
|
|
|
fetch_list=fetch_targets)
|
|
|
|
|
|
|
|
|
|
return pred_res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_dygraph(bert_config, data):
|
|
|
|
|
program_translator.enable(False)
|
|
|
|
|
with fluid.dygraph.guard(place):
|
|
|
|
|
bert = PretrainModelLayer(
|
|
|
|
|
config=bert_config, weight_sharing=False, use_fp16=False)
|
|
|
|
|
model_dict, _ = fluid.dygraph.load_dygraph(DY_STATE_DICT_SAVE_PATH)
|
|
|
|
|
|
|
|
|
|
bert.set_dict(model_dict)
|
|
|
|
|
bert.eval()
|
|
|
|
|
|
|
|
|
|
input_vars = [fluid.dygraph.to_variable(x) for x in data]
|
|
|
|
|
src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels = input_vars
|
|
|
|
|
pred_res = bert(
|
|
|
|
|
src_ids=src_ids,
|
|
|
|
|
position_ids=pos_ids,
|
|
|
|
|
sentence_ids=sent_ids,
|
|
|
|
|
input_mask=input_mask,
|
|
|
|
|
mask_label=mask_label,
|
|
|
|
|
mask_pos=mask_pos,
|
|
|
|
|
labels=labels)
|
|
|
|
|
pred_res = [var.numpy() for var in pred_res]
|
|
|
|
|
|
|
|
|
|
return pred_res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_dygraph_jit(data):
|
|
|
|
|
with fluid.dygraph.guard(place):
|
|
|
|
|
bert = fluid.dygraph.jit.load(MODEL_SAVE_PATH)
|
|
|
|
|
bert.eval()
|
|
|
|
|
|
|
|
|
|
src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels = data
|
|
|
|
|
pred_res = bert(src_ids, pos_ids, sent_ids, input_mask, mask_label,
|
|
|
|
|
mask_pos, labels)
|
|
|
|
|
pred_res = [var.numpy() for var in pred_res]
|
|
|
|
|
|
|
|
|
|
return pred_res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestBert(unittest.TestCase):
|
|
|
|
@ -104,14 +163,36 @@ class TestBert(unittest.TestCase):
|
|
|
|
|
dygraph_loss, dygraph_ppl = train_dygraph(self.bert_config,
|
|
|
|
|
self.data_reader)
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(static_loss, static_loss),
|
|
|
|
|
msg="static_loss: {} \n static_loss: {}".format(static_loss,
|
|
|
|
|
dygraph_loss))
|
|
|
|
|
np.allclose(static_loss, dygraph_loss),
|
|
|
|
|
msg="static_loss: {} \n dygraph_loss: {}".format(static_loss,
|
|
|
|
|
dygraph_loss))
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(static_ppl, dygraph_ppl),
|
|
|
|
|
msg="static_ppl: {} \n dygraph_ppl: {}".format(static_ppl,
|
|
|
|
|
dygraph_ppl))
|
|
|
|
|
|
|
|
|
|
self.verify_predict()
|
|
|
|
|
|
|
|
|
|
def verify_predict(self):
|
|
|
|
|
for data in self.data_reader.data_generator()():
|
|
|
|
|
dygraph_pred_res = predict_dygraph(self.bert_config, data)
|
|
|
|
|
static_pred_res = predict_static(data)
|
|
|
|
|
dygraph_jit_pred_res = predict_dygraph_jit(data)
|
|
|
|
|
|
|
|
|
|
for dy_res, st_res, dy_jit_res in zip(
|
|
|
|
|
dygraph_pred_res, static_pred_res, dygraph_jit_pred_res):
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(st_res, dy_res),
|
|
|
|
|
"dygraph_res: {},\n static_res: {}".format(
|
|
|
|
|
dy_res[~np.isclose(st_res, dy_res)],
|
|
|
|
|
st_res[~np.isclose(st_res, dy_res)]))
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(st_res, dy_jit_res),
|
|
|
|
|
"dygraph_jit_res: {},\n static_res: {}".format(
|
|
|
|
|
dy_jit_res[~np.isclose(st_res, dy_jit_res)],
|
|
|
|
|
st_res[~np.isclose(st_res, dy_jit_res)]))
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|