|
|
|
@ -76,26 +76,6 @@ def get_config(version='base', batch_size=1):
|
|
|
|
|
token_type_ids_from_dataset=True,
|
|
|
|
|
dtype=mstype.float32,
|
|
|
|
|
compute_type=mstype.float16)
|
|
|
|
|
elif version == 'large_mixed':
|
|
|
|
|
bert_config = BertConfig(
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
seq_length=128,
|
|
|
|
|
vocab_size=21136,
|
|
|
|
|
hidden_size=1024,
|
|
|
|
|
num_hidden_layers=24,
|
|
|
|
|
num_attention_heads=16,
|
|
|
|
|
intermediate_size=4096,
|
|
|
|
|
hidden_act="gelu",
|
|
|
|
|
hidden_dropout_prob=0.0,
|
|
|
|
|
attention_probs_dropout_prob=0.0,
|
|
|
|
|
max_position_embeddings=512,
|
|
|
|
|
type_vocab_size=2,
|
|
|
|
|
initializer_range=0.02,
|
|
|
|
|
use_relative_positions=True,
|
|
|
|
|
input_mask_from_dataset=True,
|
|
|
|
|
token_type_ids_from_dataset=True,
|
|
|
|
|
dtype=mstype.float32,
|
|
|
|
|
compute_type=mstype.float32)
|
|
|
|
|
else:
|
|
|
|
|
bert_config = BertConfig(batch_size=batch_size)
|
|
|
|
|
return bert_config
|
|
|
|
@ -136,8 +116,8 @@ class ModelCallback(Callback):
|
|
|
|
|
def step_end(self, run_context):
|
|
|
|
|
cb_params = run_context.original_args()
|
|
|
|
|
self.loss_list.append(cb_params.net_outputs[0].asnumpy()[0])
|
|
|
|
|
self.overflow_list.append(cb_params.net_outputs[1])
|
|
|
|
|
self.lossscale_list.append(cb_params.net_outputs[2])
|
|
|
|
|
self.overflow_list.append(cb_params.net_outputs[1].asnumpy())
|
|
|
|
|
self.lossscale_list.append(cb_params.net_outputs[2].asnumpy())
|
|
|
|
|
print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs)))
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
@ -157,7 +137,7 @@ def test_bert_tdt():
|
|
|
|
|
netwithloss = BertNetworkWithLoss(config, True)
|
|
|
|
|
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9)
|
|
|
|
|
scale_window = 3
|
|
|
|
|
scale_manager = DynamicLossScaleManager(2**32, 2, scale_window)
|
|
|
|
|
scale_manager = DynamicLossScaleManager(2**16, 2, scale_window)
|
|
|
|
|
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=scale_manager.get_update_cell())
|
|
|
|
|
netwithgrads.set_train(True)
|
|
|
|
|
model = Model(netwithgrads)
|
|
|
|
@ -182,22 +162,21 @@ def test_bert_tdt():
|
|
|
|
|
param.default_input = weight_variable(value.asnumpy().shape)
|
|
|
|
|
model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=False)
|
|
|
|
|
|
|
|
|
|
# assertion occurs while the loss_scale value is wrong
|
|
|
|
|
count = 0
|
|
|
|
|
for i in range(len(callback.overflow_list)):
|
|
|
|
|
if callback.overflow_list[i] == Tensor(True, mstype.bool_) and i > 0:
|
|
|
|
|
count = 0
|
|
|
|
|
assert callback.lossscale_list[i] == callback.lossscale_list[i - 1] * Tensor(0.5, mstype.float32)
|
|
|
|
|
if callback.overflow_list[i] == Tensor(False, mstype.bool_):
|
|
|
|
|
count = count + 1
|
|
|
|
|
if count == scale_window:
|
|
|
|
|
count = 0
|
|
|
|
|
assert callback.lossscale_list[i] == callback.lossscale_list[i - 1] * Tensor(2.0, mstype.float32)
|
|
|
|
|
# assertion occurs while the loss value is wrong
|
|
|
|
|
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
|
|
|
|
loss_value = np.array(callback.loss_list)
|
|
|
|
|
expect_value = [12.1918125, 11.966035, 11.972114, 11.982671, 11.976399, 12.616986, 12.180658, 12.850562, 12.415608, 12.640145]
|
|
|
|
|
expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982188, 11.974092, 12.610916, 12.17565, 12.840416, 12.40291, 12.621661]
|
|
|
|
|
print("loss value: {}".format(loss_value))
|
|
|
|
|
assert np.allclose(loss_value, expect_value, 0.00001, 0.00001)
|
|
|
|
|
assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001)
|
|
|
|
|
|
|
|
|
|
overflow = np.array(callback.overflow_list)
|
|
|
|
|
expect_overflow = [True, True, False, False, False, True, False, False, False, True]
|
|
|
|
|
print("overflow: {}".format(overflow))
|
|
|
|
|
assert (overflow == expect_overflow).all()
|
|
|
|
|
|
|
|
|
|
loss_scale = np.array(callback.lossscale_list)
|
|
|
|
|
expect_loss_scale = [32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0]
|
|
|
|
|
print("loss scale: {}".format(loss_scale))
|
|
|
|
|
assert np.allclose(loss_scale, expect_loss_scale, 0.00001, 0.00001)
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
test_bert_tdt()
|
|
|
|
|