|
|
|
@ -478,9 +478,9 @@ def export(net, *inputs, file_name, file_format='AIR'):
|
|
|
|
|
supported_formats = ['AIR', 'ONNX', 'MINDIR']
|
|
|
|
|
if file_format not in supported_formats:
|
|
|
|
|
raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}')
|
|
|
|
|
# switch network mode to infer when it is training
|
|
|
|
|
is_training = net.training
|
|
|
|
|
if is_training:
|
|
|
|
|
# When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction)
|
|
|
|
|
is_dump_onnx_in_training = net.training and file_format == 'ONNX'
|
|
|
|
|
if is_dump_onnx_in_training:
|
|
|
|
|
net.set_train(mode=False)
|
|
|
|
|
# export model
|
|
|
|
|
net.init_parameters_data()
|
|
|
|
@ -503,7 +503,7 @@ def export(net, *inputs, file_name, file_format='AIR'):
|
|
|
|
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
|
|
|
|
f.write(onnx_stream)
|
|
|
|
|
# restore network training mode
|
|
|
|
|
if is_training:
|
|
|
|
|
if is_dump_onnx_in_training:
|
|
|
|
|
net.set_train(mode=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|