|
|
|
@ -25,7 +25,7 @@ from mindspore import nn
|
|
|
|
|
from mindspore.train.model import Model, ParallelMode
|
|
|
|
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
|
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
|
from mindspore.train.serialization import load_checkpoint
|
|
|
|
|
from mindspore.communication.management import init, get_group_size, get_rank
|
|
|
|
|
from mindspore.train.quant import quant
|
|
|
|
|
import mindspore.dataset.engine as de
|
|
|
|
@ -33,8 +33,9 @@ import mindspore.dataset.engine as de
|
|
|
|
|
from src.dataset import create_dataset
|
|
|
|
|
from src.lr_generator import get_lr
|
|
|
|
|
from src.utils import Monitor, CrossEntropyWithLabelSmooth
|
|
|
|
|
from src.config import config_ascend_quant, config_ascend, config_gpu_quant, config_gpu
|
|
|
|
|
from src.config import config_ascend_quant, config_gpu_quant
|
|
|
|
|
from src.mobilenetV2 import mobilenetV2
|
|
|
|
|
from src.utils import _load_param_into_net
|
|
|
|
|
|
|
|
|
|
random.seed(1)
|
|
|
|
|
np.random.seed(1)
|
|
|
|
@ -44,7 +45,6 @@ parser = argparse.ArgumentParser(description='Image classification')
|
|
|
|
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
|
|
|
|
parser.add_argument('--pre_trained', type=str, default=None, help='Pertained checkpoint path')
|
|
|
|
|
parser.add_argument('--device_target', type=str, default=None, help='Run device target')
|
|
|
|
|
parser.add_argument('--quantization_aware', type=bool, default=False, help='Use quantization aware training')
|
|
|
|
|
args_opt = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
if args_opt.device_target == "Ascend":
|
|
|
|
@ -69,7 +69,7 @@ else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_on_ascend():
|
|
|
|
|
config = config_ascend_quant if args_opt.quantization_aware else config_ascend
|
|
|
|
|
config = config_ascend_quant
|
|
|
|
|
print("training args: {}".format(args_opt))
|
|
|
|
|
print("training configure: {}".format(config))
|
|
|
|
|
print("parallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))
|
|
|
|
@ -101,14 +101,12 @@ def train_on_ascend():
|
|
|
|
|
# load pre trained ckpt
|
|
|
|
|
if args_opt.pre_trained:
|
|
|
|
|
param_dict = load_checkpoint(args_opt.pre_trained)
|
|
|
|
|
load_param_into_net(network, param_dict)
|
|
|
|
|
|
|
|
|
|
_load_param_into_net(network, param_dict)
|
|
|
|
|
# convert fusion network to quantization aware network
|
|
|
|
|
if config.quantization_aware:
|
|
|
|
|
network = quant.convert_quant_network(network,
|
|
|
|
|
bn_fold=True,
|
|
|
|
|
per_channel=[True, False],
|
|
|
|
|
symmetric=[True, False])
|
|
|
|
|
network = quant.convert_quant_network(network,
|
|
|
|
|
bn_fold=True,
|
|
|
|
|
per_channel=[True, False],
|
|
|
|
|
symmetric=[True, False])
|
|
|
|
|
|
|
|
|
|
# get learning rate
|
|
|
|
|
lr = Tensor(get_lr(global_step=config.start_epoch * step_size,
|
|
|
|
@ -141,7 +139,7 @@ def train_on_ascend():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_on_gpu():
|
|
|
|
|
config = config_gpu_quant if args_opt.quantization_aware else config_gpu
|
|
|
|
|
config = config_gpu_quant
|
|
|
|
|
print("training args: {}".format(args_opt))
|
|
|
|
|
print("training configure: {}".format(config))
|
|
|
|
|
|
|
|
|
@ -165,14 +163,15 @@ def train_on_gpu():
|
|
|
|
|
# resume
|
|
|
|
|
if args_opt.pre_trained:
|
|
|
|
|
param_dict = load_checkpoint(args_opt.pre_trained)
|
|
|
|
|
load_param_into_net(network, param_dict)
|
|
|
|
|
_load_param_into_net(network, param_dict)
|
|
|
|
|
|
|
|
|
|
# convert fusion network to quantization aware network
|
|
|
|
|
if config.quantization_aware:
|
|
|
|
|
network = quant.convert_quant_network(network,
|
|
|
|
|
bn_fold=True,
|
|
|
|
|
per_channel=[True, False],
|
|
|
|
|
symmetric=[True, True])
|
|
|
|
|
network = quant.convert_quant_network(network,
|
|
|
|
|
bn_fold=True,
|
|
|
|
|
per_channel=[True, False],
|
|
|
|
|
symmetric=[True, True],
|
|
|
|
|
freeze_bn=1000000,
|
|
|
|
|
quant_delay=step_size * 2)
|
|
|
|
|
|
|
|
|
|
# get learning rate
|
|
|
|
|
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
|
|
|
|