fix print file bug

pull/2869/head
jinyaohui 5 years ago
parent c0e454c07b
commit c7f6527e92

@ -256,6 +256,7 @@ bool SaveDataItem2File(const std::vector<tdt::DataItem> &items, const std::strin
if (!print.SerializeToOstream(output)) {
MS_LOG(ERROR) << "Save print file:" << print_file_path << " fail.";
ret_end_thread = true;
break;
}
print.Clear();
}

@ -17,6 +17,7 @@ The context of mindspore, used to configure the current execution environment,
including execution mode, execution backend and other feature switches.
"""
import os
import time
import threading
from collections import namedtuple
from types import FunctionType
@ -55,12 +56,20 @@ def _make_directory(path):
os.makedirs(path)
real_path = path
except PermissionError as e:
logger.error(
f"No write permission on the directory `{path}, error = {e}")
logger.error(f"No write permission on the directory `{path}, error = {e}")
raise ValueError(f"No write permission on the directory `{path}`.")
return real_path
def _get_print_file_name(file_name):
"""Add timestamp suffix to file name. Rename the file name: file_name + "." + time(seconds)."""
time_second = str(int(time.time()))
file_name = file_name + "." + time_second
if os.path.exists(file_name):
ValueError("This file {} already exists.".format(file_name))
return file_name
class _ThreadLocalInfo(threading.local):
"""
Thread local Info used for store thread local attributes.
@ -381,8 +390,20 @@ class _Context:
return None
@print_file_path.setter
def print_file_path(self, file):
self._context_handle.set_print_file_path(file)
def print_file_path(self, file_path):
"""Add timestamp suffix to file name. Sets print file path."""
print_file_path = os.path.realpath(file_path)
if os.path.isdir(print_file_path):
raise IOError("Print_file_path should be file path, but got {}.".format(file_path))
if os.path.exists(print_file_path):
_path, _file_name = os.path.split(print_file_path)
path = _make_directory(_path)
file_name = _get_print_file_name(_file_name)
full_file_name = os.path.join(path, file_name)
else:
full_file_name = print_file_path
self._context_handle.set_print_file_path(full_file_name)
def check_input_format(x):
@ -575,7 +596,8 @@ def set_context(**kwargs):
max_device_memory (str): Sets the maximum memory available for device, currently only supported on GPU.
The format is "xxGB". Default: "1024GB".
print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
a file by default, and turn off printing to the screen.
a file by default, and turn off printing to the screen. If the file already exists, add a timestamp
suffix to the file.
enable_sparse (bool): Whether to enable sparse feature. Default: False.
Raises:

@ -302,7 +302,7 @@ def _save_graph(network, file_name):
if graph_proto:
with open(file_name, "wb") as f:
f.write(graph_proto)
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
os.chmod(file_name, stat.S_IRUSR)
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
@ -462,19 +462,18 @@ def parse_print(print_file_name):
List, element of list is Tensor.
Raises:
ValueError: Print file is incorrect.
ValueError: The print file may be empty, please make sure enter the correct file name.
"""
if not os.path.realpath(print_file_name):
raise ValueError("Please input the correct print file name.")
print_file_path = os.path.realpath(print_file_name)
if os.path.getsize(print_file_name) == 0:
if os.path.getsize(print_file_path) == 0:
raise ValueError("The print file may be empty, please make sure enter the correct file name.")
logger.info("Execute load print process.")
print_list = Print()
try:
with open(print_file_name, "rb") as f:
with open(print_file_path, "rb") as f:
pb_content = f.read()
print_list.ParseFromString(pb_content)
except BaseException as e:

@ -118,6 +118,12 @@ def test_variable_memory_max_size():
context.set_context(variable_memory_max_size="3GB")
def test_print_file_path():
"""test_print_file_path"""
with pytest.raises(IOError):
context.set_context(print_file_path="./")
def test_set_context():
""" test_set_context """
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",

@ -34,7 +34,7 @@ from mindspore.train.serialization import save_checkpoint, load_checkpoint, load
_exec_save_checkpoint, export, _save_graph
from ..ut_filter import non_graph_engine
context.set_context(mode=context.GRAPH_MODE, print_file_path="print.pb")
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
class Net(nn.Cell):
@ -374,10 +374,13 @@ def test_print():
def teardown_module():
files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt', 'print.pb']
files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt']
for item in files:
file_name = './' + item
if not os.path.exists(file_name):
continue
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
import shutil
if os.path.exists('./print'):
shutil.rmtree('./print')

Loading…
Cancel
Save