From 8728fd4a72d7d7aac8784363ebc208d258839237 Mon Sep 17 00:00:00 2001 From: helloiSCSI Date: Mon, 7 Dec 2020 14:53:19 +0800 Subject: [PATCH] fix summary GPU st --- tests/st/summary/test_summary.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/st/summary/test_summary.py b/tests/st/summary/test_summary.py index c11cdb8e2b..6dbc6b96c5 100644 --- a/tests/st/summary/test_summary.py +++ b/tests/st/summary/test_summary.py @@ -143,13 +143,13 @@ class TestSummary: if os.path.exists(cls.base_summary_dir): shutil.rmtree(cls.base_summary_dir) - def _run_network(self, dataset_sink_mode=False, num_samples=2): + def _run_network(self, dataset_sink_mode=False, num_samples=2, **kwargs): lenet = LeNet5() loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") optim = Momentum(lenet.trainable_params(), learning_rate=0.1, momentum=0.9) - model = Model(lenet, loss_fn=loss, optimizer=optim, metrics={'acc': Loss()}) + model = Model(lenet, loss_fn=loss, optimizer=optim, metrics={'loss': Loss()}) summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) - summary_collector = SummaryCollector(summary_dir=summary_dir, collect_freq=2) + summary_collector = SummaryCollector(summary_dir=summary_dir, collect_freq=2, **kwargs) ds_train = create_dataset(os.path.join(self.mnist_path, "train"), num_samples=num_samples) model.train(1, ds_train, callbacks=[summary_collector], dataset_sink_mode=dataset_sink_mode) @@ -161,6 +161,7 @@ class TestSummary: @pytest.mark.level0 @pytest.mark.platform_x86_ascend_training @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_summary_with_sink_mode_false(self): """Test summary with sink mode false, and num samples is 64.""" @@ -182,6 +183,7 @@ class TestSummary: @pytest.mark.level0 @pytest.mark.platform_x86_ascend_training @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_summary_with_sink_mode_true(self): """Test summary with sink mode true, and num samples is 64.""" @@ -198,6 +200,20 @@ class TestSummary: for value in Counter(tag_list).values(): assert value == tag_count + @pytest.mark.level0 + @pytest.mark.platform_x86_ascend_training + @pytest.mark.env_onecard + def test_summarycollector_user_defind(self): + """Test SummaryCollector with user defind.""" + summary_dir = self._run_network(dataset_sink_mode=True, num_samples=2, user_defind={'test': 'self test'}) + + tag_list = self._list_summary_tags(summary_dir) + # There will not record input data when dataset sink mode is True + expected_tags = {'conv1.weight/auto', 'conv2.weight/auto', 'fc1.weight/auto', 'fc1.bias/auto', + 'fc2.weight/auto', 'loss/auto', 'histogram', 'image', 'scalar', 'tensor'} + assert set(expected_tags) == set(tag_list) + + @staticmethod def _list_summary_tags(summary_dir): summary_file_path = ''