From ae1e6bf176512249266603f61845fad7924594c8 Mon Sep 17 00:00:00 2001 From: caifubi Date: Tue, 2 Mar 2021 10:19:55 +0800 Subject: [PATCH] fix ci pynative case --- .../pynative/test_pynative_resnet50_ascend.py | 79 ++++++++++++------- 1 file changed, 50 insertions(+), 29 deletions(-) diff --git a/tests/st/pynative/test_pynative_resnet50_ascend.py b/tests/st/pynative/test_pynative_resnet50_ascend.py index 783605cda9..496718567f 100644 --- a/tests/st/pynative/test_pynative_resnet50_ascend.py +++ b/tests/st/pynative/test_pynative_resnet50_ascend.py @@ -32,7 +32,36 @@ from mindspore.nn import Cell from mindspore.ops import operations as P from mindspore.ops import composite as CP from mindspore.nn.optim.momentum import Momentum -from mindspore.nn.wrap.cell_wrapper import WithLossCell +from mindspore.train.callback import LossMonitor, Callback +from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.train.model import Model + +class MyTimeMonitor(Callback): + def __init__(self, data_size): + super(MyTimeMonitor, self).__init__() + self.data_size = data_size + self.total = 0 + + def epoch_begin(self, run_context): + self.epoch_time = time.time() + + def epoch_end(self, run_context): + epoch_msseconds = (time.time()-self.epoch_time) * 1000 + per_step_mssconds = epoch_msseconds / self.data_size + print("epoch time:{0}, per step time:{1}".format(epoch_msseconds, per_step_mssconds), flush=True) + + def step_begin(self, run_context): + self.step_time = time.time() + + def step_end(self, run_context): + step_msseconds = (time.time() - self.step_time) * 1000 + if step_msseconds < 265: + self.total = self.total + 1 + print(f"step time:{step_msseconds}", flush=True) + + def good_step(self): + return self.total random.seed(1) np.random.seed(1) @@ -303,12 +332,12 @@ def resnet50(batch_size, num_classes): return ResNet(ResidualBlock, num_classes, batch_size) -def create_dataset(repeat_num=1, training=True, batch_size=32): +def create_dataset(repeat_num=1, training=True, batch_size=32, num_samples=1600): data_home = "/home/workspace/mindspore_dataset" data_dir = data_home + "/cifar-10-batches-bin" if not training: data_dir = data_home + "/cifar-10-verify-bin" - data_set = ds.Cifar10Dataset(data_dir) + data_set = ds.Cifar10Dataset(data_dir, num_samples=num_samples) resize_height = 224 resize_width = 224 @@ -385,33 +414,25 @@ def test_pynative_resnet50(): batch_size = 32 num_classes = 10 + loss_scale = 128 + total_step = 50 net = resnet50(batch_size, num_classes) - criterion = CrossEntropyLoss() optimizer = Momentum(learning_rate=0.01, momentum=0.9, params=filter(lambda x: x.requires_grad, net.get_parameters())) + data_set = create_dataset(repeat_num=1, training=True, batch_size=batch_size, num_samples=total_step * batch_size) + + # define callbacks + time_cb = MyTimeMonitor(data_size=data_set.get_dataset_size()) + loss_cb = LossMonitor() + cb = [time_cb, loss_cb] + + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + loss_scale = FixedLossScaleManager(loss_scale=loss_scale, drop_overflow_update=False) + model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=False) + + # train model + model.train(1, data_set, callbacks=cb, + sink_size=data_set.get_dataset_size(), dataset_sink_mode=True) - net_with_criterion = WithLossCell(net, criterion) - net_with_criterion.set_grad() - train_network = GradWrap(net_with_criterion) - train_network.set_train() - - step = 0 - max_step = 21 - exceed_num = 0 - data_set = create_dataset(repeat_num=1, training=True, batch_size=batch_size) - for element in data_set.create_dict_iterator(num_epochs=1): - step = step + 1 - if step > max_step: - break - start_time = time.time() - input_data = element["image"] - input_label = element["label"] - loss_output = net_with_criterion(input_data, input_label) - grads = train_network(input_data, input_label) - optimizer(grads) - end_time = time.time() - cost_time = end_time - start_time - print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) - if step > 1 and cost_time > 0.25: - exceed_num = exceed_num + 1 - assert exceed_num < 20 + assert time_cb.good_step() > 10