You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/tests/st/model_zoo_tests/yolov3/test_yolov3.py

166 lines
7.0 KiB

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
######################## train YOLOv3 example ########################
train YOLOv3 and get network model files(.ckpt) :
python train.py --image_dir /data --anno_path /data/coco/train_coco.txt --mindrecord_dir=/data/Mindrecord_train
If the mindrecord_dir is empty, it wil generate mindrecord file by image_dir and anno_path.
Note if mindrecord_dir isn't empty, it will use mindrecord_dir rather than image_dir and anno_path.
"""
import os
import time
import pytest
import numpy as np
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.train import Model
from mindspore.common.initializer import initializer
from mindspore.train.callback import Callback
from src.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper
from src.dataset import create_yolo_dataset
from src.config import ConfigYOLOV3ResNet18
np.random.seed(1)
def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False):
"""Set learning rate."""
lr_each_step = []
for i in range(global_step):
if steps:
lr_each_step.append(learning_rate * (decay_rate ** (i // decay_step)))
else:
lr_each_step.append(learning_rate * (decay_rate ** (i / decay_step)))
lr_each_step = np.array(lr_each_step).astype(np.float32)
lr_each_step = lr_each_step[start_step:]
return lr_each_step
def init_net_param(network, init_value='ones'):
"""Init:wq the parameters in network."""
params = network.trainable_params()
for p in params:
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
4 years ago
p.set_data(initializer(init_value, p.data.shape, p.data.dtype))
class ModelCallback(Callback):
def __init__(self):
super(ModelCallback, self).__init__()
self.loss_list = []
def step_end(self, run_context):
cb_params = run_context.original_args()
self.loss_list.append(cb_params.net_outputs.asnumpy())
print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs)))
class TimeMonitor(Callback):
"""Time Monitor."""
def __init__(self, data_size):
super(TimeMonitor, self).__init__()
self.data_size = data_size
self.epoch_mseconds_list = []
self.per_step_mseconds_list = []
def epoch_begin(self, run_context):
self.epoch_time = time.time()
def epoch_end(self, run_context):
epoch_mseconds = (time.time() - self.epoch_time) * 1000
self.epoch_mseconds_list.append(epoch_mseconds)
self.per_step_mseconds_list.append(epoch_mseconds / self.data_size)
DATA_DIR = "/home/workspace/mindspore_dataset/coco/coco2017/mindrecord_train/yolov3"
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_yolov3():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
rank = 0
device_num = 1
lr_init = 0.001
epoch_size = 5
batch_size = 32
loss_scale = 1024
mindrecord_dir = DATA_DIR
# It will generate mindrecord file in args_opt.mindrecord_dir,
# and the file name is yolo.mindrecord0, 1, ... file_num.
if not os.path.isdir(mindrecord_dir):
raise KeyError("mindrecord path is not exist.")
prefix = "yolo.mindrecord"
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
print("yolov3 mindrecord is ", mindrecord_file)
if not os.path.exists(mindrecord_file):
print("mindrecord file is not exist.")
assert False
else:
loss_scale = float(loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0.
dataset = create_yolo_dataset(mindrecord_file, repeat_num=1,
batch_size=batch_size, device_num=device_num, rank=rank)
dataset_size = dataset.get_dataset_size()
print("Create dataset done!")
net = yolov3_resnet18(ConfigYOLOV3ResNet18())
net = YoloWithLossCell(net, ConfigYOLOV3ResNet18())
total_epoch_size = 60
lr = Tensor(get_lr(learning_rate=lr_init, start_step=0,
global_step=total_epoch_size * dataset_size,
decay_step=1000, decay_rate=0.95, steps=True))
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale)
net = TrainingWrapper(net, opt, loss_scale)
model_callback = ModelCallback()
time_monitor_callback = TimeMonitor(data_size=dataset_size)
callback = [model_callback, time_monitor_callback]
model = Model(net)
print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.")
model.train(epoch_size, dataset, callbacks=callback, dataset_sink_mode=True,
sink_size=dataset.get_dataset_size())
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(model_callback.loss_list)
expect_loss_value = [6850, 4250, 2750]
print("loss value: {}".format(loss_value))
assert loss_value[0] < expect_loss_value[0]
assert loss_value[1] < expect_loss_value[1]
assert loss_value[2] < expect_loss_value[2]
epoch_mseconds0 = np.array(time_monitor_callback.epoch_mseconds_list)[2]
epoch_mseconds1 = np.array(time_monitor_callback.epoch_mseconds_list)[3]
epoch_mseconds2 = np.array(time_monitor_callback.epoch_mseconds_list)[4]
expect_epoch_mseconds = 1250
print("epoch mseconds: {}".format(epoch_mseconds0))
assert epoch_mseconds0 <= expect_epoch_mseconds or \
epoch_mseconds1 <= expect_epoch_mseconds or \
epoch_mseconds2 <= expect_epoch_mseconds
per_step_mseconds0 = np.array(time_monitor_callback.per_step_mseconds_list)[2]
per_step_mseconds1 = np.array(time_monitor_callback.per_step_mseconds_list)[3]
per_step_mseconds2 = np.array(time_monitor_callback.per_step_mseconds_list)[4]
expect_per_step_mseconds = 130
print("per step mseconds: {}".format(per_step_mseconds0))
assert per_step_mseconds0 <= expect_per_step_mseconds or \
per_step_mseconds1 <= expect_per_step_mseconds or \
per_step_mseconds2 <= expect_per_step_mseconds
print("yolov3 test case passed.")