From 2ec60543cd6663ae1a254ff5137fa6085703abac Mon Sep 17 00:00:00 2001 From: yepei6 Date: Sat, 27 Mar 2021 18:59:46 +0800 Subject: [PATCH] tensorprint adapt to print scalar --- mindspore/train/serialization.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index dd1076a094..e4eec594aa 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -830,13 +830,10 @@ def parse_print(print_file_name): np_type = tensor_to_np_type[data_type] param_data = np.fromstring(data, np_type) ms_type = tensor_to_ms_type[data_type] - param_dim = [] - for dim in dims: - param_dim.append(dim) - if param_dim: - param_value = param_data.reshape(param_dim) + if dims and dims != [0]: + param_value = param_data.reshape(dims) tensor_list.append(Tensor(param_value, ms_type)) - # Scale type + # Scalar type else: data_type_ = data_type.lower() if 'float' in data_type_: