!13225 Get a string path when the summary path is a list

From: @ouwenchang
Reviewed-by: @yelihua,@lixiaohui33
Signed-off-by: @lixiaohui33
pull/13225/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 4c9ad93e13

@ -20,6 +20,7 @@ import json
from json.decoder import JSONDecodeError
from importlib import import_module
from collections.abc import Iterable
import numpy as np
@ -842,12 +843,21 @@ class SummaryCollector(Callback):
dataset_file_set = (dataset_package.MindDataset, dataset_package.ManifestDataset)
dataset_files_set = (dataset_package.TFRecordDataset, dataset_package.TextFileDataset)
dataset_path = ''
if isinstance(output_dataset, dataset_file_set):
return output_dataset.dataset_file
dataset_path = output_dataset.dataset_file
if isinstance(output_dataset, dataset_dir_set):
return output_dataset.dataset_dir
dataset_path = output_dataset.dataset_dir
if isinstance(output_dataset, dataset_files_set):
return output_dataset.dataset_files[0]
dataset_path = output_dataset.dataset_files[0]
if dataset_path:
if isinstance(dataset_path, str):
return dataset_path
if isinstance(dataset_path, Iterable):
return list(dataset_path)[0]
return self._get_dataset_path(output_dataset.children[0])
@staticmethod

@ -22,6 +22,7 @@ from PIL import Image
from mindspore import log as logger
from mindspore import context
from mindspore.communication.management import get_rank
from mindspore.communication.management import GlobalComm
from ..._checkparam import Validator
from ..anf_ir_pb2 import DataType, ModelProto
@ -57,7 +58,7 @@ def get_event_file_name(prefix, suffix, time_second):
device_num = context.get_auto_parallel_context('device_num')
device_id = context.get_context('device_id')
if device_num > 1:
if device_num > 1 or GlobalComm.WORLD_COMM_GROUP == 'nccl_world_group':
# Notice:
# In GPU distribute training scene, get_context('device_id') will not work,
# so we use get_rank instead of get_context.

Loading…
Cancel
Save