|
|
|
@ -19,6 +19,8 @@ train and infer lenet quantization network
|
|
|
|
|
import os
|
|
|
|
|
import pytest
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore import Tensor
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
from mindspore.nn.metrics import Accuracy
|
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
|
|
|
@ -30,6 +32,7 @@ from dataset import create_dataset
|
|
|
|
|
from config import nonquant_cfg, quant_cfg
|
|
|
|
|
from lenet import LeNet5
|
|
|
|
|
from lenet_fusion import LeNet5 as LeNet5Fusion
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
device_target = 'GPU'
|
|
|
|
|
data_path = "/home/workspace/mindspore_dataset/mnist"
|
|
|
|
@ -122,6 +125,19 @@ def eval_quant():
|
|
|
|
|
print("============== {} ==============".format(acc))
|
|
|
|
|
assert acc['Accuracy'] > 0.98
|
|
|
|
|
|
|
|
|
|
def export_lenet():
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
|
|
|
|
cfg = quant_cfg
|
|
|
|
|
# define fusion network
|
|
|
|
|
network = LeNet5Fusion(cfg.num_classes)
|
|
|
|
|
# convert fusion network to quantization aware network
|
|
|
|
|
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000,
|
|
|
|
|
per_channel=[True, False], symmetric=[True, False])
|
|
|
|
|
|
|
|
|
|
# export network
|
|
|
|
|
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32)
|
|
|
|
|
quant.export(network, inputs, file_name="lenet_quant", file_format='MINDIR')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_x86_gpu_training
|
|
|
|
@ -130,6 +146,7 @@ def test_lenet_quant():
|
|
|
|
|
train_lenet()
|
|
|
|
|
train_lenet_quant()
|
|
|
|
|
eval_quant()
|
|
|
|
|
export_lenet()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|