!7444 mode_export_modelzoo

Merge pull request !7444 from baiyangfan/mode_export_modelzoo
pull/7444/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 8d4a127bd4

@ -29,10 +29,9 @@ from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore._checkparam import check_input_data from mindspore._checkparam import check_input_data, Validator
from mindspore.train.quant import quant from mindspore.train.quant import quant
import mindspore.context as context import mindspore.context as context
from .._checkparam import Validator
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print", __all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print",
"build_searched_strategy", "merge_sliced_parameter"] "build_searched_strategy", "merge_sliced_parameter"]

@ -23,7 +23,7 @@ import mindspore
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.train.quant import quant from mindspore.train.quant import quant
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import mnist_cfg as cfg from src.config import mnist_cfg as cfg
from src.lenet_fusion import LeNet5 as LeNet5Fusion from src.lenet_fusion import LeNet5 as LeNet5Fusion
@ -52,4 +52,4 @@ if __name__ == "__main__":
# export network # export network
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32) inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32)
quant.export(network, inputs, file_name="lenet_quant", file_format='AIR') export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='AUTO')

@ -20,7 +20,7 @@ import numpy as np
import mindspore import mindspore
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore.train.quant import quant from mindspore.train.quant import quant
from src.mobilenetV2 import mobilenetV2 from src.mobilenetV2 import mobilenetV2
@ -50,5 +50,5 @@ if __name__ == '__main__':
# export network # export network
print("============== Starting export ==============") print("============== Starting export ==============")
inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32) inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32)
quant.export(network, inputs, file_name="mobilenet_quant", file_format='MINDIR') export(network, inputs, file_name="mobilenet_quant", file_format='MINDIR', quant_mode='AUTO')
print("============== End export ==============") print("============== End export ==============")

@ -24,7 +24,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore.train import Model from mindspore.train import Model
from mindspore.train.quant import quant from mindspore.train.quant import quant
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
@ -136,7 +136,7 @@ def export_lenet():
# export network # export network
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32) inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32)
quant.export(network, inputs, file_name="lenet_quant", file_format='MINDIR') export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='AUTO')
@pytest.mark.level0 @pytest.mark.level0

Loading…
Cancel
Save