access_control

pull/7247/head
bai-yangfan 5 years ago
parent 86729985df
commit 763a7bd3aa

@ -19,6 +19,8 @@ train and infer lenet quantization network
import os
import pytest
from mindspore import context
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
@ -30,6 +32,7 @@ from dataset import create_dataset
from config import nonquant_cfg, quant_cfg
from lenet import LeNet5
from lenet_fusion import LeNet5 as LeNet5Fusion
import numpy as np
device_target = 'GPU'
data_path = "/home/workspace/mindspore_dataset/mnist"
@ -122,6 +125,19 @@ def eval_quant():
print("============== {} ==============".format(acc))
assert acc['Accuracy'] > 0.98
def export_lenet():
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
cfg = quant_cfg
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000,
per_channel=[True, False], symmetric=[True, False])
# export network
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32)
quant.export(network, inputs, file_name="lenet_quant", file_format='MINDIR')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -130,6 +146,7 @@ def test_lenet_quant():
train_lenet()
train_lenet_quant()
eval_quant()
export_lenet()
if __name__ == "__main__":

Loading…
Cancel
Save