!4683 split ci cases

Merge pull request !4683 from yoonlee666/cisplit
pull/4683/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit ff6cf58bfd

@ -17,12 +17,8 @@
import os
import time
import numpy as np
import pytest
from src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell
from src.bert_model import BertConfig
import mindspore.common.dtype as mstype
import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as C
@ -35,6 +31,10 @@ from mindspore.train.callback import Callback
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.train.model import Model
import mindspore.nn.learning_rate_schedule as lr_schedules
from model_zoo.official.nlp.bert.src.bert_for_pre_training import BertNetworkWithLoss
from model_zoo.official.nlp.bert.src.bert_for_pre_training import BertTrainOneStepWithLossScaleCell
from model_zoo.official.nlp.bert.src.bert_model import BertConfig
_current_dir = os.path.dirname(os.path.realpath(__file__))
DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"]
@ -177,74 +177,6 @@ class TimeMonitor(Callback):
self.epoch_mseconds_list.append(epoch_mseconds)
self.per_step_mseconds_list.append(epoch_mseconds / self.data_size)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_bert_percision():
"""test bert percision"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
ds, new_repeat_count, _ = me_de_train_dataset()
version = os.getenv('VERSION', 'large')
batch_size = 16
config = get_config(version=version, batch_size=batch_size)
netwithloss = BertNetworkWithLoss(config, True)
lr = BertLearningRate(decay_steps=ds.get_dataset_size()*new_repeat_count,
learning_rate=5e-5, end_learning_rate=1e-9,
power=10.0, warmup_steps=0)
decay_filter = lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()
no_decay_filter = lambda x: 'layernorm' in x.name.lower() or 'bias' in x.name.lower()
decay_params = list(filter(decay_filter, netwithloss.trainable_params()))
other_params = list(filter(no_decay_filter, netwithloss.trainable_params()))
group_params = [{'params': decay_params, 'weight_decay': 0.01},
{'params': other_params},
{'order_params': netwithloss.trainable_params()}]
optimizer = Lamb(group_params, lr)
scale_window = 3
scale_manager = DynamicLossScaleManager(2 ** 16, 2, scale_window)
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
scale_update_cell=scale_manager.get_update_cell())
netwithgrads.set_train(True)
model = Model(netwithgrads)
callback = ModelCallback()
params = netwithloss.trainable_params()
for param in params:
value = param.default_input
name = param.name
if isinstance(value, Tensor):
if name.split('.')[-1] in ['weight']:
if name.split('.')[-3] in ['cls2']:
logger.info("***************** BERT param name is 1 {}".format(name))
param.default_input = weight_variable(value.asnumpy().shape)
else:
logger.info("***************** BERT param name is 2 {}".format(name))
tempshape = value.asnumpy().shape
shape = (tempshape[1], tempshape[0])
weight_value = weight_variable(shape).asnumpy()
param.default_input = Tensor(np.transpose(weight_value, [1, 0]))
else:
logger.info("***************** BERT param name is 3 {}".format(name))
param.default_input = weight_variable(value.asnumpy().shape)
model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=False)
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(callback.loss_list)
assert np.allclose(loss_value[0], 12.206575, 0, 0.000001)
expect_loss_value = [12.206575, 11.865044, 11.828129, 11.826707, 11.82108, 12.407423, 12.005459,
12.621225, 12.222903, 12.427446]
print("loss value: {}".format(loss_value))
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)
overflow = np.array(callback.overflow_list)
expect_overflow = [False, False, False, True, False, False, False, True, False, False]
print("overflow: {}".format(overflow))
assert (overflow == expect_overflow).all()
loss_scale = np.array(callback.lossscale_list)
expect_loss_scale = [65536.0, 65536.0, 131072.0, 65536.0, 65536.0, 65536.0, 131072.0, 65536.0, 65536.0, 65536.0]
print("loss scale: {}".format(loss_scale))
assert np.allclose(loss_scale, expect_loss_scale, 0, 0)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@ -317,15 +249,14 @@ def test_bert_performance():
assert np.allclose(loss_scale, expect_loss_scale, 0, 0)
epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2]
expect_epoch_mseconds = 1600
expect_epoch_mseconds = 1400
print("epoch mseconds: {}".format(epoch_mseconds))
assert epoch_mseconds <= expect_epoch_mseconds + 5
per_step_mseconds = np.array(time_monitor_callback.per_step_mseconds_list)[2]
expect_per_step_mseconds = 16
expect_per_step_mseconds = 14
print("per step mseconds: {}".format(per_step_mseconds))
assert per_step_mseconds <= expect_per_step_mseconds + 1
if __name__ == '__main__':
test_bert_percision()
test_bert_performance()
Loading…
Cancel
Save