|
|
|
@ -26,7 +26,7 @@ from mindspore import context
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell
|
|
|
|
|
from mindspore.nn.optim import Momentum
|
|
|
|
|
from mindspore.nn.optim import Lamb
|
|
|
|
|
from mindspore.train.callback import Callback
|
|
|
|
|
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
|
|
|
|
from mindspore.train.model import Model
|
|
|
|
@ -73,7 +73,7 @@ def get_config(version='base', batch_size=1):
|
|
|
|
|
max_position_embeddings=512,
|
|
|
|
|
type_vocab_size=2,
|
|
|
|
|
initializer_range=0.02,
|
|
|
|
|
use_relative_positions=True,
|
|
|
|
|
use_relative_positions=False,
|
|
|
|
|
input_mask_from_dataset=True,
|
|
|
|
|
token_type_ids_from_dataset=True,
|
|
|
|
|
dtype=mstype.float32,
|
|
|
|
@ -138,7 +138,9 @@ def test_bert_tdt():
|
|
|
|
|
batch_size = int(os.getenv('BATCH_SIZE', '16'))
|
|
|
|
|
config = get_config(version=version, batch_size=batch_size)
|
|
|
|
|
netwithloss = BertNetworkWithLoss(config, True)
|
|
|
|
|
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9)
|
|
|
|
|
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*ds.get_repeat_count(),
|
|
|
|
|
start_learning_rate=5e-5, end_learning_rate=1e-9,
|
|
|
|
|
power=10.0, warmup_steps=0, weight_decay=0.01)
|
|
|
|
|
scale_window = 3
|
|
|
|
|
scale_manager = DynamicLossScaleManager(2 ** 16, 2, scale_window)
|
|
|
|
|
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
|
|
|
|
@ -169,10 +171,10 @@ def test_bert_tdt():
|
|
|
|
|
|
|
|
|
|
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
|
|
|
|
loss_value = np.array(callback.loss_list)
|
|
|
|
|
expect_loss_value = [12.191826, 11.966009, 11.972208, 11.98216, 11.973932, 12.611078, 12.17554, 12.840299,
|
|
|
|
|
12.403329, 12.621632]
|
|
|
|
|
expect_loss_value = [12.207201, 11.980862, 11.984737, 11.879344, 11.832838, 12.411388,
|
|
|
|
|
12.009449, 12.621273, 12.223175, 12.427313]
|
|
|
|
|
print("loss value: {}".format(loss_value))
|
|
|
|
|
assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001)
|
|
|
|
|
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)
|
|
|
|
|
|
|
|
|
|
overflow = np.array(callback.overflow_list)
|
|
|
|
|
expect_overflow = [True, True, False, False, False, True, False, False, False, True]
|
|
|
|
@ -182,7 +184,7 @@ def test_bert_tdt():
|
|
|
|
|
loss_scale = np.array(callback.lossscale_list)
|
|
|
|
|
expect_loss_scale = [32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0]
|
|
|
|
|
print("loss scale: {}".format(loss_scale))
|
|
|
|
|
assert np.allclose(loss_scale, expect_loss_scale, 0.00001, 0.00001)
|
|
|
|
|
assert np.allclose(loss_scale, expect_loss_scale, 0, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|