|
|
|
@ -445,7 +445,7 @@ def _fill_param_into_net(net, parameter_list):
|
|
|
|
|
load_param_into_net(net, parameter_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def export(net, *inputs, file_name, file_format='GEIR'):
|
|
|
|
|
def export(net, *inputs, file_name, file_format='AIR'):
|
|
|
|
|
"""
|
|
|
|
|
Exports MindSpore predict model to file in specified format.
|
|
|
|
|
|
|
|
|
@ -453,11 +453,12 @@ def export(net, *inputs, file_name, file_format='GEIR'):
|
|
|
|
|
net (Cell): MindSpore network.
|
|
|
|
|
inputs (Tensor): Inputs of the `net`.
|
|
|
|
|
file_name (str): File name of model to export.
|
|
|
|
|
file_format (str): MindSpore currently supports 'GEIR', 'ONNX' and 'MINDIR' format for exported model.
|
|
|
|
|
file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model.
|
|
|
|
|
|
|
|
|
|
- GEIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
|
|
|
|
|
Ascend model.
|
|
|
|
|
- AIR: Ascend Intermidiate Representation. An intermidiate representation format of Ascend model.
|
|
|
|
|
Recommended suffix for output file is '.air'.
|
|
|
|
|
- ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
|
|
|
|
|
Recommended suffix for output file is '.onnx'.
|
|
|
|
|
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
|
|
|
|
|
for MindSpore models.
|
|
|
|
|
Recommended suffix for output file is '.mindir'.
|
|
|
|
@ -465,7 +466,11 @@ def export(net, *inputs, file_name, file_format='GEIR'):
|
|
|
|
|
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
|
|
|
|
check_input_data(*inputs, data_class=Tensor)
|
|
|
|
|
|
|
|
|
|
supported_formats = ['GEIR', 'ONNX', 'MINDIR']
|
|
|
|
|
if file_format == 'GEIR':
|
|
|
|
|
logger.warning(f"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead.")
|
|
|
|
|
file_format = 'AIR'
|
|
|
|
|
|
|
|
|
|
supported_formats = ['AIR', 'ONNX', 'MINDIR']
|
|
|
|
|
if file_format not in supported_formats:
|
|
|
|
|
raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}')
|
|
|
|
|
# switch network mode to infer when it is training
|
|
|
|
@ -474,13 +479,11 @@ def export(net, *inputs, file_name, file_format='GEIR'):
|
|
|
|
|
net.set_train(mode=False)
|
|
|
|
|
# export model
|
|
|
|
|
net.init_parameters_data()
|
|
|
|
|
if file_format == 'GEIR':
|
|
|
|
|
phase_name = 'export.geir'
|
|
|
|
|
if file_format == 'AIR':
|
|
|
|
|
phase_name = 'export.air'
|
|
|
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
|
|
|
|
|
_executor.export(file_name, graph_id)
|
|
|
|
|
elif file_format == 'ONNX': # file_format is 'ONNX'
|
|
|
|
|
# NOTICE: the pahse name `export_onnx` is used for judging whether is exporting onnx in the compile pipeline,
|
|
|
|
|
# do not change it to other values.
|
|
|
|
|
phase_name = 'export.onnx'
|
|
|
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
|
|
|
|
onnx_stream = _executor._get_func_graph_proto(graph_id)
|
|
|
|
|