You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
166 lines
7.0 KiB
166 lines
7.0 KiB
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
|
|
"""
|
|
######################## train YOLOv3 example ########################
|
|
train YOLOv3 and get network model files(.ckpt) :
|
|
python train.py --image_dir /data --anno_path /data/coco/train_coco.txt --mindrecord_dir=/data/Mindrecord_train
|
|
|
|
If the mindrecord_dir is empty, it wil generate mindrecord file by image_dir and anno_path.
|
|
Note if mindrecord_dir isn't empty, it will use mindrecord_dir rather than image_dir and anno_path.
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
import pytest
|
|
import numpy as np
|
|
import mindspore.nn as nn
|
|
from mindspore import context, Tensor
|
|
from mindspore.train import Model
|
|
from mindspore.common.initializer import initializer
|
|
from mindspore.train.callback import Callback
|
|
|
|
from src.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper
|
|
from src.dataset import create_yolo_dataset
|
|
from src.config import ConfigYOLOV3ResNet18
|
|
|
|
np.random.seed(1)
|
|
def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False):
|
|
"""Set learning rate."""
|
|
lr_each_step = []
|
|
for i in range(global_step):
|
|
if steps:
|
|
lr_each_step.append(learning_rate * (decay_rate ** (i // decay_step)))
|
|
else:
|
|
lr_each_step.append(learning_rate * (decay_rate ** (i / decay_step)))
|
|
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
|
lr_each_step = lr_each_step[start_step:]
|
|
return lr_each_step
|
|
|
|
|
|
def init_net_param(network, init_value='ones'):
|
|
"""Init:wq the parameters in network."""
|
|
params = network.trainable_params()
|
|
for p in params:
|
|
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
|
p.set_data(initializer(init_value, p.data.shape, p.data.dtype))
|
|
|
|
class ModelCallback(Callback):
|
|
def __init__(self):
|
|
super(ModelCallback, self).__init__()
|
|
self.loss_list = []
|
|
|
|
def step_end(self, run_context):
|
|
cb_params = run_context.original_args()
|
|
self.loss_list.append(cb_params.net_outputs.asnumpy())
|
|
print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs)))
|
|
|
|
class TimeMonitor(Callback):
|
|
"""Time Monitor."""
|
|
def __init__(self, data_size):
|
|
super(TimeMonitor, self).__init__()
|
|
self.data_size = data_size
|
|
self.epoch_mseconds_list = []
|
|
self.per_step_mseconds_list = []
|
|
def epoch_begin(self, run_context):
|
|
self.epoch_time = time.time()
|
|
|
|
def epoch_end(self, run_context):
|
|
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
|
self.epoch_mseconds_list.append(epoch_mseconds)
|
|
self.per_step_mseconds_list.append(epoch_mseconds / self.data_size)
|
|
|
|
DATA_DIR = "/home/workspace/mindspore_dataset/coco/coco2017/mindrecord_train/yolov3"
|
|
|
|
@pytest.mark.level0
|
|
@pytest.mark.platform_arm_ascend_training
|
|
@pytest.mark.platform_x86_ascend_training
|
|
@pytest.mark.env_single
|
|
def test_yolov3():
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
|
rank = 0
|
|
device_num = 1
|
|
lr_init = 0.001
|
|
epoch_size = 5
|
|
batch_size = 32
|
|
loss_scale = 1024
|
|
mindrecord_dir = DATA_DIR
|
|
|
|
# It will generate mindrecord file in args_opt.mindrecord_dir,
|
|
# and the file name is yolo.mindrecord0, 1, ... file_num.
|
|
if not os.path.isdir(mindrecord_dir):
|
|
raise KeyError("mindrecord path is not exist.")
|
|
|
|
prefix = "yolo.mindrecord"
|
|
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
|
|
print("yolov3 mindrecord is ", mindrecord_file)
|
|
if not os.path.exists(mindrecord_file):
|
|
print("mindrecord file is not exist.")
|
|
assert False
|
|
else:
|
|
loss_scale = float(loss_scale)
|
|
|
|
# When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0.
|
|
dataset = create_yolo_dataset(mindrecord_file, repeat_num=1,
|
|
batch_size=batch_size, device_num=device_num, rank=rank)
|
|
dataset_size = dataset.get_dataset_size()
|
|
print("Create dataset done!")
|
|
|
|
net = yolov3_resnet18(ConfigYOLOV3ResNet18())
|
|
net = YoloWithLossCell(net, ConfigYOLOV3ResNet18())
|
|
|
|
total_epoch_size = 60
|
|
lr = Tensor(get_lr(learning_rate=lr_init, start_step=0,
|
|
global_step=total_epoch_size * dataset_size,
|
|
decay_step=1000, decay_rate=0.95, steps=True))
|
|
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale)
|
|
net = TrainingWrapper(net, opt, loss_scale)
|
|
|
|
model_callback = ModelCallback()
|
|
time_monitor_callback = TimeMonitor(data_size=dataset_size)
|
|
callback = [model_callback, time_monitor_callback]
|
|
|
|
model = Model(net)
|
|
print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.")
|
|
model.train(epoch_size, dataset, callbacks=callback, dataset_sink_mode=True,
|
|
sink_size=dataset.get_dataset_size())
|
|
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
|
loss_value = np.array(model_callback.loss_list)
|
|
|
|
expect_loss_value = [6850, 4250, 2750]
|
|
print("loss value: {}".format(loss_value))
|
|
assert loss_value[0] < expect_loss_value[0]
|
|
assert loss_value[1] < expect_loss_value[1]
|
|
assert loss_value[2] < expect_loss_value[2]
|
|
|
|
epoch_mseconds0 = np.array(time_monitor_callback.epoch_mseconds_list)[2]
|
|
epoch_mseconds1 = np.array(time_monitor_callback.epoch_mseconds_list)[3]
|
|
epoch_mseconds2 = np.array(time_monitor_callback.epoch_mseconds_list)[4]
|
|
expect_epoch_mseconds = 1250
|
|
print("epoch mseconds: {}".format(epoch_mseconds0))
|
|
assert epoch_mseconds0 <= expect_epoch_mseconds or \
|
|
epoch_mseconds1 <= expect_epoch_mseconds or \
|
|
epoch_mseconds2 <= expect_epoch_mseconds
|
|
|
|
per_step_mseconds0 = np.array(time_monitor_callback.per_step_mseconds_list)[2]
|
|
per_step_mseconds1 = np.array(time_monitor_callback.per_step_mseconds_list)[3]
|
|
per_step_mseconds2 = np.array(time_monitor_callback.per_step_mseconds_list)[4]
|
|
expect_per_step_mseconds = 130
|
|
print("per step mseconds: {}".format(per_step_mseconds0))
|
|
assert per_step_mseconds0 <= expect_per_step_mseconds or \
|
|
per_step_mseconds1 <= expect_per_step_mseconds or \
|
|
per_step_mseconds2 <= expect_per_step_mseconds
|
|
print("yolov3 test case passed.")
|