You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/tests/st/quantization/lenet_quant/test_lenet_quant.py

137 lines
5.4 KiB

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
train and infer lenet quantization network
"""
import os
import pytest
from mindspore import context
import mindspore.nn as nn
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train import Model
from mindspore.train.quant import quant
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
from dataset import create_dataset
from config import nonquant_cfg, quant_cfg
from lenet import LeNet5
from lenet_fusion import LeNet5 as LeNet5Fusion
device_target = 'GPU'
data_path = "/home/workspace/mindspore_dataset/mnist"
def train_lenet():
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
cfg = nonquant_cfg
ds_train = create_dataset(os.path.join(data_path, "train"),
cfg.batch_size)
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="ckpt_lenet_noquant", config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Training Lenet==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
dataset_sink_mode=True)
def train_lenet_quant():
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
cfg = quant_cfg
ckpt_path = './ckpt_lenet_noquant-10_1875.ckpt'
ds_train = create_dataset(os.path.join(data_path, "train"), cfg.batch_size, 1)
step_size = ds_train.get_dataset_size()
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# load quantization aware network checkpoint
param_dict = load_checkpoint(ckpt_path)
load_nonquant_param_into_quant_net(network, param_dict)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=900, bn_fold=False, per_channel=[True, False],
symmetric=[False, False])
# define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
# call back and monitor
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="ckpt_lenet_quant", config=config_ckpt)
# define model
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()],
dataset_sink_mode=True)
print("============== End Training ==============")
def eval_quant():
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
cfg = quant_cfg
ds_eval = create_dataset(os.path.join(data_path, "test"), cfg.batch_size, 1)
ckpt_path = './ckpt_lenet_quant-10_937.ckpt'
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000,
per_channel=[True, False])
# define loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
# call back and monitor
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
# load quantization aware network checkpoint
param_dict = load_checkpoint(ckpt_path)
not_load_param = load_param_into_net(network, param_dict)
if not_load_param:
raise ValueError("Load param into net fail!")
print("============== Starting Testing ==============")
acc = model.eval(ds_eval, dataset_sink_mode=True)
print("============== {} ==============".format(acc))
assert acc['Accuracy'] > 0.98
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_lenet_quant():
train_lenet()
train_lenet_quant()
eval_quant()
if __name__ == "__main__":
train_lenet_quant()