|
|
|
@ -25,6 +25,7 @@ from mindspore import log as logger
|
|
|
|
|
from mindspore.common.api import _executor
|
|
|
|
|
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
|
|
|
|
|
from mindspore.train.anf_ir_pb2 import ModelProto as anf_model
|
|
|
|
|
from mindspore.train.checkpoint_pb2 import Checkpoint
|
|
|
|
|
|
|
|
|
|
from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo
|
|
|
|
|
|
|
|
|
@ -208,13 +209,14 @@ def check_value_type(arg_name, arg_value, valid_types):
|
|
|
|
|
f'but got {type(arg_value).__name__}.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_proto(file_name, proto_format="MINDIR"):
|
|
|
|
|
def read_proto(file_name, proto_format="MINDIR", display_data=False):
|
|
|
|
|
"""
|
|
|
|
|
Read protobuf file.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
file_name (str): File name.
|
|
|
|
|
proto_format (str): Proto format.
|
|
|
|
|
proto_format (str): Proto format {MINDIR, ANF, CKPT}. Default: MINDIR.
|
|
|
|
|
display_data (bool): Whether display data. Default: False.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Object, proto object.
|
|
|
|
@ -222,8 +224,10 @@ def read_proto(file_name, proto_format="MINDIR"):
|
|
|
|
|
|
|
|
|
|
if proto_format == "MINDIR":
|
|
|
|
|
model = mindir_model()
|
|
|
|
|
elif model_format == "ANF":
|
|
|
|
|
elif proto_format == "ANF":
|
|
|
|
|
model = anf_model()
|
|
|
|
|
elif proto_format == "CKPT":
|
|
|
|
|
model = Checkpoint()
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Unsupported proto format.")
|
|
|
|
|
|
|
|
|
@ -234,4 +238,13 @@ def read_proto(file_name, proto_format="MINDIR"):
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
logger.error("Failed to read the file `%s`, please check the correct of the file.", file_name)
|
|
|
|
|
raise ValueError(e.__str__())
|
|
|
|
|
|
|
|
|
|
if proto_format == "MINDIR" and not display_data:
|
|
|
|
|
for param_proto in model.graph.parameter:
|
|
|
|
|
param_proto.raw_data = b'\0'
|
|
|
|
|
|
|
|
|
|
if proto_format == "CKPT" and not display_data:
|
|
|
|
|
for element in model.value:
|
|
|
|
|
element.tensor.tensor_content = b'\0'
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
|