From 763a7bd3aa5090e61a7f1dfd9bcf8b71095600ca Mon Sep 17 00:00:00 2001 From: bai-yangfan Date: Tue, 13 Oct 2020 17:14:55 +0800 Subject: [PATCH] access_control --- .../lenet_quant/test_lenet_quant.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/st/quantization/lenet_quant/test_lenet_quant.py b/tests/st/quantization/lenet_quant/test_lenet_quant.py index b801cdcfca..2dca807d44 100644 --- a/tests/st/quantization/lenet_quant/test_lenet_quant.py +++ b/tests/st/quantization/lenet_quant/test_lenet_quant.py @@ -19,6 +19,8 @@ train and infer lenet quantization network import os import pytest from mindspore import context +from mindspore import Tensor +from mindspore.common import dtype as mstype import mindspore.nn as nn from mindspore.nn.metrics import Accuracy from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor @@ -30,6 +32,7 @@ from dataset import create_dataset from config import nonquant_cfg, quant_cfg from lenet import LeNet5 from lenet_fusion import LeNet5 as LeNet5Fusion +import numpy as np device_target = 'GPU' data_path = "/home/workspace/mindspore_dataset/mnist" @@ -122,6 +125,19 @@ def eval_quant(): print("============== {} ==============".format(acc)) assert acc['Accuracy'] > 0.98 +def export_lenet(): + context.set_context(mode=context.GRAPH_MODE, device_target=device_target) + cfg = quant_cfg + # define fusion network + network = LeNet5Fusion(cfg.num_classes) + # convert fusion network to quantization aware network + network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, + per_channel=[True, False], symmetric=[True, False]) + + # export network + inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32) + quant.export(network, inputs, file_name="lenet_quant", file_format='MINDIR') + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -130,6 +146,7 @@ def test_lenet_quant(): train_lenet() train_lenet_quant() eval_quant() + export_lenet() if __name__ == "__main__":