diff --git a/tests/st/quantization/mobilenetv2_quant/test_mobilenetv2_quant.py b/tests/st/quantization/mobilenetv2_quant/test_mobilenetv2_quant.py index 1a9f0708e7..363ef83d80 100644 --- a/tests/st/quantization/mobilenetv2_quant/test_mobilenetv2_quant.py +++ b/tests/st/quantization/mobilenetv2_quant/test_mobilenetv2_quant.py @@ -55,10 +55,10 @@ config_ascend_quant = ed({ dataset_path = "/home/workspace/mindspore_dataset/cifar-10-batches-bin/" -@pytest.mark.level1 +@pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training -@pytest.mark.env_onecard +@pytest.mark.env_single def test_mobilenetv2_quant(): set_seed(1) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") @@ -111,9 +111,12 @@ def test_mobilenetv2_quant(): dataset_sink_mode=False) print("============== End Training ==============") + export_time_used = 700 + train_time = monitor.step_mseconds + print('train_time_used:{}'.format(train_time)) + assert train_time < export_time_used expect_avg_step_loss = 2.32 avg_step_loss = np.mean(np.array(monitor.losses)) - print("average step loss:{}".format(avg_step_loss)) assert avg_step_loss < expect_avg_step_loss diff --git a/tests/st/quantization/mobilenetv2_quant/utils.py b/tests/st/quantization/mobilenetv2_quant/utils.py index 7b3ab0c3b7..60bd08f13a 100644 --- a/tests/st/quantization/mobilenetv2_quant/utils.py +++ b/tests/st/quantization/mobilenetv2_quant/utils.py @@ -45,7 +45,7 @@ class Monitor(Callback): self.lr_init = lr_init self.lr_init_len = len(lr_init) self.step_threshold = step_threshold - self.step_mseconds = 0 + self.step_mseconds = 50000 def epoch_begin(self, run_context): self.losses = [] @@ -66,7 +66,8 @@ class Monitor(Callback): def step_end(self, run_context): cb_params = run_context.original_args() - self.step_mseconds = (time.time() - self.step_time) * 1000 + step_mseconds = (time.time() - self.step_time) * 1000 + self.step_mseconds = min(self.step_mseconds, step_mseconds) step_loss = cb_params.net_outputs if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):