From 8637acdcb8bd1323f90e61c94a5121571b926c49 Mon Sep 17 00:00:00 2001 From: ougongchang Date: Fri, 12 Mar 2021 15:08:40 +0800 Subject: [PATCH] Get a string path when the summary path is a list If device target is GPU, the device number may be set to 1. --- mindspore/train/callback/_summary_collector.py | 16 +++++++++++++--- mindspore/train/summary/_summary_adapter.py | 3 ++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index 987b6af2c2..365bab8d98 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -20,6 +20,7 @@ import json from json.decoder import JSONDecodeError from importlib import import_module +from collections.abc import Iterable import numpy as np @@ -842,12 +843,21 @@ class SummaryCollector(Callback): dataset_file_set = (dataset_package.MindDataset, dataset_package.ManifestDataset) dataset_files_set = (dataset_package.TFRecordDataset, dataset_package.TextFileDataset) + dataset_path = '' + if isinstance(output_dataset, dataset_file_set): - return output_dataset.dataset_file + dataset_path = output_dataset.dataset_file if isinstance(output_dataset, dataset_dir_set): - return output_dataset.dataset_dir + dataset_path = output_dataset.dataset_dir if isinstance(output_dataset, dataset_files_set): - return output_dataset.dataset_files[0] + dataset_path = output_dataset.dataset_files[0] + + if dataset_path: + if isinstance(dataset_path, str): + return dataset_path + if isinstance(dataset_path, Iterable): + return list(dataset_path)[0] + return self._get_dataset_path(output_dataset.children[0]) @staticmethod diff --git a/mindspore/train/summary/_summary_adapter.py b/mindspore/train/summary/_summary_adapter.py index bd71e8ed60..3620c191e8 100644 --- a/mindspore/train/summary/_summary_adapter.py +++ b/mindspore/train/summary/_summary_adapter.py @@ -22,6 +22,7 @@ from PIL import Image from mindspore import log as logger from mindspore import context from mindspore.communication.management import get_rank +from mindspore.communication.management import GlobalComm from ..._checkparam import Validator from ..anf_ir_pb2 import DataType, ModelProto @@ -57,7 +58,7 @@ def get_event_file_name(prefix, suffix, time_second): device_num = context.get_auto_parallel_context('device_num') device_id = context.get_context('device_id') - if device_num > 1: + if device_num > 1 or GlobalComm.WORLD_COMM_GROUP == 'nccl_world_group': # Notice: # In GPU distribute training scene, get_context('device_id') will not work, # so we use get_rank instead of get_context.