|
|
|
@ -21,6 +21,19 @@ from parallel_executor_test_base import TestParallelExecutorBase
|
|
|
|
|
import unittest
|
|
|
|
|
import math
|
|
|
|
|
import os
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
# FIXME(zcd): If the neural net has dropout_op, the output of ParallelExecutor
|
|
|
|
|
# and Executor is different. Because, for ParallelExecutor, the dropout_op of
|
|
|
|
|
# the neural net will be copied N copies(N is the number of device). This will
|
|
|
|
|
# lead to the random numbers generated by ParallelExecutor and Executor are different.
|
|
|
|
|
# So, if we compare the loss of ParallelExecutor and Executor, we should remove the
|
|
|
|
|
# dropout_op.
|
|
|
|
|
remove_dropout = False
|
|
|
|
|
|
|
|
|
|
# FIXME(zcd): If the neural net has batch_norm, the output of ParallelExecutor
|
|
|
|
|
# and Executor is different.
|
|
|
|
|
remove_bn = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def squeeze_excitation(input, num_channels, reduction_ratio):
|
|
|
|
@ -53,7 +66,8 @@ def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1,
|
|
|
|
|
groups=groups,
|
|
|
|
|
act=None,
|
|
|
|
|
bias_attr=False)
|
|
|
|
|
return fluid.layers.batch_norm(input=conv, act=act, momentum=0.1)
|
|
|
|
|
return conv if remove_bn else fluid.layers.batch_norm(
|
|
|
|
|
input=conv, act=act, momentum=0.1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def shortcut(input, ch_out, stride):
|
|
|
|
@ -92,13 +106,14 @@ def bottleneck_block(input, num_filters, stride, cardinality, reduction_ratio):
|
|
|
|
|
return fluid.layers.elementwise_add(x=short, y=scale, act='relu')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def SE_ResNeXt50Small(batch_size=2, use_feed=False):
|
|
|
|
|
assert not use_feed, "SE_ResNeXt doesn't support feed yet"
|
|
|
|
|
batch_size = 12
|
|
|
|
|
img_shape = [3, 224, 224]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img = fluid.layers.fill_constant(
|
|
|
|
|
shape=[batch_size, 3, 224, 224], dtype='float32', value=0.0)
|
|
|
|
|
label = fluid.layers.fill_constant(
|
|
|
|
|
shape=[batch_size, 1], dtype='int64', value=0.0)
|
|
|
|
|
def SE_ResNeXt50Small(use_feed):
|
|
|
|
|
|
|
|
|
|
img = fluid.layers.data(name='image', shape=img_shape, dtype='float32')
|
|
|
|
|
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
|
|
|
|
|
|
|
|
|
conv = conv_bn_layer(
|
|
|
|
|
input=img, num_filters=16, filter_size=3, stride=2, act='relu')
|
|
|
|
@ -127,7 +142,8 @@ def SE_ResNeXt50Small(batch_size=2, use_feed=False):
|
|
|
|
|
reshape = fluid.layers.reshape(
|
|
|
|
|
x=conv, shape=[-1, shape[1], shape[2] * shape[3]])
|
|
|
|
|
pool = fluid.layers.reduce_mean(input=reshape, dim=2)
|
|
|
|
|
dropout = fluid.layers.dropout(x=pool, dropout_prob=0.2)
|
|
|
|
|
dropout = pool if remove_dropout else fluid.layers.dropout(
|
|
|
|
|
x=pool, dropout_prob=0.2, seed=1)
|
|
|
|
|
# Classifier layer:
|
|
|
|
|
prediction = fluid.layers.fc(input=dropout, size=1000, act='softmax')
|
|
|
|
|
loss = fluid.layers.cross_entropy(input=prediction, label=label)
|
|
|
|
@ -135,75 +151,135 @@ def SE_ResNeXt50Small(batch_size=2, use_feed=False):
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestResnet(TestParallelExecutorBase):
|
|
|
|
|
def check_resnet_convergence_with_learning_rate_decay(self,
|
|
|
|
|
use_cuda=True,
|
|
|
|
|
use_reduce=False,
|
|
|
|
|
iter=20):
|
|
|
|
|
def cosine_decay(learning_rate, step_each_epoch, epochs=120):
|
|
|
|
|
"""
|
|
|
|
|
Applies cosine decay to the learning rate.
|
|
|
|
|
lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1)
|
|
|
|
|
"""
|
|
|
|
|
global_step = _decay_step_counter()
|
|
|
|
|
|
|
|
|
|
if use_cuda and not core.is_compiled_with_cuda():
|
|
|
|
|
return
|
|
|
|
|
with init_on_cpu():
|
|
|
|
|
epoch = ops.floor(global_step / step_each_epoch)
|
|
|
|
|
decayed_lr = learning_rate * \
|
|
|
|
|
(ops.cos(epoch * (math.pi / epochs)) + 1)/2
|
|
|
|
|
return decayed_lr
|
|
|
|
|
|
|
|
|
|
os.environ['CPU_NUM'] = str(4)
|
|
|
|
|
|
|
|
|
|
def _cosine_decay(learning_rate, step_each_epoch, epochs=120):
|
|
|
|
|
"""
|
|
|
|
|
Applies cosine decay to the learning rate.
|
|
|
|
|
lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1)
|
|
|
|
|
"""
|
|
|
|
|
global_step = _decay_step_counter()
|
|
|
|
|
def optimizer(learning_rate=0.01):
|
|
|
|
|
optimizer = fluid.optimizer.Momentum(
|
|
|
|
|
learning_rate=cosine_decay(
|
|
|
|
|
learning_rate=learning_rate, step_each_epoch=2, epochs=1),
|
|
|
|
|
momentum=0.9,
|
|
|
|
|
regularization=fluid.regularizer.L2Decay(1e-4))
|
|
|
|
|
return optimizer
|
|
|
|
|
|
|
|
|
|
with init_on_cpu():
|
|
|
|
|
epoch = ops.floor(global_step / step_each_epoch)
|
|
|
|
|
decayed_lr = learning_rate * \
|
|
|
|
|
(ops.cos(epoch * (math.pi / epochs)) + 1)/2
|
|
|
|
|
return decayed_lr
|
|
|
|
|
|
|
|
|
|
def _optimizer(learning_rate=0.01):
|
|
|
|
|
optimizer = fluid.optimizer.Momentum(
|
|
|
|
|
learning_rate=_cosine_decay(
|
|
|
|
|
learning_rate=learning_rate, step_each_epoch=2, epochs=1),
|
|
|
|
|
momentum=0.9,
|
|
|
|
|
regularization=fluid.regularizer.L2Decay(1e-4))
|
|
|
|
|
return optimizer
|
|
|
|
|
class TestResnet(TestParallelExecutorBase):
|
|
|
|
|
@classmethod
|
|
|
|
|
def setUpClass(cls):
|
|
|
|
|
os.environ['CPU_NUM'] = str(4)
|
|
|
|
|
global remove_dropout
|
|
|
|
|
global remove_bn
|
|
|
|
|
remove_dropout = False
|
|
|
|
|
remove_bn = False
|
|
|
|
|
|
|
|
|
|
def _init_data(self, batch_size=2, random=True):
|
|
|
|
|
np.random.seed(5)
|
|
|
|
|
if random:
|
|
|
|
|
img = np.random.random(
|
|
|
|
|
size=[batch_size] + img_shape).astype(np.float32)
|
|
|
|
|
else:
|
|
|
|
|
img = np.ones(shape=[batch_size] + img_shape, dtype='float32')
|
|
|
|
|
label = [np.random.randint(0, 999) for _ in range(batch_size)]
|
|
|
|
|
label = np.array(label).astype(np.int64).reshape(-1, 1)
|
|
|
|
|
return img, label
|
|
|
|
|
|
|
|
|
|
def _compare_reduce_and_allreduce(self,
|
|
|
|
|
model,
|
|
|
|
|
use_cuda,
|
|
|
|
|
iter=20,
|
|
|
|
|
delta2=1e-4):
|
|
|
|
|
if use_cuda and not core.is_compiled_with_cuda():
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
import functools
|
|
|
|
|
global remove_bn
|
|
|
|
|
remove_bn = True
|
|
|
|
|
|
|
|
|
|
batch_size = 2
|
|
|
|
|
img, label = self._init_data(batch_size=batch_size)
|
|
|
|
|
all_reduce_first_loss, all_reduce_last_loss = self.check_network_convergence(
|
|
|
|
|
model,
|
|
|
|
|
feed_dict={"image": img,
|
|
|
|
|
"label": label},
|
|
|
|
|
iter=iter,
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
use_cuda=use_cuda,
|
|
|
|
|
use_reduce=False,
|
|
|
|
|
optimizer=optimizer)
|
|
|
|
|
reduce_first_loss, reduce_last_loss = self.check_network_convergence(
|
|
|
|
|
model,
|
|
|
|
|
feed_dict={"image": img,
|
|
|
|
|
"label": label},
|
|
|
|
|
iter=iter,
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
use_cuda=use_cuda,
|
|
|
|
|
use_reduce=True,
|
|
|
|
|
optimizer=optimizer)
|
|
|
|
|
|
|
|
|
|
for loss in zip(all_reduce_first_loss, reduce_first_loss):
|
|
|
|
|
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
|
|
|
|
|
for loss in zip(all_reduce_last_loss, reduce_last_loss):
|
|
|
|
|
self.assertAlmostEquals(loss[0], loss[1], delta=delta2)
|
|
|
|
|
|
|
|
|
|
def _check_resnet_convergence(self,
|
|
|
|
|
model,
|
|
|
|
|
use_cuda=True,
|
|
|
|
|
use_reduce=False,
|
|
|
|
|
iter=20,
|
|
|
|
|
delta2=1e-6):
|
|
|
|
|
if use_cuda and not core.is_compiled_with_cuda():
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
global remove_dropout
|
|
|
|
|
global remove_bn
|
|
|
|
|
remove_dropout = True
|
|
|
|
|
remove_bn = True
|
|
|
|
|
|
|
|
|
|
img, label = self._init_data(batch_size=batch_size)
|
|
|
|
|
single_first_loss, single_last_loss = self.check_network_convergence(
|
|
|
|
|
functools.partial(
|
|
|
|
|
SE_ResNeXt50Small, batch_size=batch_size),
|
|
|
|
|
model,
|
|
|
|
|
feed_dict={"image": img,
|
|
|
|
|
"label": label},
|
|
|
|
|
iter=iter,
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
use_cuda=use_cuda,
|
|
|
|
|
use_reduce=use_reduce,
|
|
|
|
|
optimizer=_optimizer,
|
|
|
|
|
optimizer=optimizer,
|
|
|
|
|
use_parallel_executor=False)
|
|
|
|
|
|
|
|
|
|
parallel_first_loss, parallel_last_loss = self.check_network_convergence(
|
|
|
|
|
functools.partial(
|
|
|
|
|
SE_ResNeXt50Small, batch_size=batch_size),
|
|
|
|
|
model,
|
|
|
|
|
feed_dict={"image": img,
|
|
|
|
|
"label": label},
|
|
|
|
|
iter=iter,
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
use_cuda=use_cuda,
|
|
|
|
|
use_reduce=use_reduce,
|
|
|
|
|
optimizer=_optimizer)
|
|
|
|
|
optimizer=optimizer)
|
|
|
|
|
|
|
|
|
|
for p_f in parallel_first_loss:
|
|
|
|
|
self.assertAlmostEquals(p_f, single_first_loss[0], delta=1e-6)
|
|
|
|
|
for p_l in parallel_last_loss:
|
|
|
|
|
self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6)
|
|
|
|
|
self.assertAlmostEquals(
|
|
|
|
|
np.mean(parallel_first_loss), single_first_loss[0], delta=1e-6)
|
|
|
|
|
self.assertAlmostEquals(
|
|
|
|
|
np.mean(parallel_last_loss), single_last_loss[0], delta=delta2)
|
|
|
|
|
|
|
|
|
|
def test_seresnext_with_learning_rate_decay(self):
|
|
|
|
|
self.check_resnet_convergence_with_learning_rate_decay(True, False)
|
|
|
|
|
self.check_resnet_convergence_with_learning_rate_decay(
|
|
|
|
|
False, False, iter=5)
|
|
|
|
|
|
|
|
|
|
def test_seresnext_with_new_strategy_with_learning_rate_decay(self):
|
|
|
|
|
self.check_resnet_convergence_with_learning_rate_decay(True, True)
|
|
|
|
|
self.check_resnet_convergence_with_learning_rate_decay(
|
|
|
|
|
False, True, iter=5)
|
|
|
|
|
self._check_resnet_convergence(model=SE_ResNeXt50Small, use_cuda=True)
|
|
|
|
|
self._check_resnet_convergence(
|
|
|
|
|
model=SE_ResNeXt50Small, use_cuda=False, iter=2, delta2=1e-3)
|
|
|
|
|
|
|
|
|
|
def test_seresnext_with_new_strategy(self):
|
|
|
|
|
# self._compare_reduce_and_allreduce(
|
|
|
|
|
# model=SE_ResNeXt50Small, use_cuda=True)
|
|
|
|
|
self._compare_reduce_and_allreduce(
|
|
|
|
|
model=SE_ResNeXt50Small, use_cuda=False, iter=5, delta2=1e-2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|