|
|
|
@ -29,8 +29,7 @@ from mindspore.common.api import _executor
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
from mindspore._checkparam import check_input_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"]
|
|
|
|
|
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print"]
|
|
|
|
|
|
|
|
|
|
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
|
|
|
|
|
"Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64,
|
|
|
|
@ -513,6 +512,13 @@ def parse_print(print_file_name):
|
|
|
|
|
tensor_list.append(Tensor(param_value, ms_type))
|
|
|
|
|
# Scale type
|
|
|
|
|
else:
|
|
|
|
|
data_type_ = data_type.lower()
|
|
|
|
|
if 'float' in data_type_:
|
|
|
|
|
param_data = float(param_data[0])
|
|
|
|
|
elif 'int' in data_type_:
|
|
|
|
|
param_data = int(param_data[0])
|
|
|
|
|
elif 'bool' in data_type_:
|
|
|
|
|
param_data = bool(param_data[0])
|
|
|
|
|
tensor_list.append(Tensor(param_data, ms_type))
|
|
|
|
|
|
|
|
|
|
except BaseException as e:
|
|
|
|
|