From f0d18d2ce0d4959d8c68babb195354340d43d007 Mon Sep 17 00:00:00 2001 From: bai-yangfan Date: Mon, 19 Oct 2020 11:39:33 +0800 Subject: [PATCH] mode_export_modelzoo --- mindspore/train/serialization.py | 3 +-- model_zoo/official/cv/lenet_quant/export.py | 4 ++-- model_zoo/official/cv/mobilenetv2_quant/export.py | 4 ++-- tests/st/quantization/lenet_quant/test_lenet_quant.py | 4 ++-- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index acf05a256b..f63716056d 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -29,10 +29,9 @@ from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.common.api import _executor from mindspore.common import dtype as mstype -from mindspore._checkparam import check_input_data +from mindspore._checkparam import check_input_data, Validator from mindspore.train.quant import quant import mindspore.context as context -from .._checkparam import Validator __all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print", "build_searched_strategy", "merge_sliced_parameter"] diff --git a/model_zoo/official/cv/lenet_quant/export.py b/model_zoo/official/cv/lenet_quant/export.py index 380541587c..c250edc6ab 100644 --- a/model_zoo/official/cv/lenet_quant/export.py +++ b/model_zoo/official/cv/lenet_quant/export.py @@ -23,7 +23,7 @@ import mindspore from mindspore import Tensor from mindspore import context from mindspore.train.quant import quant -from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from src.config import mnist_cfg as cfg from src.lenet_fusion import LeNet5 as LeNet5Fusion @@ -52,4 +52,4 @@ if __name__ == "__main__": # export network inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32) - quant.export(network, inputs, file_name="lenet_quant", file_format='AIR') + export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='AUTO') diff --git a/model_zoo/official/cv/mobilenetv2_quant/export.py b/model_zoo/official/cv/mobilenetv2_quant/export.py index 6f4ad6aac1..83d8ff3dad 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/export.py +++ b/model_zoo/official/cv/mobilenetv2_quant/export.py @@ -20,7 +20,7 @@ import numpy as np import mindspore from mindspore import Tensor from mindspore import context -from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from mindspore.train.quant import quant from src.mobilenetV2 import mobilenetV2 @@ -50,5 +50,5 @@ if __name__ == '__main__': # export network print("============== Starting export ==============") inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32) - quant.export(network, inputs, file_name="mobilenet_quant", file_format='MINDIR') + export(network, inputs, file_name="mobilenet_quant", file_format='MINDIR', quant_mode='AUTO') print("============== End export ==============") diff --git a/tests/st/quantization/lenet_quant/test_lenet_quant.py b/tests/st/quantization/lenet_quant/test_lenet_quant.py index 2dca807d44..4dffcfccbe 100644 --- a/tests/st/quantization/lenet_quant/test_lenet_quant.py +++ b/tests/st/quantization/lenet_quant/test_lenet_quant.py @@ -24,7 +24,7 @@ from mindspore.common import dtype as mstype import mindspore.nn as nn from mindspore.nn.metrics import Accuracy from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor -from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from mindspore.train import Model from mindspore.train.quant import quant from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net @@ -136,7 +136,7 @@ def export_lenet(): # export network inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32) - quant.export(network, inputs, file_name="lenet_quant", file_format='MINDIR') + export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='AUTO') @pytest.mark.level0