diff --git a/mindspore/ccsrc/utils/tensorprint_utils.cc b/mindspore/ccsrc/utils/tensorprint_utils.cc index ee53345f31..cdaa826c82 100644 --- a/mindspore/ccsrc/utils/tensorprint_utils.cc +++ b/mindspore/ccsrc/utils/tensorprint_utils.cc @@ -256,6 +256,7 @@ bool SaveDataItem2File(const std::vector &items, const std::strin if (!print.SerializeToOstream(output)) { MS_LOG(ERROR) << "Save print file:" << print_file_path << " fail."; ret_end_thread = true; + break; } print.Clear(); } diff --git a/mindspore/context.py b/mindspore/context.py index b5be6c3213..98dbfb327a 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -17,6 +17,7 @@ The context of mindspore, used to configure the current execution environment, including execution mode, execution backend and other feature switches. """ import os +import time import threading from collections import namedtuple from types import FunctionType @@ -55,12 +56,20 @@ def _make_directory(path): os.makedirs(path) real_path = path except PermissionError as e: - logger.error( - f"No write permission on the directory `{path}, error = {e}") + logger.error(f"No write permission on the directory `{path}, error = {e}") raise ValueError(f"No write permission on the directory `{path}`.") return real_path +def _get_print_file_name(file_name): + """Add timestamp suffix to file name. Rename the file name: file_name + "." + time(seconds).""" + time_second = str(int(time.time())) + file_name = file_name + "." + time_second + if os.path.exists(file_name): + ValueError("This file {} already exists.".format(file_name)) + return file_name + + class _ThreadLocalInfo(threading.local): """ Thread local Info used for store thread local attributes. @@ -381,8 +390,20 @@ class _Context: return None @print_file_path.setter - def print_file_path(self, file): - self._context_handle.set_print_file_path(file) + def print_file_path(self, file_path): + """Add timestamp suffix to file name. Sets print file path.""" + print_file_path = os.path.realpath(file_path) + if os.path.isdir(print_file_path): + raise IOError("Print_file_path should be file path, but got {}.".format(file_path)) + + if os.path.exists(print_file_path): + _path, _file_name = os.path.split(print_file_path) + path = _make_directory(_path) + file_name = _get_print_file_name(_file_name) + full_file_name = os.path.join(path, file_name) + else: + full_file_name = print_file_path + self._context_handle.set_print_file_path(full_file_name) def check_input_format(x): @@ -575,7 +596,8 @@ def set_context(**kwargs): max_device_memory (str): Sets the maximum memory available for device, currently only supported on GPU. The format is "xxGB". Default: "1024GB". print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to - a file by default, and turn off printing to the screen. + a file by default, and turn off printing to the screen. If the file already exists, add a timestamp + suffix to the file. enable_sparse (bool): Whether to enable sparse feature. Default: False. Raises: diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index d74bee2706..3812698419 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -302,7 +302,7 @@ def _save_graph(network, file_name): if graph_proto: with open(file_name, "wb") as f: f.write(graph_proto) - os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) + os.chmod(file_name, stat.S_IRUSR) def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): @@ -462,19 +462,18 @@ def parse_print(print_file_name): List, element of list is Tensor. Raises: - ValueError: Print file is incorrect. + ValueError: The print file may be empty, please make sure enter the correct file name. """ - if not os.path.realpath(print_file_name): - raise ValueError("Please input the correct print file name.") + print_file_path = os.path.realpath(print_file_name) - if os.path.getsize(print_file_name) == 0: + if os.path.getsize(print_file_path) == 0: raise ValueError("The print file may be empty, please make sure enter the correct file name.") logger.info("Execute load print process.") print_list = Print() try: - with open(print_file_name, "rb") as f: + with open(print_file_path, "rb") as f: pb_content = f.read() print_list.ParseFromString(pb_content) except BaseException as e: diff --git a/tests/ut/python/pynative_mode/test_context.py b/tests/ut/python/pynative_mode/test_context.py index 66dc0a4f58..e2d4e31412 100644 --- a/tests/ut/python/pynative_mode/test_context.py +++ b/tests/ut/python/pynative_mode/test_context.py @@ -118,6 +118,12 @@ def test_variable_memory_max_size(): context.set_context(variable_memory_max_size="3GB") +def test_print_file_path(): + """test_print_file_path""" + with pytest.raises(IOError): + context.set_context(print_file_path="./") + + def test_set_context(): """ test_set_context """ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", diff --git a/tests/ut/python/utils/test_serialize.py b/tests/ut/python/utils/test_serialize.py index 035ea87845..7f85695a19 100644 --- a/tests/ut/python/utils/test_serialize.py +++ b/tests/ut/python/utils/test_serialize.py @@ -34,7 +34,7 @@ from mindspore.train.serialization import save_checkpoint, load_checkpoint, load _exec_save_checkpoint, export, _save_graph from ..ut_filter import non_graph_engine -context.set_context(mode=context.GRAPH_MODE, print_file_path="print.pb") +context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb") class Net(nn.Cell): @@ -374,10 +374,13 @@ def test_print(): def teardown_module(): - files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt', 'print.pb'] + files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt'] for item in files: file_name = './' + item if not os.path.exists(file_name): continue os.chmod(file_name, stat.S_IWRITE) os.remove(file_name) + import shutil + if os.path.exists('./print'): + shutil.rmtree('./print')