modify read proto

pull/12435/head
changzherui 4 years ago
parent 28cbab85ed
commit 0d6db9a0a4

@ -25,6 +25,7 @@ from mindspore import log as logger
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
from mindspore.train.anf_ir_pb2 import ModelProto as anf_model from mindspore.train.anf_ir_pb2 import ModelProto as anf_model
from mindspore.train.checkpoint_pb2 import Checkpoint
from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo
@ -208,13 +209,14 @@ def check_value_type(arg_name, arg_value, valid_types):
f'but got {type(arg_value).__name__}.') f'but got {type(arg_value).__name__}.')
def read_proto(file_name, proto_format="MINDIR"): def read_proto(file_name, proto_format="MINDIR", display_data=False):
""" """
Read protobuf file. Read protobuf file.
Args: Args:
file_name (str): File name. file_name (str): File name.
proto_format (str): Proto format. proto_format (str): Proto format {MINDIR, ANF, CKPT}. Default: MINDIR.
display_data (bool): Whether display data. Default: False.
Returns: Returns:
Object, proto object. Object, proto object.
@ -222,8 +224,10 @@ def read_proto(file_name, proto_format="MINDIR"):
if proto_format == "MINDIR": if proto_format == "MINDIR":
model = mindir_model() model = mindir_model()
elif model_format == "ANF": elif proto_format == "ANF":
model = anf_model() model = anf_model()
elif proto_format == "CKPT":
model = Checkpoint()
else: else:
raise ValueError("Unsupported proto format.") raise ValueError("Unsupported proto format.")
@ -234,4 +238,13 @@ def read_proto(file_name, proto_format="MINDIR"):
except BaseException as e: except BaseException as e:
logger.error("Failed to read the file `%s`, please check the correct of the file.", file_name) logger.error("Failed to read the file `%s`, please check the correct of the file.", file_name)
raise ValueError(e.__str__()) raise ValueError(e.__str__())
if proto_format == "MINDIR" and not display_data:
for param_proto in model.graph.parameter:
param_proto.raw_data = b'\0'
if proto_format == "CKPT" and not display_data:
for element in model.value:
element.tensor.tensor_content = b'\0'
return model return model

Loading…
Cancel
Save