|
|
|
@ -205,7 +205,8 @@ class TestParallelExecutorBase(unittest.TestCase):
|
|
|
|
|
allow_op_delay=False,
|
|
|
|
|
feed_dict=None,
|
|
|
|
|
seed=None,
|
|
|
|
|
use_parallel_executor=True):
|
|
|
|
|
use_parallel_executor=True,
|
|
|
|
|
use_nccl_allreduce=True):
|
|
|
|
|
def run_executor(exe, feed, fetch_list, program=None):
|
|
|
|
|
if isinstance(exe, fluid.ParallelExecutor):
|
|
|
|
|
res = exe.run(fetch_list=fetch_list, feed=feed)
|
|
|
|
@ -234,7 +235,10 @@ class TestParallelExecutorBase(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
if use_parallel_executor:
|
|
|
|
|
exe = fluid.ParallelExecutor(
|
|
|
|
|
True, loss_name=loss.name, allow_op_delay=allow_op_delay)
|
|
|
|
|
True,
|
|
|
|
|
loss_name=loss.name,
|
|
|
|
|
allow_op_delay=allow_op_delay,
|
|
|
|
|
use_nccl_allreduce=use_nccl_allreduce)
|
|
|
|
|
else:
|
|
|
|
|
exe = fluid.Executor(place=place)
|
|
|
|
|
|
|
|
|
@ -280,17 +284,25 @@ class TestMNIST(TestParallelExecutorBase):
|
|
|
|
|
fluid.recordio_writer.convert_reader_to_recordio_file(
|
|
|
|
|
'./mnist.recordio', reader, feeder)
|
|
|
|
|
|
|
|
|
|
def test_simple_fc(self):
|
|
|
|
|
def check_simple_fc_convergence(self, use_nccl_allreduce=True):
|
|
|
|
|
self.check_network_convergence(simple_fc_net)
|
|
|
|
|
self.check_network_convergence(simple_fc_net, allow_op_delay=True)
|
|
|
|
|
|
|
|
|
|
img = numpy.zeros(shape=[32, 784], dtype='float32')
|
|
|
|
|
label = numpy.ones(shape=[32, 1], dtype='int64')
|
|
|
|
|
self.check_network_convergence(
|
|
|
|
|
simple_fc_net, feed_dict={"image": img,
|
|
|
|
|
"label": label})
|
|
|
|
|
simple_fc_net,
|
|
|
|
|
feed_dict={"image": img,
|
|
|
|
|
"label": label},
|
|
|
|
|
use_nccl_allreduce=use_nccl_allreduce)
|
|
|
|
|
|
|
|
|
|
def test_simple_fc_with_nccl_allreduce(self):
|
|
|
|
|
self.check_simple_fc_convergence(True)
|
|
|
|
|
|
|
|
|
|
def test_simple_fc_parallel_accuracy(self):
|
|
|
|
|
def test_simple_fc_with_reduce_op(self):
|
|
|
|
|
self.check_simple_fc_convergence(False)
|
|
|
|
|
|
|
|
|
|
def check_simple_fc_parallel_accuracy(self, use_nccl_allreduce=True):
|
|
|
|
|
img = numpy.zeros(shape=[32, 784], dtype='float32')
|
|
|
|
|
label = numpy.ones(shape=[32, 1], dtype='int64')
|
|
|
|
|
single_first_loss, single_last_loss = self.check_network_convergence(
|
|
|
|
@ -304,20 +316,35 @@ class TestMNIST(TestParallelExecutorBase):
|
|
|
|
|
seed=1000,
|
|
|
|
|
feed_dict={"image": img,
|
|
|
|
|
"label": label},
|
|
|
|
|
use_parallel_executor=True)
|
|
|
|
|
use_parallel_executor=True,
|
|
|
|
|
use_nccl_allreduce=use_nccl_allreduce)
|
|
|
|
|
|
|
|
|
|
for p_f in parallel_first_loss:
|
|
|
|
|
self.assertAlmostEquals(p_f, single_first_loss[0], delta=1e-6)
|
|
|
|
|
for p_l in parallel_last_loss:
|
|
|
|
|
self.assertAlmostEquals(p_l, single_last_loss[0], delta=1e-6)
|
|
|
|
|
|
|
|
|
|
def test_batchnorm_fc(self):
|
|
|
|
|
def test_simple_fc_parallel_accuracy_with_nccl_allreduce(self):
|
|
|
|
|
self.check_simple_fc_parallel_accuracy(True)
|
|
|
|
|
|
|
|
|
|
def test_simple_fc_parallel_accuracy_with_reduce_op(self):
|
|
|
|
|
self.check_simple_fc_parallel_accuracy(False)
|
|
|
|
|
|
|
|
|
|
def check_batchnorm_fc_convergence(self, use_nccl_allreduce):
|
|
|
|
|
self.check_network_convergence(fc_with_batchnorm)
|
|
|
|
|
img = numpy.zeros(shape=[32, 784], dtype='float32')
|
|
|
|
|
label = numpy.ones(shape=[32, 1], dtype='int64')
|
|
|
|
|
self.check_network_convergence(
|
|
|
|
|
fc_with_batchnorm, feed_dict={"image": img,
|
|
|
|
|
"label": label})
|
|
|
|
|
fc_with_batchnorm,
|
|
|
|
|
feed_dict={"image": img,
|
|
|
|
|
"label": label},
|
|
|
|
|
use_nccl_allreduce=use_nccl_allreduce)
|
|
|
|
|
|
|
|
|
|
def test_batchnorm_fc_with_nccl_allreduce(self):
|
|
|
|
|
self.check_batchnorm_fc_convergence(True)
|
|
|
|
|
|
|
|
|
|
def test_batchnorm_fc_with_reduce_op(self):
|
|
|
|
|
self.check_batchnorm_fc_convergence(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestResnet(TestParallelExecutorBase):
|
|
|
|
@ -339,14 +366,21 @@ class TestResnet(TestParallelExecutorBase):
|
|
|
|
|
# fluid.recordio_writer.convert_reader_to_recordio_file(
|
|
|
|
|
# "./flowers.recordio", reader, feeder, compressor=fluid.core.RecordIOWriter.Compressor.NoCompress)
|
|
|
|
|
|
|
|
|
|
def test_resnet(self):
|
|
|
|
|
def check_resnet_convergence(self, use_nccl_allreduce):
|
|
|
|
|
import functools
|
|
|
|
|
batch_size = 2
|
|
|
|
|
self.check_network_convergence(
|
|
|
|
|
functools.partial(
|
|
|
|
|
SE_ResNeXt50Small, batch_size=batch_size),
|
|
|
|
|
iter=20,
|
|
|
|
|
batch_size=batch_size)
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
use_nccl_allreduce=use_nccl_allreduce)
|
|
|
|
|
|
|
|
|
|
def test_resnet_with_nccl_allreduce(self):
|
|
|
|
|
self.check_resnet_convergence(True)
|
|
|
|
|
|
|
|
|
|
def test_resnet_with_reduce_op(self):
|
|
|
|
|
self.check_resnet_convergence(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelHyperParams(object):
|
|
|
|
@ -510,7 +544,7 @@ class TestTransformer(TestParallelExecutorBase):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParallelExecutorTestingDuringTraining(unittest.TestCase):
|
|
|
|
|
def test_parallel_testing(self):
|
|
|
|
|
def check_network_convergence(self, use_nccl_allreduce):
|
|
|
|
|
main = fluid.Program()
|
|
|
|
|
startup = fluid.Program()
|
|
|
|
|
with fluid.program_guard(main, startup):
|
|
|
|
@ -531,12 +565,16 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
|
|
|
|
|
feed_dict = {'image': image, 'label': label}
|
|
|
|
|
|
|
|
|
|
train_exe = fluid.ParallelExecutor(
|
|
|
|
|
use_cuda=True, loss_name=loss.name, main_program=main)
|
|
|
|
|
use_cuda=True,
|
|
|
|
|
loss_name=loss.name,
|
|
|
|
|
main_program=main,
|
|
|
|
|
use_nccl_allreduce=use_nccl_allreduce)
|
|
|
|
|
|
|
|
|
|
test_exe = fluid.ParallelExecutor(
|
|
|
|
|
use_cuda=True,
|
|
|
|
|
main_program=test_program,
|
|
|
|
|
share_vars_from=train_exe)
|
|
|
|
|
share_vars_from=train_exe,
|
|
|
|
|
use_nccl_allreduce=use_nccl_allreduce)
|
|
|
|
|
|
|
|
|
|
for i in xrange(5):
|
|
|
|
|
test_loss, = test_exe.run([loss.name], feed=feed_dict)
|
|
|
|
@ -550,6 +588,12 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
|
|
|
|
|
"Train loss: " + str(train_loss) + "\n Test loss:" +
|
|
|
|
|
str(test_loss))
|
|
|
|
|
|
|
|
|
|
def test_parallel_testing_with_nccl_allreduce(self):
|
|
|
|
|
self.check_network_convergence(use_nccl_allreduce=True)
|
|
|
|
|
|
|
|
|
|
def test_parallel_testing_with_reduce_op(self):
|
|
|
|
|
self.check_network_convergence(use_nccl_allreduce=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import paddle.dataset.conll05 as conll05
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
@ -568,21 +612,26 @@ embedding_name = 'emb'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
|
|
|
|
|
**ignored):
|
|
|
|
|
is_sparse, use_nccl_allreduce, **ignored):
|
|
|
|
|
# 8 features
|
|
|
|
|
predicate_embedding = fluid.layers.embedding(
|
|
|
|
|
input=predicate,
|
|
|
|
|
is_sparse=is_sparse,
|
|
|
|
|
size=[pred_dict_len, word_dim],
|
|
|
|
|
dtype='float32',
|
|
|
|
|
param_attr='vemb')
|
|
|
|
|
|
|
|
|
|
mark_embedding = fluid.layers.embedding(
|
|
|
|
|
input=mark, size=[mark_dict_len, mark_dim], dtype='float32')
|
|
|
|
|
input=mark,
|
|
|
|
|
is_sparse=is_sparse,
|
|
|
|
|
size=[mark_dict_len, mark_dim],
|
|
|
|
|
dtype='float32')
|
|
|
|
|
|
|
|
|
|
word_input = [word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2]
|
|
|
|
|
emb_layers = [
|
|
|
|
|
fluid.layers.embedding(
|
|
|
|
|
size=[word_dict_len, word_dim],
|
|
|
|
|
is_sparse=is_sparse,
|
|
|
|
|
input=x,
|
|
|
|
|
param_attr=fluid.ParamAttr(
|
|
|
|
|
name=embedding_name, trainable=False)) for x in word_input
|
|
|
|
@ -632,7 +681,7 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestCRFModel(unittest.TestCase):
|
|
|
|
|
def test_all(self):
|
|
|
|
|
def check_network_convergence(self, is_sparse, use_nccl_allreduce):
|
|
|
|
|
main = fluid.Program()
|
|
|
|
|
startup = fluid.Program()
|
|
|
|
|
with fluid.program_guard(main, startup):
|
|
|
|
@ -652,6 +701,7 @@ class TestCRFModel(unittest.TestCase):
|
|
|
|
|
name='ctx_p2_data', shape=[1], dtype='int64', lod_level=1)
|
|
|
|
|
mark = fluid.layers.data(
|
|
|
|
|
name='mark_data', shape=[1], dtype='int64', lod_level=1)
|
|
|
|
|
|
|
|
|
|
feature_out = db_lstm(**locals())
|
|
|
|
|
target = fluid.layers.data(
|
|
|
|
|
name='target', shape=[1], dtype='int64', lod_level=1)
|
|
|
|
@ -679,7 +729,10 @@ class TestCRFModel(unittest.TestCase):
|
|
|
|
|
exe = fluid.Executor(place)
|
|
|
|
|
exe.run(startup)
|
|
|
|
|
|
|
|
|
|
pe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name)
|
|
|
|
|
pe = fluid.ParallelExecutor(
|
|
|
|
|
use_cuda=True,
|
|
|
|
|
loss_name=avg_cost.name,
|
|
|
|
|
use_nccl_allreduce=use_nccl_allreduce)
|
|
|
|
|
|
|
|
|
|
feeder = fluid.DataFeeder(
|
|
|
|
|
feed_list=[
|
|
|
|
@ -694,3 +747,13 @@ class TestCRFModel(unittest.TestCase):
|
|
|
|
|
print map(numpy.array,
|
|
|
|
|
pe.run(feed=feeder.feed(cur_batch),
|
|
|
|
|
fetch_list=[avg_cost.name]))[0]
|
|
|
|
|
|
|
|
|
|
def test_update_sparse_parameter(self):
|
|
|
|
|
self.check_network_convergence(is_sparse=True, use_nccl_allreduce=False)
|
|
|
|
|
|
|
|
|
|
def test_update_dense_parameter_with_nccl_allreduce(self):
|
|
|
|
|
self.check_network_convergence(is_sparse=False, use_nccl_allreduce=True)
|
|
|
|
|
|
|
|
|
|
def test_update_dense_parameter_with_reduce_op(self):
|
|
|
|
|
self.check_network_convergence(
|
|
|
|
|
is_sparse=False, use_nccl_allreduce=False)
|
|
|
|
|