From 0d6db9a0a43a4f7e0bb31ba867e5b7aa8862b82b Mon Sep 17 00:00:00 2001 From: changzherui Date: Fri, 19 Feb 2021 19:55:02 +0800 Subject: [PATCH] modify read proto --- mindspore/train/_utils.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/mindspore/train/_utils.py b/mindspore/train/_utils.py index 37455489e5..8d42ad0f02 100644 --- a/mindspore/train/_utils.py +++ b/mindspore/train/_utils.py @@ -25,6 +25,7 @@ from mindspore import log as logger from mindspore.common.api import _executor from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model from mindspore.train.anf_ir_pb2 import ModelProto as anf_model +from mindspore.train.checkpoint_pb2 import Checkpoint from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo @@ -208,13 +209,14 @@ def check_value_type(arg_name, arg_value, valid_types): f'but got {type(arg_value).__name__}.') -def read_proto(file_name, proto_format="MINDIR"): +def read_proto(file_name, proto_format="MINDIR", display_data=False): """ Read protobuf file. Args: file_name (str): File name. - proto_format (str): Proto format. + proto_format (str): Proto format {MINDIR, ANF, CKPT}. Default: MINDIR. + display_data (bool): Whether display data. Default: False. Returns: Object, proto object. @@ -222,8 +224,10 @@ def read_proto(file_name, proto_format="MINDIR"): if proto_format == "MINDIR": model = mindir_model() - elif model_format == "ANF": + elif proto_format == "ANF": model = anf_model() + elif proto_format == "CKPT": + model = Checkpoint() else: raise ValueError("Unsupported proto format.") @@ -234,4 +238,13 @@ def read_proto(file_name, proto_format="MINDIR"): except BaseException as e: logger.error("Failed to read the file `%s`, please check the correct of the file.", file_name) raise ValueError(e.__str__()) + + if proto_format == "MINDIR" and not display_data: + for param_proto in model.graph.parameter: + param_proto.raw_data = b'\0' + + if proto_format == "CKPT" and not display_data: + for element in model.value: + element.tensor.tensor_content = b'\0' + return model