From 8a521c0b4da5118098d57e34c1c4150e276f140a Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 16 Jul 2018 17:44:14 +0800 Subject: [PATCH 1/4] Remove buggy get_test_program and refine c++ reader demo --- python/paddle/fluid/io.py | 98 --------------- .../convert_data_to_recordio.py | 8 +- .../tests/demo/text_classification/train.py | 115 +++++++++--------- 3 files changed, 62 insertions(+), 159 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 0eb1194e27..32368d3c0c 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -789,101 +789,3 @@ def get_parameter_value_by_name(name, executor, program=None): program = default_main_program() var = program.global_block().var(name) return get_parameter_value(var, executor) - - -def get_test_program(filelist, program=None, startup_program=None): - """ - Transpile current train program to a program to read test dataset - if the program is using reader ops like "open_files_op". - """ - - def _copy_reader_var_(block, var, new_name=None): - if new_name == None: - new_name = var.name - new_var = block.create_var( - name=str(new_name), type=core.VarDesc.VarType.READER) - new_var.desc.set_shapes(var.desc.shapes()) - new_var.desc.set_dtypes(var.desc.dtypes()) - new_var.persistable = True - return new_var - - def _get_test_reader_name(train_reader_name): - return train_reader_name + "_test" - - def _is_reader_op(op): - block = op.block - if "Out" in op.output_names: - reader_out = block.vars[op.output("Out")[0]] - if reader_out.type == core.VarDesc.VarType.READER: - return True - return False - - if program == None: - program = default_main_program() - if startup_program == None: - startup_program = default_startup_program() - startup_block = startup_program.global_block() - - # 1. find out the orignal reader var name - startup_reader_op_list = [] - - for op in startup_block.ops: - if _is_reader_op(op): - startup_reader_op_list.append(op) - - if len(startup_reader_op_list) == 0: - return program - - root_reader_op = startup_reader_op_list[0] - train_test_reader_map = {} - # 2. add operators to startup to read open and read test data files - for op in startup_reader_op_list: - assert (len(op.output("Out")) == 1) - train_reader_name = op.output("Out")[0] - train_reader = startup_block.vars[train_reader_name] - test_reader = _copy_reader_var_( - startup_block, - train_reader, - new_name=_get_test_reader_name(train_reader_name)) - train_test_reader_map[train_reader.name] = test_reader - - test_op_inputs = {} - for name in op.input_names: - train_arg_names = op.input(name) - test_arg_vars = [] - for arg_name in train_arg_names: - arg_var = train_test_reader_map[ - arg_name] if name == "UnderlyingReader" else startup_block.vars[ - arg_name] - test_arg_vars.append(arg_var) - test_op_inputs[name] = test_arg_vars - - test_op = startup_block.append_op( - type=op.type, - inputs=test_op_inputs, - outputs={'Out': [test_reader]}, - attrs=op.attrs) - # root reader op's filelist attr for read test files - if op.type == root_reader_op.type: - test_op.set_attr("file_names", filelist) - if op.type == "create_multi_pass_reader": - test_op.set_attr("pass_num", 1) - - # 3. rename reader vars in inference program to different name - # to avoid read from train data. - main_block = program.global_block() - for var in main_block.vars.values(): - if var.type == core.VarDesc.VarType.READER: - main_block.rename_var( - str(var.name), str(_get_test_reader_name(var.name))) - - for op in main_block.ops: - if op.type == root_reader_op.type: - test_op.set_attr("file_names", filelist) - if op.type == "create_multi_pass_reader": - test_op.set_attr("pass_num", 1) - - startup_program.sync_with_cpp() - program.sync_with_cpp() - - return program diff --git a/python/paddle/fluid/tests/demo/text_classification/convert_data_to_recordio.py b/python/paddle/fluid/tests/demo/text_classification/convert_data_to_recordio.py index 9425d472a4..2dd8f352f7 100644 --- a/python/paddle/fluid/tests/demo/text_classification/convert_data_to_recordio.py +++ b/python/paddle/fluid/tests/demo/text_classification/convert_data_to_recordio.py @@ -31,8 +31,12 @@ def load_vocab(filename): # load word dict with paddle inner function -word_dict = load_vocab(sys.argv[1]) -word_dict[""] = len(word_dict) +if len(sys.argv) > 1: + word_dict = load_vocab(sys.argv[1]) + word_dict[""] = len(word_dict) +else: + word_dict = paddle.dataset.imdb.word_dict() + print "Dict dim = ", len(word_dict) # input text data diff --git a/python/paddle/fluid/tests/demo/text_classification/train.py b/python/paddle/fluid/tests/demo/text_classification/train.py index e408684c6e..9e930b67a4 100644 --- a/python/paddle/fluid/tests/demo/text_classification/train.py +++ b/python/paddle/fluid/tests/demo/text_classification/train.py @@ -19,7 +19,7 @@ import sys TRAIN_FILES = ['train.recordio'] TEST_FILES = ['test.recordio'] -DICT_DIM = 89528 +DICT_DIM = 5147 # embedding dim emb_dim = 128 @@ -33,33 +33,24 @@ hid_dim2 = 96 # class num class_dim = 2 +# epoch num +epoch_num = 10 -def network_cfg(is_train, pass_num=100): - with fluid.unique_name.guard(): - train_file_obj = fluid.layers.open_files( - filenames=TRAIN_FILES, - pass_num=pass_num, - shapes=[[-1, 1], [-1, 1]], - lod_levels=[1, 0], - dtypes=['int64', 'int64'], - thread_num=1) - - test_file_obj = fluid.layers.open_files( - filenames=TEST_FILES, - pass_num=1, - shapes=[[-1, 1], [-1, 1]], - lod_levels=[1, 0], - dtypes=['int64', 'int64'], - thread_num=1) - if is_train: - file_obj = fluid.layers.shuffle(train_file_obj, buffer_size=1000) - else: - file_obj = test_file_obj +def build_program(is_train): + file_obj_handle = fluid.layers.io.open_files( + filenames=TRAIN_FILES if is_train else TEST_FILES, + shapes=[[-1, 1], [-1, 1]], + lod_levels=[1, 0], + dtypes=['int64', 'int64'], + thread_num=1) + if is_train: + file_obj = fluid.layers.io.shuffle(file_obj_handle, buffer_size=1000) + else: + file_obj = file_obj_handle + file_obj = fluid.layers.io.double_buffer(file_obj) - file_obj = fluid.layers.double_buffer( - file_obj, - name="train_double_buffer" if is_train else 'test_double_buffer') + with fluid.unique_name.guard(): data, label = fluid.layers.read_file(file_obj) @@ -90,58 +81,64 @@ def network_cfg(is_train, pass_num=100): if is_train: # SGD optimizer - sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.01) + sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.001) sgd_optimizer.minimize(avg_cost) - return { - 'loss': avg_cost, - 'log': [avg_cost, acc], - 'file': train_file_obj if is_train else test_file_obj - } + return {'loss': avg_cost, 'log': [avg_cost, acc], 'file': file_obj_handle} def main(): train = fluid.Program() startup = fluid.Program() + test = fluid.Program() with fluid.program_guard(train, startup): - train_args = network_cfg(is_train=True) - - test = fluid.Program() + train_args = build_program(is_train=True) - with fluid.program_guard(test, fluid.Program()): - test_args = network_cfg(is_train=False) + with fluid.program_guard(test, startup): + test_args = build_program(is_train=False) + use_cuda = fluid.core.is_compiled_with_cuda() # startup - place = fluid.CUDAPlace(0) + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place=place) exe.run(startup) train_exe = fluid.ParallelExecutor( - use_cuda=True, loss_name=train_args['loss'].name, main_program=train) + use_cuda=use_cuda, + loss_name=train_args['loss'].name, + main_program=train) + test_exe = fluid.ParallelExecutor( + use_cuda=use_cuda, main_program=test, share_vars_from=train_exe) fetch_var_list = [var.name for var in train_args['log']] - for i in xrange(sys.maxint): - result = map(numpy.array, - train_exe.run(fetch_list=fetch_var_list - if i % 1000 == 0 else [])) - if len(result) != 0: - print 'Train: ', result - - if i % 1000 == 0: - test_exe = fluid.ParallelExecutor( - use_cuda=True, main_program=test, share_vars_from=train_exe) - loss = [] - acc = [] - try: - while True: - loss_np, acc_np = map( - numpy.array, test_exe.run(fetch_list=fetch_var_list)) - loss.append(loss_np[0]) - acc.append(acc_np[0]) - except: - test_args['file'].reset() - print 'TEST: ', numpy.mean(loss), numpy.mean(acc) + for epoch_id in range(epoch_num): + # train + try: + batch_id = 0 + while True: + result = map(numpy.array, + train_exe.run(fetch_list=fetch_var_list + if batch_id % 10 == 0 else [])) + if len(result) != 0: + print 'Train loss: ', result + batch_id += 1 + except fluid.core.EOFException: + print 'End of epoch', epoch_id + train_args['file'].reset() + + # test + loss = [] + acc = [] + try: + while True: + loss_np, acc_np = map(numpy.array, + test_exe.run(fetch_list=fetch_var_list)) + loss.append(loss_np[0]) + acc.append(acc_np[0]) + except: + test_args['file'].reset() + print 'TEST: ', numpy.mean(loss), numpy.mean(acc) if __name__ == '__main__': From 0388d1cb2ac73b28896e28b510cf929da7fa7776 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 24 Jul 2018 16:36:33 +0800 Subject: [PATCH 2/4] some update --- .../.gitignore | 0 .../convert_data_to_recordio.py | 2 +- .../train.py | 41 ++++++++----------- 3 files changed, 18 insertions(+), 25 deletions(-) rename python/paddle/fluid/tests/demo/{text_classification => file_reader}/.gitignore (100%) rename python/paddle/fluid/tests/demo/{text_classification => file_reader}/convert_data_to_recordio.py (96%) rename python/paddle/fluid/tests/demo/{text_classification => file_reader}/train.py (78%) diff --git a/python/paddle/fluid/tests/demo/text_classification/.gitignore b/python/paddle/fluid/tests/demo/file_reader/.gitignore similarity index 100% rename from python/paddle/fluid/tests/demo/text_classification/.gitignore rename to python/paddle/fluid/tests/demo/file_reader/.gitignore diff --git a/python/paddle/fluid/tests/demo/text_classification/convert_data_to_recordio.py b/python/paddle/fluid/tests/demo/file_reader/convert_data_to_recordio.py similarity index 96% rename from python/paddle/fluid/tests/demo/text_classification/convert_data_to_recordio.py rename to python/paddle/fluid/tests/demo/file_reader/convert_data_to_recordio.py index aaa713df88..b839e14889 100644 --- a/python/paddle/fluid/tests/demo/text_classification/convert_data_to_recordio.py +++ b/python/paddle/fluid/tests/demo/file_reader/convert_data_to_recordio.py @@ -50,7 +50,7 @@ feeder = fluid.DataFeeder(feed_list=[data, label], place=fluid.CPUPlace()) BATCH_SIZE = 128 train_reader = paddle.batch( paddle.reader.shuffle( - paddle.dataset.imdb.train(word_dict), buf_size=10000), + paddle.dataset.imdb.train(word_dict), buf_size=25000), batch_size=BATCH_SIZE) test_reader = paddle.batch( diff --git a/python/paddle/fluid/tests/demo/text_classification/train.py b/python/paddle/fluid/tests/demo/file_reader/train.py similarity index 78% rename from python/paddle/fluid/tests/demo/text_classification/train.py rename to python/paddle/fluid/tests/demo/file_reader/train.py index 9e930b67a4..bc3a6dc81d 100644 --- a/python/paddle/fluid/tests/demo/text_classification/train.py +++ b/python/paddle/fluid/tests/demo/file_reader/train.py @@ -27,9 +27,6 @@ emb_dim = 128 # hidden dim hid_dim = 128 -# hidden dim2 -hid_dim2 = 96 - # class num class_dim = 2 @@ -42,13 +39,9 @@ def build_program(is_train): filenames=TRAIN_FILES if is_train else TEST_FILES, shapes=[[-1, 1], [-1, 1]], lod_levels=[1, 0], - dtypes=['int64', 'int64'], - thread_num=1) - if is_train: - file_obj = fluid.layers.io.shuffle(file_obj_handle, buffer_size=1000) - else: - file_obj = file_obj_handle - file_obj = fluid.layers.io.double_buffer(file_obj) + dtypes=['int64', 'int64']) + + file_obj = fluid.layers.io.double_buffer(file_obj_handle) with fluid.unique_name.guard(): @@ -56,22 +49,24 @@ def build_program(is_train): emb = fluid.layers.embedding(input=data, size=[DICT_DIM, emb_dim]) - # sequence conv with window size = 3 - win_size = 3 conv_3 = fluid.nets.sequence_conv_pool( input=emb, num_filters=hid_dim, - filter_size=win_size, + filter_size=3, act="tanh", - pool_type="max") + pool_type="sqrt") - # fc layer after conv - fc_1 = fluid.layers.fc(input=[conv_3], size=hid_dim2) + conv_4 = fluid.nets.sequence_conv_pool( + input=emb, + num_filters=hid_dim, + filter_size=4, + act="tanh", + pool_type="sqrt") - # probability of each class - prediction = fluid.layers.fc(input=[fc_1], + prediction = fluid.layers.fc(input=[conv_3, conv_4], size=class_dim, act="softmax") + # cross entropy loss cost = fluid.layers.cross_entropy(input=prediction, label=label) @@ -117,11 +112,9 @@ def main(): try: batch_id = 0 while True: - result = map(numpy.array, - train_exe.run(fetch_list=fetch_var_list - if batch_id % 10 == 0 else [])) - if len(result) != 0: - print 'Train loss: ', result + loss, acc = map(numpy.array, + train_exe.run(fetch_list=fetch_var_list)) + print 'Train epoch', epoch_id, 'batch', batch_id, 'loss:', loss, 'acc:', acc batch_id += 1 except fluid.core.EOFException: print 'End of epoch', epoch_id @@ -138,7 +131,7 @@ def main(): acc.append(acc_np[0]) except: test_args['file'].reset() - print 'TEST: ', numpy.mean(loss), numpy.mean(acc) + print 'Test loss:', numpy.mean(loss), 'acc:', numpy.mean(acc) if __name__ == '__main__': From aa3618ed3e585802fea5e895430c8d6ff02beb45 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 25 Jul 2018 10:11:03 +0800 Subject: [PATCH 3/4] fix _create_prefetch_block in distribute_transpiler --- python/paddle/fluid/transpiler/distribute_transpiler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index fc58703eca..e7698d8c52 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -887,7 +887,8 @@ class DistributeTranspiler(object): # create table optimize block in pserver program table_opt_op = [ op for op in self.optimize_ops - if op.input("Param")[0] == self.table_name + if 'Param' in op.input_names and op.input("Param")[0] == + self.table_name ][0] table_opt_block = pserver_program.create_block(pre_block_idx) # only support sgd now From a4c5223713ab1c2d1683071bd8218f2a37cc15e1 Mon Sep 17 00:00:00 2001 From: chengduo Date: Wed, 25 Jul 2018 13:44:13 +0800 Subject: [PATCH 4/4] Update test_pe_mnist threshold (#12348) * update test_pe_mnist threshold * clean code --- .../unittests/test_parallel_executor_mnist.py | 85 ++++++------------- 1 file changed, 28 insertions(+), 57 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py index b21e16439a..76389d916f 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py @@ -107,44 +107,24 @@ class TestMNIST(TestParallelExecutorBase): label = np.ones(shape=[32, 1], dtype='int64') return img, label - # simple_fc - def check_simple_fc_convergence(self, use_cuda, use_reduce=False): + def _compare_reduce_and_allreduce(self, model, use_cuda, random_data=True): if use_cuda and not core.is_compiled_with_cuda(): return - self.check_network_convergence(simple_fc_net, use_cuda=use_cuda) self.check_network_convergence( - simple_fc_net, use_cuda=use_cuda, allow_op_delay=True) - - img, label = self._init_data() - + model, use_cuda=use_cuda, use_reduce=True) self.check_network_convergence( - simple_fc_net, - feed_dict={"image": img, - "label": label}, - use_cuda=use_cuda, - use_reduce=use_reduce) + model, use_cuda=use_cuda, allow_op_delay=True, use_reduce=True) - def check_simple_fc_convergence_with_Reduce(self, use_cuda): - if use_cuda and not core.is_compiled_with_cuda(): - return - self.check_network_convergence( - simple_fc_net, use_cuda=use_cuda, use_reduce=True) - self.check_network_convergence( - simple_fc_net, - use_cuda=use_cuda, - allow_op_delay=True, - use_reduce=True) - - img, label = self._init_data() + img, label = self._init_data(random_data) all_reduce_first_loss, all_reduce_last_loss = self.check_network_convergence( - simple_fc_net, + model, feed_dict={"image": img, "label": label}, use_cuda=use_cuda, use_reduce=False) reduce_first_loss, reduce_last_loss = self.check_network_convergence( - simple_fc_net, + model, feed_dict={"image": img, "label": label}, use_cuda=use_cuda, @@ -153,7 +133,24 @@ class TestMNIST(TestParallelExecutorBase): for loss in zip(all_reduce_first_loss, reduce_first_loss): self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) for loss in zip(all_reduce_last_loss, reduce_last_loss): - self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) + self.assertAlmostEquals(loss[0], loss[1], delta=1e-4) + + # simple_fc + def check_simple_fc_convergence(self, use_cuda, use_reduce=False): + if use_cuda and not core.is_compiled_with_cuda(): + return + self.check_network_convergence(simple_fc_net, use_cuda=use_cuda) + self.check_network_convergence( + simple_fc_net, use_cuda=use_cuda, allow_op_delay=True) + + img, label = self._init_data() + + self.check_network_convergence( + simple_fc_net, + feed_dict={"image": img, + "label": label}, + use_cuda=use_cuda, + use_reduce=use_reduce) def test_simple_fc(self): # use_cuda @@ -162,8 +159,8 @@ class TestMNIST(TestParallelExecutorBase): def test_simple_fc_with_new_strategy(self): # use_cuda, use_reduce - self.check_simple_fc_convergence_with_Reduce(True) - self.check_simple_fc_convergence_with_Reduce(False) + self._compare_reduce_and_allreduce(simple_fc_net, True) + self._compare_reduce_and_allreduce(simple_fc_net, False) def check_simple_fc_parallel_accuracy(self, use_cuda): if use_cuda and not core.is_compiled_with_cuda(): @@ -209,39 +206,13 @@ class TestMNIST(TestParallelExecutorBase): "label": label}, use_cuda=use_cuda) - def check_batchnorm_fc_convergence_use_reduce(self, use_cuda): - if use_cuda and not core.is_compiled_with_cuda(): - return - self.check_network_convergence( - fc_with_batchnorm, use_cuda=use_cuda, use_reduce=True) - - img, label = self._init_data() - - all_reduce_first_loss, all_reduce_last_loss = self.check_network_convergence( - fc_with_batchnorm, - feed_dict={"image": img, - "label": label}, - use_cuda=use_cuda, - use_reduce=False) - reduce_first_loss, reduce_last_loss = self.check_network_convergence( - fc_with_batchnorm, - feed_dict={"image": img, - "label": label}, - use_cuda=use_cuda, - use_reduce=True) - - for loss in zip(all_reduce_first_loss, reduce_first_loss): - self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) - for loss in zip(all_reduce_last_loss, reduce_last_loss): - self.assertAlmostEquals(loss[0], loss[1], delta=1e-4) - def test_batchnorm_fc(self): self.check_batchnorm_fc_convergence(True) self.check_batchnorm_fc_convergence(False) def test_batchnorm_fc_with_new_strategy(self): - self.check_batchnorm_fc_convergence_use_reduce(True) - self.check_batchnorm_fc_convergence_use_reduce(False) + self._compare_reduce_and_allreduce(fc_with_batchnorm, True) + self._compare_reduce_and_allreduce(fc_with_batchnorm, False) if __name__ == '__main__':