parent
6cf308076d
commit
c8fa4a424a
@ -0,0 +1,59 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""evaluate_imagenet"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.inceptionv4 import Inceptionv4
|
||||
from src.config import config_ascend as config
|
||||
|
||||
def parse_args():
|
||||
'''parse_args'''
|
||||
parser = argparse.ArgumentParser(description='image classification evaluation')
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform')
|
||||
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||
parser.add_argument('--checkpoint_path', type=str, default='', help='checkpoint of inceptionV4')
|
||||
args_opt = parser.parse_args()
|
||||
return args_opt
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
if args.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
|
||||
net = Inceptionv4(classes=config.num_classes)
|
||||
ckpt = load_checkpoint(args.checkpoint_path)
|
||||
load_param_into_net(net, ckpt)
|
||||
net.set_train(False)
|
||||
dataset = create_dataset(dataset_path=args.dataset_path, do_train=False,
|
||||
repeat_num=1, batch_size=config.batch_size)
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
eval_metrics = {'Loss': nn.Loss(),
|
||||
'Top1-Acc': nn.Top1CategoricalAccuracy(),
|
||||
'Top5-Acc': nn.Top5CategoricalAccuracy()}
|
||||
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
|
||||
print('='*20, 'Evalute start', '='*20)
|
||||
metrics = model.eval(dataset)
|
||||
print("metric: ", metrics)
|
@ -0,0 +1,46 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into air and onnx models#################
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.config import config_ascend as config
|
||||
from src.inceptionv4 import Inceptionv4
|
||||
|
||||
def parse_args():
|
||||
'''parse_args'''
|
||||
parser = argparse.ArgumentParser(description='checkpoint export')
|
||||
parser.add_argument('--model_name', type=str, default='inceptionV4.air', help='convert model name of inceptionv4')
|
||||
parser.add_argument('--format', type=str, default='AIR', help='convert model name of inceptionv4')
|
||||
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of inceptionv4')
|
||||
_args_opt = parser.parse_args()
|
||||
return _args_opt
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
args_opt = parse_args()
|
||||
|
||||
net = Inceptionv4(classes=config.num_classes)
|
||||
param_dict = load_checkpoint(args_opt.checkpoint)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, 299, 299]), ms.float32)
|
||||
export(net, input_arr, file_name=args_opt.model_name, file_format=args_opt.format)
|
@ -0,0 +1,49 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export RANK_TABLE_FILE=$1
|
||||
DATA_DIR=$2
|
||||
export RANK_SIZE=8
|
||||
|
||||
|
||||
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
|
||||
echo "the number of logical core" $cores
|
||||
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
|
||||
core_gap=`expr $avg_core_per_rank \- 1`
|
||||
echo "avg_core_per_rank" $avg_core_per_rank
|
||||
echo "core_gap" $core_gap
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
start=`expr $i \* $avg_core_per_rank`
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
export DEPLOY_MODE=0
|
||||
export GE_USE_STATIC_MEMORY=1
|
||||
end=`expr $start \+ $core_gap`
|
||||
cmdopt=$start"-"$end
|
||||
|
||||
rm -rf train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp *.py ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $i, device $DEVICE_ID rank_id $RANK_ID"
|
||||
|
||||
env > env.log
|
||||
taskset -c $cmdopt python -u ../train.py \
|
||||
--device_id $i \
|
||||
--dataset_path=$DATA_DIR > log.txt 2>&1 &
|
||||
cd ../
|
||||
done
|
@ -0,0 +1,28 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_ID=$1
|
||||
DATA_DIR=$2
|
||||
CHECKPOINT_PATH=$3
|
||||
export RANK_SIZE=1
|
||||
|
||||
rm -rf evaluation_ascend
|
||||
mkdir ./evaluation_ascend
|
||||
cd ./evaluation_ascend || exit
|
||||
echo "start training for device id $DEVICE_ID"
|
||||
env > env.log
|
||||
python ../eval.py --platform=Ascend --dataset_path=$DATA_DIR --checkpoint_path=$CHECKPOINT_PATH > eval.log 2>&1 &
|
||||
cd ../
|
@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export RANK_SIZE=1
|
||||
export DEVICE_ID=$1
|
||||
DATA_DIR=$2
|
||||
|
||||
rm -rf train_standalone
|
||||
mkdir ./train_standalone
|
||||
cd ./train_standalone || exit
|
||||
echo "start training for device id $DEVICE_ID"
|
||||
env > env.log
|
||||
python -u ../train.py \
|
||||
--device_id=$1 \
|
||||
--dataset_path=$DATA_DIR > log.txt 2>&1 &
|
||||
cd ../
|
@ -0,0 +1,42 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""callback function"""
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
|
||||
class EvaluateCallBack(Callback):
|
||||
"""EvaluateCallBack"""
|
||||
def __init__(self, model, eval_dataset, per_print_time=1000):
|
||||
super(EvaluateCallBack, self).__init__()
|
||||
self.model = model
|
||||
self.per_print_time = per_print_time
|
||||
self.eval_dataset = eval_dataset
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
if cb_params.cur_step_num % self.per_print_time == 0:
|
||||
result = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
|
||||
print('cur epoch {}, cur_step {}, top1 accuracy {}, top5 accuracy {}.'.format(cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
result['top_1_accuracy'],
|
||||
result['top_5_accuracy']))
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
result = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
|
||||
print('cur epoch {}, cur_step {}, top1 accuracy {}, top5 accuracy {}.'.format(cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
result['top_1_accuracy'],
|
||||
result['top_5_accuracy']))
|
@ -0,0 +1,47 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in main.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
config_ascend = edict({
|
||||
'is_save_on_master': False,
|
||||
|
||||
'batch_size': 128,
|
||||
'epoch_size': 250,
|
||||
'num_classes': 1000,
|
||||
'work_nums': 8,
|
||||
|
||||
'loss_scale': 1024,
|
||||
'smooth_factor': 0.1,
|
||||
'weight_decay': 0.00004,
|
||||
'momentum': 0.9,
|
||||
'amp_level': 'O3',
|
||||
'decay': 0.9,
|
||||
'epsilon': 1.0,
|
||||
|
||||
'keep_checkpoint_max': 10,
|
||||
'save_checkpoint_epochs': 10,
|
||||
|
||||
'lr_init': 0.00004,
|
||||
'lr_end': 0.000004,
|
||||
'lr_max': 0.4,
|
||||
'warmup_epochs': 1,
|
||||
'start_epoch': 1,
|
||||
|
||||
'onnx_filename': 'inceptionv4.onnx',
|
||||
'air_filename': 'inceptionv4.air'
|
||||
})
|
@ -0,0 +1,79 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Create train or eval dataset."""
|
||||
import os
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
from src.config import config_ascend as config
|
||||
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
|
||||
|
||||
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
||||
"""
|
||||
Create a train or eval dataset.
|
||||
|
||||
Args:
|
||||
dataset_path (str): The path of dataset.
|
||||
do_train (bool): Whether dataset is used for train or eval.
|
||||
repeat_num (int): The repeat times of dataset. Default: 1.
|
||||
batch_size (int): The batch size of dataset. Default: 32.
|
||||
|
||||
Returns:
|
||||
Dataset.
|
||||
"""
|
||||
|
||||
do_shuffle = bool(do_train)
|
||||
|
||||
if device_num == 1 or not do_train:
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums, shuffle=do_shuffle)
|
||||
else:
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums,
|
||||
shuffle=do_shuffle, num_shards=device_num, shard_id=device_id)
|
||||
|
||||
image_length = 299
|
||||
if do_train:
|
||||
trans = [
|
||||
C.RandomCropDecodeResize(image_length, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
|
||||
]
|
||||
else:
|
||||
trans = [
|
||||
C.Decode(),
|
||||
C.Resize(image_length),
|
||||
C.CenterCrop(image_length)
|
||||
]
|
||||
trans += [
|
||||
C.Rescale(1.0 / 255.0, 0.0),
|
||||
C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=config.work_nums)
|
||||
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=config.work_nums)
|
||||
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
|
||||
# apply dataset repeat operation
|
||||
ds = ds.repeat(repeat_num)
|
||||
return ds
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,167 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train imagenet"""
|
||||
import os
|
||||
import argparse
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from mindspore.communication import init, get_rank
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
|
||||
from mindspore.train.model import ParallelMode
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore import Model
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.nn import RMSProp
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.common.initializer import XavierUniform, initializer
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.inceptionv4 import Inceptionv4
|
||||
from src.dataset import create_dataset, device_num
|
||||
|
||||
from src.config import config_ascend as config
|
||||
|
||||
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
|
||||
set_seed(1)
|
||||
|
||||
def generate_cosine_lr(steps_per_epoch, total_epochs,
|
||||
lr_init=config.lr_init,
|
||||
lr_end=config.lr_end,
|
||||
lr_max=config.lr_max,
|
||||
warmup_epochs=config.warmup_epochs):
|
||||
"""
|
||||
Applies cosine decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
steps_per_epoch(int): steps number per epoch
|
||||
total_epochs(int): all epoch in training.
|
||||
lr_init(float): init learning rate.
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
decay_steps = total_steps - warmup_steps
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
lr = float(lr_init) + lr_inc * (i + 1)
|
||||
else:
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * (i - warmup_steps) / decay_steps))
|
||||
lr = (lr_max - lr_end) * cosine_decay + lr_end
|
||||
lr_each_step.append(lr)
|
||||
learning_rate = np.array(lr_each_step).astype(np.float32)
|
||||
current_step = steps_per_epoch * (config.start_epoch - 1)
|
||||
learning_rate = learning_rate[current_step:]
|
||||
return learning_rate
|
||||
|
||||
|
||||
def inception_v4_train():
|
||||
"""
|
||||
Train Inceptionv4 in data parallelism
|
||||
"""
|
||||
print('epoch_size: {} batch_size: {} class_num {}'.format(config.epoch_size, config.batch_size, config.num_classes))
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(device_id=args.device_id)
|
||||
context.set_context(enable_graph_kernel=False)
|
||||
rank = 0
|
||||
if device_num > 1:
|
||||
init(backend_name='hccl')
|
||||
rank = get_rank()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True,
|
||||
all_reduce_fusion_config=[200, 400])
|
||||
|
||||
# create dataset
|
||||
train_dataset = create_dataset(dataset_path=args.dataset_path, do_train=True,
|
||||
repeat_num=1, batch_size=config.batch_size)
|
||||
train_step_size = train_dataset.get_dataset_size()
|
||||
|
||||
# create model
|
||||
net = Inceptionv4(classes=config.num_classes)
|
||||
# loss
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
# learning rate
|
||||
lr = Tensor(generate_cosine_lr(steps_per_epoch=train_step_size, total_epochs=config.epoch_size))
|
||||
|
||||
decayed_params = []
|
||||
no_decayed_params = []
|
||||
for param in net.trainable_params():
|
||||
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||
decayed_params.append(param)
|
||||
else:
|
||||
no_decayed_params.append(param)
|
||||
for param in net.trainable_params():
|
||||
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||
param.set_data(initializer(XavierUniform(), param.data.shape, param.data.dtype))
|
||||
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
|
||||
{'params': no_decayed_params},
|
||||
{'order_params': net.trainable_params()}]
|
||||
|
||||
opt = RMSProp(group_params, lr, decay=config.decay, epsilon=config.epsilon, weight_decay=config.weight_decay,
|
||||
momentum=config.momentum, loss_scale=config.loss_scale)
|
||||
|
||||
if args.device_id == 0:
|
||||
print(lr)
|
||||
print(train_step_size)
|
||||
if args.resume:
|
||||
ckpt = load_checkpoint(args.resume)
|
||||
load_param_into_net(net, ckpt)
|
||||
|
||||
loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={
|
||||
'acc', 'top_1_accuracy', 'top_5_accuracy'}, loss_scale_manager=loss_scale_manager, amp_level=config.amp_level)
|
||||
|
||||
# define callbacks
|
||||
performance_cb = TimeMonitor(data_size=train_step_size)
|
||||
loss_cb = LossMonitor(per_print_times=train_step_size)
|
||||
ckp_save_step = config.save_checkpoint_epochs * train_step_size
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=ckp_save_step, keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=f"inceptionV4-train-rank{rank}",
|
||||
directory='ckpts_rank_' + str(rank), config=config_ck)
|
||||
callbacks = [performance_cb, loss_cb]
|
||||
if device_num > 1 and config.is_save_on_master:
|
||||
if args.device_id == 0:
|
||||
callbacks.append(ckpoint_cb)
|
||||
else:
|
||||
callbacks.append(ckpoint_cb)
|
||||
|
||||
# train model
|
||||
model.train(config.epoch_size, train_dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||
|
||||
def parse_args():
|
||||
'''parse_args'''
|
||||
arg_parser = argparse.ArgumentParser(description='InceptionV4 image classification training')
|
||||
arg_parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||
arg_parser.add_argument('--device_id', type=int, default=0, help='device id')
|
||||
arg_parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
|
||||
args_opt = arg_parser.parse_args()
|
||||
return args_opt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
inception_v4_train()
|
||||
print('Inceptionv4 training success!')
|
@ -0,0 +1 @@
|
||||
# recommend
|
Loading…
Reference in new issue