[Dy2stat] Add Yolov3 as Unit Test (#24879)

Add Yolov3 as ProgramTranslator Unit Test. The YoloV3 code is referred from PaddlePaddle/models/dygraph/yolov3
revert-24981-add_device_attr_for_regulization
Huihuang Zheng 5 years ago committed by GitHub
parent 29de0d97a5
commit 5a5497a5e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,184 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.dygraph.nn import Conv2D, BatchNorm
from paddle.fluid.dygraph.base import to_variable
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
ch_in,
ch_out,
filter_size=3,
stride=1,
groups=1,
padding=0,
act="leaky",
is_test=True):
super(ConvBNLayer, self).__init__()
self.conv = Conv2D(
num_channels=ch_in,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=groups,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02)),
bias_attr=False,
act=None)
self.batch_norm = BatchNorm(
num_channels=ch_out,
is_test=is_test,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02),
regularizer=L2Decay(0.)),
bias_attr=ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=L2Decay(0.)))
self.act = act
def forward(self, inputs):
out = self.conv(inputs)
out = self.batch_norm(out)
if self.act == 'leaky':
out = fluid.layers.leaky_relu(x=out, alpha=0.1)
return out
class DownSample(fluid.dygraph.Layer):
def __init__(self,
ch_in,
ch_out,
filter_size=3,
stride=2,
padding=1,
is_test=True):
super(DownSample, self).__init__()
self.conv_bn_layer = ConvBNLayer(
ch_in=ch_in,
ch_out=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
is_test=is_test)
self.ch_out = ch_out
def forward(self, inputs):
out = self.conv_bn_layer(inputs)
return out
class BasicBlock(fluid.dygraph.Layer):
def __init__(self, ch_in, ch_out, is_test=True):
super(BasicBlock, self).__init__()
self.conv1 = ConvBNLayer(
ch_in=ch_in,
ch_out=ch_out,
filter_size=1,
stride=1,
padding=0,
is_test=is_test)
self.conv2 = ConvBNLayer(
ch_in=ch_out,
ch_out=ch_out * 2,
filter_size=3,
stride=1,
padding=1,
is_test=is_test)
def forward(self, inputs):
conv1 = self.conv1(inputs)
conv2 = self.conv2(conv1)
out = fluid.layers.elementwise_add(x=inputs, y=conv2, act=None)
return out
class LayerWarp(fluid.dygraph.Layer):
def __init__(self, ch_in, ch_out, count, is_test=True):
super(LayerWarp, self).__init__()
self.basicblock0 = BasicBlock(ch_in, ch_out, is_test=is_test)
self.res_out_list = []
for i in range(1, count):
res_out = self.add_sublayer(
"basic_block_%d" % (i),
BasicBlock(
ch_out * 2, ch_out, is_test=is_test))
self.res_out_list.append(res_out)
self.ch_out = ch_out
def forward(self, inputs):
y = self.basicblock0(inputs)
for basic_block_i in self.res_out_list:
y = basic_block_i(y)
return y
DarkNet_cfg = {53: ([1, 2, 8, 8, 4])}
class DarkNet53_conv_body(fluid.dygraph.Layer):
def __init__(self, ch_in=3, is_test=True):
super(DarkNet53_conv_body, self).__init__()
self.stages = DarkNet_cfg[53]
self.stages = self.stages[0:5]
self.conv0 = ConvBNLayer(
ch_in=ch_in,
ch_out=32,
filter_size=3,
stride=1,
padding=1,
is_test=is_test)
self.downsample0 = DownSample(ch_in=32, ch_out=32 * 2, is_test=is_test)
self.darknet53_conv_block_list = []
self.downsample_list = []
ch_in = [64, 128, 256, 512, 1024]
for i, stage in enumerate(self.stages):
conv_block = self.add_sublayer(
"stage_%d" % (i),
LayerWarp(
int(ch_in[i]), 32 * (2**i), stage, is_test=is_test))
self.darknet53_conv_block_list.append(conv_block)
for i in range(len(self.stages) - 1):
downsample = self.add_sublayer(
"stage_%d_downsample" % i,
DownSample(
ch_in=32 * (2**(i + 1)),
ch_out=32 * (2**(i + 2)),
is_test=is_test))
self.downsample_list.append(downsample)
def forward(self, inputs):
out = self.conv0(inputs)
out = self.downsample0(out)
blocks = []
for i, conv_block_i in enumerate(self.darknet53_conv_block_list):
out = conv_block_i(out)
blocks.append(out)
if i < len(self.stages) - 1:
out = self.downsample_list[i](out)
return blocks[-1:-4:-1]

@ -0,0 +1,174 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import random
import time
import unittest
import paddle.fluid as fluid
from paddle.fluid.dygraph import ProgramTranslator
from paddle.fluid.dygraph import to_variable
from yolov3 import cfg, YOLOv3
random.seed(0)
np.random.seed(0)
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self):
self.loss_sum = 0.0
self.iter_cnt = 0
def add_value(self, value):
self.loss_sum += np.mean(value)
self.iter_cnt += 1
def get_mean_value(self):
return self.loss_sum / self.iter_cnt
class FakeDataReader(object):
def __init__(self):
self.generator_out = []
self.total_iter = cfg.max_iter
for i in range(self.total_iter):
batch_out = []
for j in range(cfg.batch_size):
img = np.random.normal(0.485, 0.229,
[3, cfg.input_size, cfg.input_size])
gt_boxes_node1 = np.random.randint(
low=cfg.input_size / 4,
high=cfg.input_size / 2,
size=[1, 2])
gt_boxes_node2 = gt_boxes_node1 + cfg.input_size / 4
gt_boxes = np.concatenate(
(gt_boxes_node1, gt_boxes_node2), axis=1)
gt_labels = np.random.randint(
low=0, high=cfg.class_num, size=[1])
gt_scores = np.zeros([1])
batch_out.append([img, gt_boxes, gt_labels, gt_scores])
self.generator_out.append(batch_out)
def reader(self):
def generator():
for i in range(self.total_iter):
yield self.generator_out[i]
return generator
fake_data_reader = FakeDataReader()
def train(to_static):
program_translator = ProgramTranslator()
program_translator.enable(to_static)
random.seed(0)
np.random.seed(0)
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = 1000
fluid.default_main_program().random_seed = 1000
model = YOLOv3(3, is_train=True)
boundaries = cfg.lr_steps
gamma = cfg.lr_gamma
step_num = len(cfg.lr_steps)
learning_rate = cfg.learning_rate
values = [learning_rate * (gamma**i) for i in range(step_num + 1)]
lr = fluid.dygraph.PiecewiseDecay(
boundaries=boundaries, values=values, begin=0)
lr = fluid.layers.linear_lr_warmup(
learning_rate=lr,
warmup_steps=cfg.warm_up_iter,
start_lr=0.0,
end_lr=cfg.learning_rate, )
optimizer = fluid.optimizer.Momentum(
learning_rate=lr,
regularization=fluid.regularizer.L2Decay(cfg.weight_decay),
momentum=cfg.momentum,
parameter_list=model.parameters())
start_time = time.time()
snapshot_loss = 0
snapshot_time = 0
total_sample = 0
input_size = cfg.input_size
shuffle = True
shuffle_seed = None
total_iter = cfg.max_iter
mixup_iter = total_iter - cfg.no_mixup_iter
train_reader = FakeDataReader().reader()
smoothed_loss = SmoothedValue()
ret = []
for iter_id, data in enumerate(train_reader()):
prev_start_time = start_time
start_time = time.time()
img = np.array([x[0] for x in data]).astype('float32')
img = to_variable(img)
gt_box = np.array([x[1] for x in data]).astype('float32')
gt_box = to_variable(gt_box)
gt_label = np.array([x[2] for x in data]).astype('int32')
gt_label = to_variable(gt_label)
gt_score = np.array([x[3] for x in data]).astype('float32')
gt_score = to_variable(gt_score)
loss = model(img, gt_box, gt_label, gt_score, None, None)
smoothed_loss.add_value(np.mean(loss.numpy()))
snapshot_loss += loss.numpy()
snapshot_time += start_time - prev_start_time
total_sample += 1
print("Iter {:d}, loss {:.6f}, time {:.5f}".format(
iter_id,
smoothed_loss.get_mean_value(), start_time - prev_start_time))
ret.append(smoothed_loss.get_mean_value())
loss.backward()
optimizer.minimize(loss)
model.clear_gradients()
return np.array(ret)
class TestYolov3(unittest.TestCase):
def test_dygraph_static_same_loss(self):
dygraph_loss = train(to_static=False)
static_loss = train(to_static=True)
self.assertTrue(
np.allclose(dygraph_loss, static_loss),
msg="dygraph_loss: {} \nstatic_loss: {}".format(dygraph_loss,
static_loss))
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save