From b250fceab5e8f9f0c763d1faa054c078fc4db669 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 2 Mar 2017 14:28:15 +0800 Subject: [PATCH 01/13] Add save/load parameters. --- demo/mnist/.gitignore | 1 + demo/mnist/api_train_v2.py | 11 ++++++++++- python/paddle/v2/parameters.py | 27 +++++++++++++++++++++++++++ python/paddle/v2/trainer.py | 2 +- 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/demo/mnist/.gitignore b/demo/mnist/.gitignore index 8bd9837523..9c552159be 100644 --- a/demo/mnist/.gitignore +++ b/demo/mnist/.gitignore @@ -5,3 +5,4 @@ plot.png train.log *pyc .ipynb_checkpoints +params.pkl diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index a59b30ccdb..73fcb9d79d 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -1,4 +1,5 @@ import paddle.v2 as paddle +import cPickle def main(): @@ -16,7 +17,11 @@ def main(): act=paddle.activation.Softmax()) cost = paddle.layer.classification_cost(input=inference, label=label) - parameters = paddle.parameters.create(cost) + try: + with open('params.pkl', 'r') as f: + parameters = cPickle.load(f) + except IOError: + parameters = paddle.parameters.create(cost) adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01) @@ -34,6 +39,10 @@ def main(): event.pass_id, event.batch_id, event.cost, event.metrics, result.metrics) + with open('params.pkl', 'w') as f: + cPickle.dump( + parameters, f, protocol=cPickle.HIGHEST_PROTOCOL) + else: pass diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index 2a6026bcab..d8c3a73b0e 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -222,6 +222,33 @@ class Parameters(object): self.__gradient_machines__.append(gradient_machine) + def __getstate__(self): + params = {} + for name in self.names(): + params[name] = self.get(name) + + param_conf = {} + for name in self.__param_conf__: + conf = self.__param_conf__[name] + assert isinstance(conf, ParameterConfig) + param_conf[name] = conf.SerializeToString() + + return {'conf': param_conf, 'params': params} + + def __setstate__(self, obj): + Parameters.__init__(self) + + def __impl__(conf, params): + for name in conf: + p = ParameterConfig() + p.ParseFromString(conf[name]) + self.__append_config__(p) + for name in params: + shape = self.get_shape(name) + self.set(name, params[name].reshape(shape)) + + __impl__(**obj) + def __get_parameter_in_gradient_machine__(gradient_machine, name): """ diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 5003f55f3e..709566ca44 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -66,9 +66,9 @@ class SGD(ITrainer): self.__topology_in_proto__, api.CREATE_MODE_NORMAL, self.__optimizer__.enable_types()) assert isinstance(gm, api.GradientMachine) - parameters.append_gradient_machine(gm) self.__gradient_machine__ = gm self.__gradient_machine__.randParameters() + parameters.append_gradient_machine(gm) def train(self, reader, num_passes=1, event_handler=None, reader_dict=None): """ From fb74ae36d4b9ba21ccf98bd45b04c361029e7406 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sat, 4 Mar 2017 17:11:49 +0800 Subject: [PATCH 02/13] Refine serialize --- demo/mnist/.gitignore | 1 + demo/mnist/api_train_v2.py | 3 +++ python/paddle/v2/parameters.py | 40 +++++++++++++++++++++++++++++++++- python/paddle/v2/trainer.py | 4 ++-- 4 files changed, 45 insertions(+), 3 deletions(-) diff --git a/demo/mnist/.gitignore b/demo/mnist/.gitignore index 9c552159be..ed074b09e7 100644 --- a/demo/mnist/.gitignore +++ b/demo/mnist/.gitignore @@ -6,3 +6,4 @@ train.log *pyc .ipynb_checkpoints params.pkl +params.tar diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index a72ebfa980..7a1f661318 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -103,6 +103,9 @@ def main(): cPickle.dump( parameters, f, protocol=cPickle.HIGHEST_PROTOCOL) + with open('params.tar', 'w') as f: + parameters.serialize_to_tar(f) + elif isinstance(event, paddle.event.EndPass): result = trainer.test(reader=paddle.reader.batched( paddle.dataset.mnist.test(), batch_size=128)) diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index d8c3a73b0e..6a7b883500 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -1,7 +1,9 @@ import numpy as np import py_paddle.swig_paddle as api from paddle.proto.ParameterConfig_pb2 import ParameterConfig - +import struct +import tarfile +import cStringIO from topology import Topology __all__ = ['Parameters', 'create'] @@ -235,6 +237,42 @@ class Parameters(object): return {'conf': param_conf, 'params': params} + def serialize(self, name, f): + """ + + :param name: + :param f: + :type f: file + :return: + """ + param = self.get(name) + size = reduce(lambda a, b: a * b, param.shape) + f.write(struct.pack("IIQ", 0, 4, size)) + param = param.astype(np.float32) + f.write(param.tobytes()) + + def deserialize(self, name, f): + """ + + :param name: + :param f: + :type f: file + :return: + """ + f.read(16) # header + arr = np.fromfile(f, dtype=np.float32) + self.set(name, arr.reshape(self.get_shape(name))) + + def serialize_to_tar(self, f): + tar = tarfile.TarFile(fileobj=f, mode='w') + for nm in self.names(): + buf = cStringIO.StringIO() + self.serialize(nm, buf) + tarinfo = tarfile.TarInfo(name=nm) + buf.seek(0) + tarinfo.size = len(buf.getvalue()) + tar.addfile(tarinfo, buf) + def __setstate__(self, obj): Parameters.__init__(self) diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index e878ea6e3b..7da97d79a8 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -101,7 +101,7 @@ class SGD(): for each_param in self.__gradient_machine__.getNonStaticParameters( ): updater.update(each_param) - cost_sum = out_args.sumCosts() + cost_sum = out_args.sum() cost = cost_sum / len(data_batch) updater.finishBatch(cost) batch_evaluator.finish() @@ -137,7 +137,7 @@ class SGD(): num_samples += len(data_batch) self.__gradient_machine__.forward( feeder(data_batch), out_args, api.PASS_TEST) - total_cost += out_args.sumCosts() + total_cost += out_args.sum() self.__gradient_machine__.eval(evaluator) evaluator.finish() From efe53811c569182df71b14c46aa5a7238038cba1 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sat, 4 Mar 2017 18:32:39 +0800 Subject: [PATCH 03/13] complete serialize * Test gzip --- demo/mnist/.gitignore | 1 + demo/mnist/api_train_v2.py | 14 ++++----- python/paddle/v2/parameters.py | 54 ++++++++++++++++------------------ 3 files changed, 32 insertions(+), 37 deletions(-) diff --git a/demo/mnist/.gitignore b/demo/mnist/.gitignore index ed074b09e7..7e61d5e3a0 100644 --- a/demo/mnist/.gitignore +++ b/demo/mnist/.gitignore @@ -7,3 +7,4 @@ train.log .ipynb_checkpoints params.pkl params.tar +params.tar.gz diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 7a1f661318..a11260d91b 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -1,5 +1,5 @@ import paddle.v2 as paddle -import cPickle +import gzip def softmax_regression(img): @@ -73,8 +73,8 @@ def main(): cost = paddle.layer.classification_cost(input=predict, label=label) try: - with open('params.pkl', 'r') as f: - parameters = cPickle.load(f) + with gzip.open('params.tar.gz', 'r') as f: + parameters = paddle.parameters.Parameters.from_tar(f) except IOError: parameters = paddle.parameters.create(cost) @@ -99,12 +99,8 @@ def main(): event.pass_id, event.batch_id, event.cost, event.metrics, result.metrics) - with open('params.pkl', 'w') as f: - cPickle.dump( - parameters, f, protocol=cPickle.HIGHEST_PROTOCOL) - - with open('params.tar', 'w') as f: - parameters.serialize_to_tar(f) + with gzip.open('params.tar.gz', 'w') as f: + parameters.to_tar(f) elif isinstance(event, paddle.event.EndPass): result = trainer.test(reader=paddle.reader.batched( diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index 6a7b883500..58be523407 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -224,19 +224,6 @@ class Parameters(object): self.__gradient_machines__.append(gradient_machine) - def __getstate__(self): - params = {} - for name in self.names(): - params[name] = self.get(name) - - param_conf = {} - for name in self.__param_conf__: - conf = self.__param_conf__[name] - assert isinstance(conf, ParameterConfig) - param_conf[name] = conf.SerializeToString() - - return {'conf': param_conf, 'params': params} - def serialize(self, name, f): """ @@ -260,10 +247,10 @@ class Parameters(object): :return: """ f.read(16) # header - arr = np.fromfile(f, dtype=np.float32) + arr = np.frombuffer(f.read(), dtype=np.float32) self.set(name, arr.reshape(self.get_shape(name))) - def serialize_to_tar(self, f): + def to_tar(self, f): tar = tarfile.TarFile(fileobj=f, mode='w') for nm in self.names(): buf = cStringIO.StringIO() @@ -273,19 +260,30 @@ class Parameters(object): tarinfo.size = len(buf.getvalue()) tar.addfile(tarinfo, buf) - def __setstate__(self, obj): - Parameters.__init__(self) - - def __impl__(conf, params): - for name in conf: - p = ParameterConfig() - p.ParseFromString(conf[name]) - self.__append_config__(p) - for name in params: - shape = self.get_shape(name) - self.set(name, params[name].reshape(shape)) - - __impl__(**obj) + conf = self.__param_conf__[nm] + confStr = conf.SerializeToString() + tarinfo = tarfile.TarInfo(name="%s.protobuf" % nm) + tarinfo.size = len(confStr) + buf = cStringIO.StringIO(confStr) + buf.seek(0) + tar.addfile(tarinfo, fileobj=buf) + + @staticmethod + def from_tar(f): + params = Parameters() + tar = tarfile.TarFile(fileobj=f, mode='r') + for finfo in tar: + assert isinstance(finfo, tarfile.TarInfo) + if finfo.name.endswith('.protobuf'): + f = tar.extractfile(finfo) + conf = ParameterConfig() + conf.ParseFromString(f.read()) + params.__append_config__(conf) + + for param_name in params.names(): + f = tar.extractfile(param_name) + params.deserialize(param_name, f) + return params def __get_parameter_in_gradient_machine__(gradient_machine, name): From 98522dcb3390ecbccc4c544e0c9c57b8d3e4b2c9 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 5 Mar 2017 21:26:44 +0800 Subject: [PATCH 04/13] optimizer wmt14 dataset --- demo/seqToseq/api_train_v2.py | 161 +++++++++++++++----------- demo/seqToseq/seqToseq_net_v2.py | 92 --------------- python/paddle/v2/dataset/wmt14.py | 185 +++++++++++++----------------- 3 files changed, 174 insertions(+), 264 deletions(-) delete mode 100644 demo/seqToseq/seqToseq_net_v2.py diff --git a/demo/seqToseq/api_train_v2.py b/demo/seqToseq/api_train_v2.py index a5f59ec379..f100ef80cb 100644 --- a/demo/seqToseq/api_train_v2.py +++ b/demo/seqToseq/api_train_v2.py @@ -1,76 +1,106 @@ -import os - import paddle.v2 as paddle -from seqToseq_net_v2 import seqToseq_net_v2 - -# Data Definiation. -# TODO:This code should be merged to dataset package. -data_dir = "./data/pre-wmt14" -src_lang_dict = os.path.join(data_dir, 'src.dict') -trg_lang_dict = os.path.join(data_dir, 'trg.dict') - -source_dict_dim = len(open(src_lang_dict, "r").readlines()) -target_dict_dim = len(open(trg_lang_dict, "r").readlines()) - - -def read_to_dict(dict_path): - with open(dict_path, "r") as fin: - out_dict = { - line.strip(): line_count - for line_count, line in enumerate(fin) - } - return out_dict - - -src_dict = read_to_dict(src_lang_dict) -trg_dict = read_to_dict(trg_lang_dict) - -train_list = os.path.join(data_dir, 'train.list') -test_list = os.path.join(data_dir, 'test.list') - -UNK_IDX = 2 -START = "" -END = "" - -def _get_ids(s, dictionary): - words = s.strip().split() - return [dictionary[START]] + \ - [dictionary.get(w, UNK_IDX) for w in words] + \ - [dictionary[END]] - - -def train_reader(file_name): - def reader(): - with open(file_name, 'r') as f: - for line_count, line in enumerate(f): - line_split = line.strip().split('\t') - if len(line_split) != 2: - continue - src_seq = line_split[0] # one source sequence - src_ids = _get_ids(src_seq, src_dict) - - trg_seq = line_split[1] # one target sequence - trg_words = trg_seq.split() - trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words] - - # remove sequence whose length > 80 in training mode - if len(src_ids) > 80 or len(trg_ids) > 80: - continue - trg_ids_next = trg_ids + [trg_dict[END]] - trg_ids = [trg_dict[START]] + trg_ids - - yield src_ids, trg_ids, trg_ids_next - - return reader +def seqToseq_net(source_dict_dim, target_dict_dim): + ### Network Architecture + word_vector_dim = 512 # dimension of word vector + decoder_size = 512 # dimension of hidden unit in GRU Decoder network + encoder_size = 512 # dimension of hidden unit in GRU Encoder network + + #### Encoder + src_word_id = paddle.layer.data( + name='source_language_word', + type=paddle.data_type.integer_value_sequence(source_dict_dim)) + src_embedding = paddle.layer.embedding( + input=src_word_id, + size=word_vector_dim, + param_attr=paddle.attr.ParamAttr(name='_source_language_embedding')) + src_forward = paddle.networks.simple_gru( + input=src_embedding, size=encoder_size) + src_backward = paddle.networks.simple_gru( + input=src_embedding, size=encoder_size, reverse=True) + encoded_vector = paddle.layer.concat(input=[src_forward, src_backward]) + + #### Decoder + with paddle.layer.mixed(size=decoder_size) as encoded_proj: + encoded_proj += paddle.layer.full_matrix_projection( + input=encoded_vector) + + backward_first = paddle.layer.first_seq(input=src_backward) + + with paddle.layer.mixed( + size=decoder_size, act=paddle.activation.Tanh()) as decoder_boot: + decoder_boot += paddle.layer.full_matrix_projection( + input=backward_first) + + def gru_decoder_with_attention(enc_vec, enc_proj, current_word): + + decoder_mem = paddle.layer.memory( + name='gru_decoder', size=decoder_size, boot_layer=decoder_boot) + + context = paddle.networks.simple_attention( + encoded_sequence=enc_vec, + encoded_proj=enc_proj, + decoder_state=decoder_mem) + + with paddle.layer.mixed(size=decoder_size * 3) as decoder_inputs: + decoder_inputs += paddle.layer.full_matrix_projection(input=context) + decoder_inputs += paddle.layer.full_matrix_projection( + input=current_word) + + gru_step = paddle.layer.gru_step( + name='gru_decoder', + input=decoder_inputs, + output_mem=decoder_mem, + size=decoder_size) + + with paddle.layer.mixed( + size=target_dict_dim, + bias_attr=True, + act=paddle.activation.Softmax()) as out: + out += paddle.layer.full_matrix_projection(input=gru_step) + return out + + decoder_group_name = "decoder_group" + group_input1 = paddle.layer.StaticInputV2(input=encoded_vector, is_seq=True) + group_input2 = paddle.layer.StaticInputV2(input=encoded_proj, is_seq=True) + group_inputs = [group_input1, group_input2] + + trg_embedding = paddle.layer.embedding( + input=paddle.layer.data( + name='target_language_word', + type=paddle.data_type.integer_value_sequence(target_dict_dim)), + size=word_vector_dim, + param_attr=paddle.attr.ParamAttr(name='_target_language_embedding')) + group_inputs.append(trg_embedding) + + # For decoder equipped with attention mechanism, in training, + # target embeding (the groudtruth) is the data input, + # while encoded source sequence is accessed to as an unbounded memory. + # Here, the StaticInput defines a read-only memory + # for the recurrent_group. + decoder = paddle.layer.recurrent_group( + name=decoder_group_name, + step=gru_decoder_with_attention, + input=group_inputs) + + lbl = paddle.layer.data( + name='target_language_next_word', + type=paddle.data_type.integer_value_sequence(target_dict_dim)) + cost = paddle.layer.classification_cost(input=decoder, label=lbl) + + return cost def main(): paddle.init(use_gpu=False, trainer_count=1) + # source and target dict dim. + dict_size = 30000 + source_dict_dim = target_dict_dim = dict_size + # define network topology - cost = seqToseq_net_v2(source_dict_dim, target_dict_dim) + cost = seqToseq_net(source_dict_dim, target_dict_dim) parameters = paddle.parameters.create(cost) # define optimize method and trainer @@ -85,10 +115,9 @@ def main(): 'target_language_word': 1, 'target_language_next_word': 2 } - wmt14_reader = paddle.reader.batched( paddle.reader.shuffle( - train_reader("data/pre-wmt14/train/train"), buf_size=8192), + paddle.dataset.wmt14.train(dict_size=dict_size), buf_size=8192), batch_size=5) # define event_handler callback diff --git a/demo/seqToseq/seqToseq_net_v2.py b/demo/seqToseq/seqToseq_net_v2.py deleted file mode 100644 index 058a6789d7..0000000000 --- a/demo/seqToseq/seqToseq_net_v2.py +++ /dev/null @@ -1,92 +0,0 @@ -import paddle.v2 as paddle - - -def seqToseq_net_v2(source_dict_dim, target_dict_dim): - ### Network Architecture - word_vector_dim = 512 # dimension of word vector - decoder_size = 512 # dimension of hidden unit in GRU Decoder network - encoder_size = 512 # dimension of hidden unit in GRU Encoder network - - #### Encoder - src_word_id = paddle.layer.data( - name='source_language_word', - type=paddle.data_type.integer_value_sequence(source_dict_dim)) - src_embedding = paddle.layer.embedding( - input=src_word_id, - size=word_vector_dim, - param_attr=paddle.attr.ParamAttr(name='_source_language_embedding')) - src_forward = paddle.networks.simple_gru( - input=src_embedding, size=encoder_size) - src_backward = paddle.networks.simple_gru( - input=src_embedding, size=encoder_size, reverse=True) - encoded_vector = paddle.layer.concat(input=[src_forward, src_backward]) - - #### Decoder - with paddle.layer.mixed(size=decoder_size) as encoded_proj: - encoded_proj += paddle.layer.full_matrix_projection( - input=encoded_vector) - - backward_first = paddle.layer.first_seq(input=src_backward) - - with paddle.layer.mixed( - size=decoder_size, act=paddle.activation.Tanh()) as decoder_boot: - decoder_boot += paddle.layer.full_matrix_projection( - input=backward_first) - - def gru_decoder_with_attention(enc_vec, enc_proj, current_word): - - decoder_mem = paddle.layer.memory( - name='gru_decoder', size=decoder_size, boot_layer=decoder_boot) - - context = paddle.networks.simple_attention( - encoded_sequence=enc_vec, - encoded_proj=enc_proj, - decoder_state=decoder_mem) - - with paddle.layer.mixed(size=decoder_size * 3) as decoder_inputs: - decoder_inputs += paddle.layer.full_matrix_projection(input=context) - decoder_inputs += paddle.layer.full_matrix_projection( - input=current_word) - - gru_step = paddle.layer.gru_step( - name='gru_decoder', - input=decoder_inputs, - output_mem=decoder_mem, - size=decoder_size) - - with paddle.layer.mixed( - size=target_dict_dim, - bias_attr=True, - act=paddle.activation.Softmax()) as out: - out += paddle.layer.full_matrix_projection(input=gru_step) - return out - - decoder_group_name = "decoder_group" - group_input1 = paddle.layer.StaticInputV2(input=encoded_vector, is_seq=True) - group_input2 = paddle.layer.StaticInputV2(input=encoded_proj, is_seq=True) - group_inputs = [group_input1, group_input2] - - trg_embedding = paddle.layer.embedding( - input=paddle.layer.data( - name='target_language_word', - type=paddle.data_type.integer_value_sequence(target_dict_dim)), - size=word_vector_dim, - param_attr=paddle.attr.ParamAttr(name='_target_language_embedding')) - group_inputs.append(trg_embedding) - - # For decoder equipped with attention mechanism, in training, - # target embeding (the groudtruth) is the data input, - # while encoded source sequence is accessed to as an unbounded memory. - # Here, the StaticInput defines a read-only memory - # for the recurrent_group. - decoder = paddle.layer.recurrent_group( - name=decoder_group_name, - step=gru_decoder_with_attention, - input=group_inputs) - - lbl = paddle.layer.data( - name='target_language_next_word', - type=paddle.data_type.integer_value_sequence(target_dict_dim)) - cost = paddle.layer.classification_cost(input=decoder, label=lbl) - - return cost diff --git a/python/paddle/v2/dataset/wmt14.py b/python/paddle/v2/dataset/wmt14.py index 9904848b5d..5a9dd4ca80 100644 --- a/python/paddle/v2/dataset/wmt14.py +++ b/python/paddle/v2/dataset/wmt14.py @@ -14,129 +14,102 @@ """ wmt14 dataset """ -import paddle.v2.dataset.common -import tarfile +import os import os.path -import itertools +import tarfile + +import paddle.v2.dataset.common +from wmt14_util import SeqToSeqDatasetCreater __all__ = ['train', 'test', 'build_dict'] URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz' MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5' -URL_TRAIN = 'http://localhost:8000/train.tgz' -MD5_TRAIN = '72de99da2830ea5a3a2c4eb36092bbc7' - - -def word_count(f, word_freq=None): - add = paddle.v2.dataset.common.dict_add - if word_freq == None: - word_freq = {} - - for l in f: - for w in l.strip().split(): - add(word_freq, w) - add(word_freq, '') - add(word_freq, '') - - return word_freq - - -def get_word_dix(word_freq): - TYPO_FREQ = 50 - word_freq = filter(lambda x: x[1] > TYPO_FREQ, word_freq.items()) - word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) - words, _ = list(zip(*word_freq_sorted)) - word_idx = dict(zip(words, xrange(len(words)))) - word_idx[''] = len(words) - return word_idx - - -def get_word_freq(train, dev): - word_freq = word_count(train, word_count(dev)) - if '' in word_freq: - # remove for now, since we will set it as last index - del word_freq[''] - return word_freq - - -def build_dict(): - base_dir = './wmt14-data' - train_en_filename = base_dir + '/train/train.en' - train_fr_filename = base_dir + '/train/train.fr' - dev_en_filename = base_dir + '/dev/ntst1213.en' - dev_fr_filename = base_dir + '/dev/ntst1213.fr' - - if not os.path.exists(train_en_filename) or not os.path.exists( - train_fr_filename): +URL_TRAIN = 'http://localhost:8989/wmt14.tgz' +MD5_TRAIN = '7373473f86016f1f48037c9c340a2d5b' + +START = "" +END = "" +UNK = "" +UNK_IDX = 2 + +DEFAULT_DATA_DIR = "./data" +ORIGIN_DATA_DIR = "wmt14" +INNER_DATA_DIR = "pre-wmt14" +SRC_DICT = INNER_DATA_DIR + "/src.dict" +TRG_DICT = INNER_DATA_DIR + "/trg.dict" +TRAIN_FILE = INNER_DATA_DIR + "/train/train" + + +def __process_data__(data_path, dict_size=None): + downloaded_data = os.path.join(data_path, ORIGIN_DATA_DIR) + if not os.path.exists(downloaded_data): + # 1. download and extract tgz. with tarfile.open( paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)) as tf: - tf.extractall(base_dir) - - if not os.path.exists(dev_en_filename) or not os.path.exists( - dev_fr_filename): - with tarfile.open( - paddle.v2.dataset.common.download(URL_DEV_TEST, 'wmt14', - MD5_DEV_TEST)) as tf: - tf.extractall(base_dir) - - f_en = open(train_en_filename) - f_fr = open(train_fr_filename) - f_en_dev = open(dev_en_filename) - f_fr_dev = open(dev_fr_filename) - - word_freq_en = get_word_freq(f_en, f_en_dev) - word_freq_fr = get_word_freq(f_fr, f_fr_dev) - - f_en.close() - f_fr.close() - f_en_dev.close() - f_fr_dev.close() - - return get_word_dix(word_freq_en), get_word_dix(word_freq_fr) - - -def reader_creator(directory, path_en, path_fr, URL, MD5, dict_en, dict_fr): - def reader(): - if not os.path.exists(path_en) or not os.path.exists(path_fr): - with tarfile.open( - paddle.v2.dataset.common.download(URL, 'wmt14', MD5)) as tf: - tf.extractall(directory) - - f_en = open(path_en) - f_fr = open(path_fr) - UNK_en = dict_en[''] - UNK_fr = dict_fr[''] - - for en, fr in itertools.izip(f_en, f_fr): - src_ids = [dict_en.get(w, UNK_en) for w in en.strip().split()] - tar_ids = [ - dict_fr.get(w, UNK_fr) - for w in [''] + fr.strip().split() + [''] + tf.extractall(data_path) + + # 2. process data file to intermediate format. + processed_data = os.path.join(data_path, INNER_DATA_DIR) + if not os.path.exists(processed_data): + dict_size = dict_size or -1 + data_creator = SeqToSeqDatasetCreater(downloaded_data, processed_data) + data_creator.create_dataset(dict_size, mergeDict=False) + + +def __read_to_dict__(dict_path, count): + with open(dict_path, "r") as fin: + out_dict = dict() + for line_count, line in enumerate(fin): + if line_count <= count: + out_dict[line.strip()] = line_count + else: + break + return out_dict + + +def __reader__(file_name, src_dict, trg_dict): + with open(file_name, 'r') as f: + for line_count, line in enumerate(f): + line_split = line.strip().split('\t') + if len(line_split) != 2: + continue + src_seq = line_split[0] # one source sequence + src_words = src_seq.split() + src_ids = [ + src_dict.get(w, UNK_IDX) for w in [START] + src_words + [END] ] + trg_seq = line_split[1] # one target sequence + trg_words = trg_seq.split() + trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words] + # remove sequence whose length > 80 in training mode - if len(src_ids) == 0 or len(tar_ids) <= 1 or len( - src_ids) > 80 or len(tar_ids) > 80: + if len(src_ids) > 80 or len(trg_ids) > 80: continue + trg_ids_next = trg_ids + [trg_dict[END]] + trg_ids = [trg_dict[START]] + trg_ids + + yield src_ids, trg_ids, trg_ids_next - yield src_ids, tar_ids[:-1], tar_ids[1:] - f_en.close() - f_fr.close() +def train(data_dir=None, dict_size=None): + data_dir = data_dir or DEFAULT_DATA_DIR + __process_data__(data_dir, dict_size) + src_lang_dict = os.path.join(data_dir, SRC_DICT) + trg_lang_dict = os.path.join(data_dir, TRG_DICT) + train_file_name = os.path.join(data_dir, TRAIN_FILE) - return reader + default_dict_size = len(open(src_lang_dict, "r").readlines()) + if dict_size > default_dict_size: + raise ValueError("dict_dim should not be larger then the " + "length of word dict") -def train(dict_en, dict_fr): - directory = './wmt14-data' - return reader_creator(directory, directory + '/train/train.en', - directory + '/train/train.fr', URL_TRAIN, MD5_TRAIN, - dict_en, dict_fr) + real_dict_dim = dict_size or default_dict_size + src_dict = __read_to_dict__(src_lang_dict, real_dict_dim) + trg_dict = __read_to_dict__(trg_lang_dict, real_dict_dim) -def test(dict_en, dict_fr): - directory = './wmt14-data' - return reader_creator(directory, directory + '/dev/ntst1213.en', - directory + '/dev/ntst1213.fr', URL_DEV_TEST, - MD5_DEV_TEST, dict_en, dict_fr) + return lambda: __reader__(train_file_name, src_dict, trg_dict) From f6f444ff3da1ab16d4b770d9c98615095f842715 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sun, 5 Mar 2017 21:55:46 +0800 Subject: [PATCH 05/13] optimize code --- demo/seqToseq/api_train_v2.py | 2 +- demo/seqToseq/preprocess.py | 160 +--------------------------------- 2 files changed, 2 insertions(+), 160 deletions(-) diff --git a/demo/seqToseq/api_train_v2.py b/demo/seqToseq/api_train_v2.py index f100ef80cb..177fd26d68 100644 --- a/demo/seqToseq/api_train_v2.py +++ b/demo/seqToseq/api_train_v2.py @@ -115,7 +115,7 @@ def main(): 'target_language_word': 1, 'target_language_next_word': 2 } - wmt14_reader = paddle.reader.batched( + wmt14_reader = paddle.batch( paddle.reader.shuffle( paddle.dataset.wmt14.train(dict_size=dict_size), buf_size=8192), batch_size=5) diff --git a/demo/seqToseq/preprocess.py b/demo/seqToseq/preprocess.py index 03f371331a..afa7bd5e0f 100755 --- a/demo/seqToseq/preprocess.py +++ b/demo/seqToseq/preprocess.py @@ -23,167 +23,9 @@ Options: -m --mergeDict merge source and target dictionary """ import os -import sys - -import string from optparse import OptionParser -from paddle.utils.preprocess_util import save_list, DatasetCreater - - -class SeqToSeqDatasetCreater(DatasetCreater): - """ - A class to process data for sequence to sequence application. - """ - - def __init__(self, data_path, output_path): - """ - data_path: the path to store the train data, test data and gen data - output_path: the path to store the processed dataset - """ - DatasetCreater.__init__(self, data_path) - self.gen_dir_name = 'gen' - self.gen_list_name = 'gen.list' - self.output_path = output_path - - def concat_file(self, file_path, file1, file2, output_path, output): - """ - Concat file1 and file2 to be one output file - The i-th line of output = i-th line of file1 + '\t' + i-th line of file2 - file_path: the path to store file1 and file2 - output_path: the path to store output file - """ - file1 = os.path.join(file_path, file1) - file2 = os.path.join(file_path, file2) - output = os.path.join(output_path, output) - if not os.path.exists(output): - os.system('paste ' + file1 + ' ' + file2 + ' > ' + output) - - def cat_file(self, dir_path, suffix, output_path, output): - """ - Cat all the files in dir_path with suffix to be one output file - dir_path: the base directory to store input file - suffix: suffix of file name - output_path: the path to store output file - """ - cmd = 'cat ' - file_list = os.listdir(dir_path) - file_list.sort() - for file in file_list: - if file.endswith(suffix): - cmd += os.path.join(dir_path, file) + ' ' - output = os.path.join(output_path, output) - if not os.path.exists(output): - os.system(cmd + '> ' + output) - - def build_dict(self, file_path, dict_path, dict_size=-1): - """ - Create the dictionary for the file, Note that - 1. Valid characters include all printable characters - 2. There is distinction between uppercase and lowercase letters - 3. There is 3 special token: - : the start of a sequence - : the end of a sequence - : a word not included in dictionary - file_path: the path to store file - dict_path: the path to store dictionary - dict_size: word count of dictionary - if is -1, dictionary will contains all the words in file - """ - if not os.path.exists(dict_path): - dictory = dict() - with open(file_path, "r") as fdata: - for line in fdata: - line = line.split('\t') - for line_split in line: - words = line_split.strip().split() - for word in words: - if word not in dictory: - dictory[word] = 1 - else: - dictory[word] += 1 - output = open(dict_path, "w+") - output.write('\n\n\n') - count = 3 - for key, value in sorted( - dictory.items(), key=lambda d: d[1], reverse=True): - output.write(key + "\n") - count += 1 - if count == dict_size: - break - self.dict_size = count - - def create_dataset(self, - dict_size=-1, - mergeDict=False, - suffixes=['.src', '.trg']): - """ - Create seqToseq dataset - """ - # dataset_list and dir_list has one-to-one relationship - train_dataset = os.path.join(self.data_path, self.train_dir_name) - test_dataset = os.path.join(self.data_path, self.test_dir_name) - gen_dataset = os.path.join(self.data_path, self.gen_dir_name) - dataset_list = [train_dataset, test_dataset, gen_dataset] - - train_dir = os.path.join(self.output_path, self.train_dir_name) - test_dir = os.path.join(self.output_path, self.test_dir_name) - gen_dir = os.path.join(self.output_path, self.gen_dir_name) - dir_list = [train_dir, test_dir, gen_dir] - - # create directory - for dir in dir_list: - if not os.path.exists(dir): - os.mkdir(dir) - - # checkout dataset should be parallel corpora - suffix_len = len(suffixes[0]) - for dataset in dataset_list: - file_list = os.listdir(dataset) - if len(file_list) % 2 == 1: - raise RuntimeError("dataset should be parallel corpora") - file_list.sort() - for i in range(0, len(file_list), 2): - if file_list[i][:-suffix_len] != file_list[i + 1][:-suffix_len]: - raise RuntimeError( - "source and target file name should be equal") - - # cat all the files with the same suffix in dataset - for suffix in suffixes: - for dataset in dataset_list: - outname = os.path.basename(dataset) + suffix - self.cat_file(dataset, suffix, dataset, outname) - - # concat parallel corpora and create file.list - print 'concat parallel corpora for dataset' - id = 0 - list = ['train.list', 'test.list', 'gen.list'] - for dataset in dataset_list: - outname = os.path.basename(dataset) - self.concat_file(dataset, outname + suffixes[0], - outname + suffixes[1], dir_list[id], outname) - save_list([os.path.join(dir_list[id], outname)], - os.path.join(self.output_path, list[id])) - id += 1 - # build dictionary for train data - dict = ['src.dict', 'trg.dict'] - dict_path = [ - os.path.join(self.output_path, dict[0]), - os.path.join(self.output_path, dict[1]) - ] - if mergeDict: - outname = os.path.join(train_dir, train_dataset.split('/')[-1]) - print 'build src dictionary for train data' - self.build_dict(outname, dict_path[0], dict_size) - print 'build trg dictionary for train data' - os.system('cp ' + dict_path[0] + ' ' + dict_path[1]) - else: - outname = os.path.join(train_dataset, self.train_dir_name) - for id in range(0, 2): - suffix = suffixes[id] - print 'build ' + suffix[1:] + ' dictionary for train data' - self.build_dict(outname + suffix, dict_path[id], dict_size) - print 'dictionary size is', self.dict_size +from paddle.v2.dataset.wmt14_util import SeqToSeqDatasetCreater def main(): From 06915d0a0507ba25e46917e4e622fcbbe3cd2668 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 6 Mar 2017 10:53:16 +0800 Subject: [PATCH 06/13] add wmt14_util.py and a small dataset on bos for test --- python/paddle/v2/dataset/wmt14.py | 3 +- python/paddle/v2/dataset/wmt14_util.py | 172 +++++++++++++++++++++++++ 2 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 python/paddle/v2/dataset/wmt14_util.py diff --git a/python/paddle/v2/dataset/wmt14.py b/python/paddle/v2/dataset/wmt14.py index 5a9dd4ca80..254f07c8dd 100644 --- a/python/paddle/v2/dataset/wmt14.py +++ b/python/paddle/v2/dataset/wmt14.py @@ -25,7 +25,8 @@ __all__ = ['train', 'test', 'build_dict'] URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz' MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5' -URL_TRAIN = 'http://localhost:8989/wmt14.tgz' +# this is a small set of data for test. The original data is too large and will be add later. +URL_TRAIN = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz' MD5_TRAIN = '7373473f86016f1f48037c9c340a2d5b' START = "" diff --git a/python/paddle/v2/dataset/wmt14_util.py b/python/paddle/v2/dataset/wmt14_util.py new file mode 100644 index 0000000000..0d72389164 --- /dev/null +++ b/python/paddle/v2/dataset/wmt14_util.py @@ -0,0 +1,172 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from paddle.utils.preprocess_util import save_list, DatasetCreater + + +class SeqToSeqDatasetCreater(DatasetCreater): + """ + A class to process data for sequence to sequence application. + """ + + def __init__(self, data_path, output_path): + """ + data_path: the path to store the train data, test data and gen data + output_path: the path to store the processed dataset + """ + DatasetCreater.__init__(self, data_path) + self.gen_dir_name = 'gen' + self.gen_list_name = 'gen.list' + self.output_path = output_path + + def concat_file(self, file_path, file1, file2, output_path, output): + """ + Concat file1 and file2 to be one output file + The i-th line of output = i-th line of file1 + '\t' + i-th line of file2 + file_path: the path to store file1 and file2 + output_path: the path to store output file + """ + file1 = os.path.join(file_path, file1) + file2 = os.path.join(file_path, file2) + output = os.path.join(output_path, output) + if not os.path.exists(output): + os.system('paste ' + file1 + ' ' + file2 + ' > ' + output) + + def cat_file(self, dir_path, suffix, output_path, output): + """ + Cat all the files in dir_path with suffix to be one output file + dir_path: the base directory to store input file + suffix: suffix of file name + output_path: the path to store output file + """ + cmd = 'cat ' + file_list = os.listdir(dir_path) + file_list.sort() + for file in file_list: + if file.endswith(suffix): + cmd += os.path.join(dir_path, file) + ' ' + output = os.path.join(output_path, output) + if not os.path.exists(output): + os.system(cmd + '> ' + output) + + def build_dict(self, file_path, dict_path, dict_size=-1): + """ + Create the dictionary for the file, Note that + 1. Valid characters include all printable characters + 2. There is distinction between uppercase and lowercase letters + 3. There is 3 special token: + : the start of a sequence + : the end of a sequence + : a word not included in dictionary + file_path: the path to store file + dict_path: the path to store dictionary + dict_size: word count of dictionary + if is -1, dictionary will contains all the words in file + """ + if not os.path.exists(dict_path): + dictory = dict() + with open(file_path, "r") as fdata: + for line in fdata: + line = line.split('\t') + for line_split in line: + words = line_split.strip().split() + for word in words: + if word not in dictory: + dictory[word] = 1 + else: + dictory[word] += 1 + output = open(dict_path, "w+") + output.write('\n\n\n') + count = 3 + for key, value in sorted( + dictory.items(), key=lambda d: d[1], reverse=True): + output.write(key + "\n") + count += 1 + if count == dict_size: + break + self.dict_size = count + + def create_dataset(self, + dict_size=-1, + mergeDict=False, + suffixes=['.src', '.trg']): + """ + Create seqToseq dataset + """ + # dataset_list and dir_list has one-to-one relationship + train_dataset = os.path.join(self.data_path, self.train_dir_name) + test_dataset = os.path.join(self.data_path, self.test_dir_name) + gen_dataset = os.path.join(self.data_path, self.gen_dir_name) + dataset_list = [train_dataset, test_dataset, gen_dataset] + + train_dir = os.path.join(self.output_path, self.train_dir_name) + test_dir = os.path.join(self.output_path, self.test_dir_name) + gen_dir = os.path.join(self.output_path, self.gen_dir_name) + dir_list = [train_dir, test_dir, gen_dir] + + # create directory + for dir in dir_list: + if not os.path.exists(dir): + os.makedirs(dir) + + # checkout dataset should be parallel corpora + suffix_len = len(suffixes[0]) + for dataset in dataset_list: + file_list = os.listdir(dataset) + if len(file_list) % 2 == 1: + raise RuntimeError("dataset should be parallel corpora") + file_list.sort() + for i in range(0, len(file_list), 2): + if file_list[i][:-suffix_len] != file_list[i + 1][:-suffix_len]: + raise RuntimeError( + "source and target file name should be equal") + + # cat all the files with the same suffix in dataset + for suffix in suffixes: + for dataset in dataset_list: + outname = os.path.basename(dataset) + suffix + self.cat_file(dataset, suffix, dataset, outname) + + # concat parallel corpora and create file.list + print 'concat parallel corpora for dataset' + id = 0 + list = ['train.list', 'test.list', 'gen.list'] + for dataset in dataset_list: + outname = os.path.basename(dataset) + self.concat_file(dataset, outname + suffixes[0], + outname + suffixes[1], dir_list[id], outname) + save_list([os.path.join(dir_list[id], outname)], + os.path.join(self.output_path, list[id])) + id += 1 + + # build dictionary for train data + dict = ['src.dict', 'trg.dict'] + dict_path = [ + os.path.join(self.output_path, dict[0]), + os.path.join(self.output_path, dict[1]) + ] + if mergeDict: + outname = os.path.join(train_dir, train_dataset.split('/')[-1]) + print 'build src dictionary for train data' + self.build_dict(outname, dict_path[0], dict_size) + print 'build trg dictionary for train data' + os.system('cp ' + dict_path[0] + ' ' + dict_path[1]) + else: + outname = os.path.join(train_dataset, self.train_dir_name) + for id in range(0, 2): + suffix = suffixes[id] + print 'build ' + suffix[1:] + ' dictionary for train data' + self.build_dict(outname + suffix, dict_path[id], dict_size) + print 'dictionary size is', self.dict_size From c36a3f46070e8ef5102b6fb34362c50193d5f529 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 6 Mar 2017 14:51:15 +0800 Subject: [PATCH 07/13] Add unittest for serialize/deserialize. --- python/paddle/v2/parameters.py | 6 +++ python/paddle/v2/tests/run_tests.sh | 2 +- python/paddle/v2/tests/test_parameters.py | 60 +++++++++++++++++++++++ 3 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 python/paddle/v2/tests/test_parameters.py diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index 1fed0b8a6a..05dc5c68dd 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -124,6 +124,12 @@ class Parameters(object): if len(self.__gradient_machines__) == 0: # create new parameter in python numpy. + if len(self.__tmp_params__) != 0: + ret_list = [ + mat for name, mat in self.__tmp_params__ if name == key + ] + if len(ret_list) == 1: + return ret_list[0] return np.ndarray(shape=shape, dtype=np.float32) else: for each_gradient_machine in self.__gradient_machines__: diff --git a/python/paddle/v2/tests/run_tests.sh b/python/paddle/v2/tests/run_tests.sh index b96f54fe9c..dda1b1bd22 100755 --- a/python/paddle/v2/tests/run_tests.sh +++ b/python/paddle/v2/tests/run_tests.sh @@ -22,7 +22,7 @@ cd $SCRIPTPATH $1 -m pip install ../../../../paddle/dist/*.whl -test_list="test_data_feeder.py" +test_list="test_data_feeder.py test_parameters.py" export PYTHONPATH=$PWD/../../../../python/ diff --git a/python/paddle/v2/tests/test_parameters.py b/python/paddle/v2/tests/test_parameters.py new file mode 100644 index 0000000000..ebb182caab --- /dev/null +++ b/python/paddle/v2/tests/test_parameters.py @@ -0,0 +1,60 @@ +import unittest +import sys + +try: + import py_paddle + + del py_paddle +except ImportError: + print >> sys.stderr, "It seems swig of Paddle is not installed, this " \ + "unittest will not be run." + sys.exit(0) + +import paddle.v2.parameters as parameters +from paddle.proto.ParameterConfig_pb2 import ParameterConfig +import random +import cStringIO +import numpy + + +def __rand_param_config__(name): + conf = ParameterConfig() + conf.name = name + size = 1 + for i in xrange(2): + dim = random.randint(1, 1000) + conf.dims.append(dim) + size *= dim + conf.size = size + assert conf.IsInitialized() + return conf + + +class TestParameters(unittest.TestCase): + def test_serialization(self): + params = parameters.Parameters() + params.__append_config__(__rand_param_config__("param_0")) + params.__append_config__(__rand_param_config__("param_1")) + + for name in params.names(): + param = params.get(name) + param[:] = numpy.random.uniform( + -1.0, 1.0, size=params.get_shape(name)) + params.set(name, param) + + tmp_file = cStringIO.StringIO() + params.to_tar(tmp_file) + tmp_file.seek(0) + params_dup = parameters.Parameters.from_tar(tmp_file) + + self.assertEqual(params_dup.names(), params.names()) + + for name in params.names(): + self.assertEqual(params.get_shape(name), params_dup.get_shape(name)) + p0 = params.get(name) + p1 = params_dup.get(name) + self.assertTrue(numpy.isclose(p0, p1).all()) + + +if __name__ == '__main__': + unittest.main() From 26445368a2fb2d95598d80b5fad0d880c04bd0da Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 6 Mar 2017 15:45:19 +0800 Subject: [PATCH 08/13] Rename reader_dict to feeding * Also fix some other bugs. * Fix #1495 --- demo/image_classification/api_v2_train.py | 13 ++++++----- demo/introduction/api_train_v2.py | 22 ++++++++--------- demo/mnist/api_train_v2.py | 4 ++-- demo/semantic_role_labeling/api_train_v2.py | 6 ++--- demo/sentiment/train_v2.py | 23 +++++++----------- demo/seqToseq/api_train_v2.py | 6 ++--- python/paddle/v2/data_feeder.py | 24 ++++++++++++++----- python/paddle/v2/inference.py | 16 ++++--------- python/paddle/v2/trainer.py | 26 ++++++--------------- 9 files changed, 64 insertions(+), 76 deletions(-) diff --git a/demo/image_classification/api_v2_train.py b/demo/image_classification/api_v2_train.py index e0fc0e04bb..7134fa61e8 100644 --- a/demo/image_classification/api_v2_train.py +++ b/demo/image_classification/api_v2_train.py @@ -13,8 +13,9 @@ # limitations under the License import sys + import paddle.v2 as paddle -from api_v2_vgg import vgg_bn_drop + from api_v2_resnet import resnet_cifar10 @@ -23,7 +24,7 @@ def main(): classdim = 10 # PaddlePaddle init - paddle.init(use_gpu=True, trainer_count=1) + paddle.init(use_gpu=False, trainer_count=1) image = paddle.layer.data( name="image", type=paddle.data_type.dense_vector(datadim)) @@ -68,8 +69,8 @@ def main(): result = trainer.test( reader=paddle.batch( paddle.dataset.cifar.test10(), batch_size=128), - reader_dict={'image': 0, - 'label': 1}) + feeding={'image': 0, + 'label': 1}) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) # Create trainer @@ -83,8 +84,8 @@ def main(): batch_size=128), num_passes=5, event_handler=event_handler, - reader_dict={'image': 0, - 'label': 1}) + feeding={'image': 0, + 'label': 1}) if __name__ == '__main__': diff --git a/demo/introduction/api_train_v2.py b/demo/introduction/api_train_v2.py index 75dd65f9fc..84125c3b4b 100644 --- a/demo/introduction/api_train_v2.py +++ b/demo/introduction/api_train_v2.py @@ -30,26 +30,26 @@ def main(): def event_handler(event): if isinstance(event, paddle.event.EndIteration): if event.batch_id % 100 == 0: - print "Pass %d, Batch %d, Cost %f, %s" % ( - event.pass_id, event.batch_id, event.cost, event.metrics) + print "Pass %d, Batch %d, Cost %f" % ( + event.pass_id, event.batch_id, event.cost) if isinstance(event, paddle.event.EndPass): - result = trainer.test( - reader=paddle.reader.batched( - uci_housing.test(), batch_size=2), - reader_dict={'x': 0, + if (event.pass_id + 1) % 10 == 0: + result = trainer.test( + reader=paddle.batch( + uci_housing.test(), batch_size=2), + feeding={'x': 0, 'y': 1}) - if event.pass_id % 10 == 0: - print "Test %d, %s" % (event.pass_id, result.metrics) + print "Test %d, %.2f" % (event.pass_id, result.cost) # training trainer.train( - reader=paddle.reader.batched( + reader=paddle.batch( paddle.reader.shuffle( uci_housing.train(), buf_size=500), batch_size=2), - reader_dict={'x': 0, - 'y': 1}, + feeding={'x': 0, + 'y': 1}, event_handler=event_handler, num_passes=30) diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 072b2a08da..68761be80f 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -92,7 +92,7 @@ def main(): def event_handler(event): if isinstance(event, paddle.event.EndIteration): if event.batch_id % 1000 == 0: - result = trainer.test(reader=paddle.reader.batched( + result = trainer.test(reader=paddle.batch( paddle.dataset.mnist.test(), batch_size=256)) print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % ( @@ -103,7 +103,7 @@ def main(): parameters.to_tar(f) elif isinstance(event, paddle.event.EndPass): - result = trainer.test(reader=paddle.reader.batched( + result = trainer.test(reader=paddle.batch( paddle.dataset.mnist.test(), batch_size=128)) print "Test with Pass %d, Cost %f, %s\n" % ( event.pass_id, result.cost, result.metrics) diff --git a/demo/semantic_role_labeling/api_train_v2.py b/demo/semantic_role_labeling/api_train_v2.py index 15db922b97..036cad4b0a 100644 --- a/demo/semantic_role_labeling/api_train_v2.py +++ b/demo/semantic_role_labeling/api_train_v2.py @@ -163,11 +163,11 @@ def main(): update_equation=optimizer) parameters.set('emb', load_parameter(conll05.get_embedding(), 44068, 32)) - trn_reader = paddle.reader.batched( + trn_reader = paddle.batch( paddle.reader.shuffle( conll05.test(), buf_size=8192), batch_size=10) - reader_dict = { + feeding = { 'word_data': 0, 'ctx_n2_data': 1, 'ctx_n1_data': 2, @@ -183,7 +183,7 @@ def main(): reader=trn_reader, event_handler=event_handler, num_passes=10000, - reader_dict=reader_dict) + feeding=feeding) if __name__ == '__main__': diff --git a/demo/sentiment/train_v2.py b/demo/sentiment/train_v2.py index 3a266e74ea..fd7243cbe6 100644 --- a/demo/sentiment/train_v2.py +++ b/demo/sentiment/train_v2.py @@ -18,11 +18,7 @@ from paddle.trainer_config_helpers.poolings import MaxPooling import paddle.v2 as paddle -def convolution_net(input_dim, - class_dim=2, - emb_dim=128, - hid_dim=128, - is_predict=False): +def convolution_net(input_dim, class_dim=2, emb_dim=128, hid_dim=128): data = paddle.layer.data("word", paddle.data_type.integer_value_sequence(input_dim)) emb = paddle.layer.embedding(input=data, size=emb_dim) @@ -42,8 +38,7 @@ def stacked_lstm_net(input_dim, class_dim=2, emb_dim=128, hid_dim=512, - stacked_num=3, - is_predict=False): + stacked_num=3): """ A Wrapper for sentiment classification task. This network uses bi-directional recurrent network, @@ -110,7 +105,7 @@ def stacked_lstm_net(input_dim, if __name__ == '__main__': # init - paddle.init(use_gpu=True, trainer_count=4) + paddle.init(use_gpu=False, trainer_count=4) # network config print 'load dictionary...' @@ -143,11 +138,11 @@ if __name__ == '__main__': sys.stdout.flush() if isinstance(event, paddle.event.EndPass): result = trainer.test( - reader=paddle.reader.batched( + reader=paddle.batch( lambda: paddle.dataset.imdb.test(word_dict), batch_size=128), - reader_dict={'word': 0, - 'label': 1}) + feeding={'word': 0, + 'label': 1}) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) # create trainer @@ -156,11 +151,11 @@ if __name__ == '__main__': update_equation=adam_optimizer) trainer.train( - reader=paddle.reader.batched( + reader=paddle.batch( paddle.reader.shuffle( lambda: paddle.dataset.imdb.train(word_dict), buf_size=1000), batch_size=100), event_handler=event_handler, - reader_dict={'word': 0, - 'label': 1}, + feeding={'word': 0, + 'label': 1}, num_passes=10) diff --git a/demo/seqToseq/api_train_v2.py b/demo/seqToseq/api_train_v2.py index a5f59ec379..5b7506b152 100644 --- a/demo/seqToseq/api_train_v2.py +++ b/demo/seqToseq/api_train_v2.py @@ -80,13 +80,13 @@ def main(): update_equation=optimizer) # define data reader - reader_dict = { + feeding = { 'source_language_word': 0, 'target_language_word': 1, 'target_language_next_word': 2 } - wmt14_reader = paddle.reader.batched( + wmt14_reader = paddle.batch( paddle.reader.shuffle( train_reader("data/pre-wmt14/train/train"), buf_size=8192), batch_size=5) @@ -103,7 +103,7 @@ def main(): reader=wmt14_reader, event_handler=event_handler, num_passes=10000, - reader_dict=reader_dict) + feeding=feeding) if __name__ == '__main__': diff --git a/python/paddle/v2/data_feeder.py b/python/paddle/v2/data_feeder.py index b7465238be..ba77fecf21 100644 --- a/python/paddle/v2/data_feeder.py +++ b/python/paddle/v2/data_feeder.py @@ -14,11 +14,18 @@ from py_paddle import DataProviderConverter -import data_type +import paddle.trainer.PyDataProvider2 as pydp2 __all__ = ['DataFeeder'] +def default_feeding_map(data_types): + reader_dict = dict() + for i, tp in enumerate(data_types): + reader_dict[tp[0]] = i + return reader_dict + + class DataFeeder(DataProviderConverter): """ DataFeeder converts the data returned by paddle.reader into a data structure @@ -60,16 +67,21 @@ class DataFeeder(DataProviderConverter): :type data_types: list :param reader_dict: A dictionary to specify the position of each data in the input data. - :type reader_dict: dict + :type feeding: dict """ - def __init__(self, data_types, reader_dict): + def __init__(self, data_types, feeding=None): self.input_names = [] input_types = [] - self.reader_dict = reader_dict + if feeding is None: + feeding = default_feeding_map(data_types) + + self.feeding = feeding for each in data_types: self.input_names.append(each[0]) - assert isinstance(each[1], data_type.InputType) + if not isinstance(each[1], pydp2.InputType): + raise TypeError("second item in each data_type should be an " + "InputType") input_types.append(each[1]) DataProviderConverter.__init__(self, input_types) @@ -90,7 +102,7 @@ class DataFeeder(DataProviderConverter): for each in data: reorder = [] for name in self.input_names: - reorder.append(each[self.reader_dict[name]]) + reorder.append(each[self.feeding[name]]) retv.append(reorder) return retv diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py index 476fd3fa45..7d889bce7f 100644 --- a/python/paddle/v2/inference.py +++ b/python/paddle/v2/inference.py @@ -21,10 +21,8 @@ class Inference(object): self.__gradient_machine__ = gm self.__data_types__ = topo.data_type() - def iter_infer(self, reader, reader_dict=None): - if reader_dict is None: - reader_dict = self.default_reader_dict() - feeder = DataFeeder(self.__data_types__, reader_dict) + def iter_infer(self, reader, feeding=None): + feeder = DataFeeder(self.__data_types__, feeding) self.__gradient_machine__.start() for data_batch in reader(): yield self.__gradient_machine__.forwardTest(feeder(data_batch)) @@ -47,13 +45,7 @@ class Inference(object): else: return retv - def default_reader_dict(self): - reader_dict = dict() - for i, tp in enumerate(self.__data_types__): - reader_dict[tp[0]] = i - return reader_dict - -def infer(output, parameters, reader, reader_dict=None, field='value'): +def infer(output, parameters, reader, feeding=None, field='value'): inferer = Inference(output=output, parameters=parameters) - return inferer.infer(field=field, reader=reader, reader_dict=reader_dict) + return inferer.infer(field=field, reader=reader, feeding=feeding) diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 187abaf9a3..7bd3e2c565 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -61,7 +61,7 @@ class SGD(object): self.__gradient_machine__.randParameters() parameters.append_gradient_machine(gm) - def train(self, reader, num_passes=1, event_handler=None, reader_dict=None): + def train(self, reader, num_passes=1, event_handler=None, feeding=None): """ Training method. Will train num_passes of input data. @@ -70,14 +70,13 @@ class SGD(object): :param event_handler: Event handler. A method will be invoked when event occurred. :type event_handler: (BaseEvent) => None + :param feeding: Feeding is a map of neural network input name and array + index that reader returns. + :type feeding: dict :return: """ if event_handler is None: event_handler = default_event_handler - - if reader_dict is None: - reader_dict = self.default_reader_dict() - __check_train_args__(**locals()) updater = self.__optimizer__.create_local_updater() @@ -89,9 +88,7 @@ class SGD(object): pass_evaluator = self.__gradient_machine__.makeEvaluator() assert isinstance(pass_evaluator, api.Evaluator) out_args = api.Arguments.createArguments(0) - - feeder = DataFeeder(self.__data_types__, reader_dict) - + feeder = DataFeeder(self.__data_types__, feeding) for pass_id in xrange(num_passes): event_handler(v2_event.BeginPass(pass_id)) pass_evaluator.start() @@ -125,17 +122,8 @@ class SGD(object): event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator)) self.__gradient_machine__.finish() - def default_reader_dict(self): - reader_dict = dict() - for i, tp in enumerate(self.__data_types__): - reader_dict[tp[0]] = i - return reader_dict - - def test(self, reader, reader_dict=None): - if reader_dict is None: - reader_dict = self.default_reader_dict() - - feeder = DataFeeder(self.__data_types__, reader_dict) + def test(self, reader, feeding=None): + feeder = DataFeeder(self.__data_types__, feeding) evaluator = self.__gradient_machine__.makeEvaluator() out_args = api.Arguments.createArguments(0) evaluator.start() From 82437970594e331f63bed25c2f3ab42b413e68d9 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Mon, 6 Mar 2017 18:54:47 +0800 Subject: [PATCH 09/13] add relu in layer_math.py --- .../trainer_config_helpers/layer_math.py | 1 + .../tests/configs/math_ops.py | 3 +- .../tests/configs/protostr/math_ops.protostr | 32 ++++++++++++++----- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/python/paddle/trainer_config_helpers/layer_math.py b/python/paddle/trainer_config_helpers/layer_math.py index 2d9e36f2b0..544b443825 100644 --- a/python/paddle/trainer_config_helpers/layer_math.py +++ b/python/paddle/trainer_config_helpers/layer_math.py @@ -39,6 +39,7 @@ register_unary_math_op('abs', act.AbsActivation()) register_unary_math_op('sigmoid', act.SigmoidActivation()) register_unary_math_op('tanh', act.TanhActivation()) register_unary_math_op('square', act.SquareActivation()) +register_unary_math_op('relu', act.ReluActivation()) def add(layeroutput, other): diff --git a/python/paddle/trainer_config_helpers/tests/configs/math_ops.py b/python/paddle/trainer_config_helpers/tests/configs/math_ops.py index 3331c10d64..24c901c8ee 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/math_ops.py +++ b/python/paddle/trainer_config_helpers/tests/configs/math_ops.py @@ -7,8 +7,9 @@ x = layer_math.exp(x) x = layer_math.log(x) x = layer_math.abs(x) x = layer_math.sigmoid(x) +x = layer_math.tanh(x) x = layer_math.square(x) -x = layer_math.square(x) +x = layer_math.relu(x) y = 1 + x y = y + 1 y = x + y diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/math_ops.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/math_ops.protostr index da8da1b541..9b8a2ad968 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/math_ops.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/math_ops.protostr @@ -65,13 +65,28 @@ layers { } } } +layers { + name: "__tanh_0__" + type: "mixed" + size: 100 + active_type: "tanh" + inputs { + input_layer_name: "__sigmoid_0__" + proj_conf { + type: "identity" + name: "___tanh_0__.w0" + input_size: 100 + output_size: 100 + } + } +} layers { name: "__square_0__" type: "mixed" size: 100 active_type: "square" inputs { - input_layer_name: "__sigmoid_0__" + input_layer_name: "__tanh_0__" proj_conf { type: "identity" name: "___square_0__.w0" @@ -81,15 +96,15 @@ layers { } } layers { - name: "__square_1__" + name: "__relu_0__" type: "mixed" size: 100 - active_type: "square" + active_type: "relu" inputs { input_layer_name: "__square_0__" proj_conf { type: "identity" - name: "___square_1__.w0" + name: "___relu_0__.w0" input_size: 100 output_size: 100 } @@ -101,7 +116,7 @@ layers { size: 100 active_type: "" inputs { - input_layer_name: "__square_1__" + input_layer_name: "__relu_0__" } slope: 1.0 intercept: 1 @@ -123,7 +138,7 @@ layers { size: 100 active_type: "" inputs { - input_layer_name: "__square_1__" + input_layer_name: "__relu_0__" proj_conf { type: "identity" name: "___mixed_0__.w0" @@ -147,7 +162,7 @@ layers { size: 100 active_type: "" inputs { - input_layer_name: "__square_1__" + input_layer_name: "__relu_0__" } slope: -1.0 intercept: 0.0 @@ -339,8 +354,9 @@ sub_models { layer_names: "__log_0__" layer_names: "__abs_0__" layer_names: "__sigmoid_0__" + layer_names: "__tanh_0__" layer_names: "__square_0__" - layer_names: "__square_1__" + layer_names: "__relu_0__" layer_names: "__slope_intercept_layer_0__" layer_names: "__slope_intercept_layer_1__" layer_names: "__mixed_0__" From 96a2e44aa3d0e1c566c08bee01e284f59277ef3c Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 6 Mar 2017 20:46:25 +0800 Subject: [PATCH 10/13] optimize seq2seq-dataset --- demo/sentiment/preprocess.py | 166 +++++++++++++++++++++++- python/paddle/v2/dataset/wmt14.py | 149 ++++++++++----------- python/paddle/v2/dataset/wmt14_util.py | 172 ------------------------- 3 files changed, 229 insertions(+), 258 deletions(-) delete mode 100644 python/paddle/v2/dataset/wmt14_util.py diff --git a/demo/sentiment/preprocess.py b/demo/sentiment/preprocess.py index 29b3682b74..59c3b5febe 100755 --- a/demo/sentiment/preprocess.py +++ b/demo/sentiment/preprocess.py @@ -12,22 +12,176 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import sys -import random import operator -import numpy as np -from subprocess import Popen, PIPE -from os.path import join as join_path from optparse import OptionParser +from os.path import join as join_path +from subprocess import Popen, PIPE +import numpy as np from paddle.utils.preprocess_util import * +from paddle.utils.preprocess_util import save_list, DatasetCreater """ Usage: run following command to show help message. python preprocess.py -h """ +class SeqToSeqDatasetCreater(DatasetCreater): + """ + A class to process data for sequence to sequence application. + """ + + def __init__(self, data_path, output_path): + """ + data_path: the path to store the train data, test data and gen data + output_path: the path to store the processed dataset + """ + DatasetCreater.__init__(self, data_path) + self.gen_dir_name = 'gen' + self.gen_list_name = 'gen.list' + self.output_path = output_path + + def concat_file(self, file_path, file1, file2, output_path, output): + """ + Concat file1 and file2 to be one output file + The i-th line of output = i-th line of file1 + '\t' + i-th line of file2 + file_path: the path to store file1 and file2 + output_path: the path to store output file + """ + file1 = os.path.join(file_path, file1) + file2 = os.path.join(file_path, file2) + output = os.path.join(output_path, output) + if not os.path.exists(output): + os.system('paste ' + file1 + ' ' + file2 + ' > ' + output) + + def cat_file(self, dir_path, suffix, output_path, output): + """ + Cat all the files in dir_path with suffix to be one output file + dir_path: the base directory to store input file + suffix: suffix of file name + output_path: the path to store output file + """ + cmd = 'cat ' + file_list = os.listdir(dir_path) + file_list.sort() + for file in file_list: + if file.endswith(suffix): + cmd += os.path.join(dir_path, file) + ' ' + output = os.path.join(output_path, output) + if not os.path.exists(output): + os.system(cmd + '> ' + output) + + def build_dict(self, file_path, dict_path, dict_size=-1): + """ + Create the dictionary for the file, Note that + 1. Valid characters include all printable characters + 2. There is distinction between uppercase and lowercase letters + 3. There is 3 special token: + : the start of a sequence + : the end of a sequence + : a word not included in dictionary + file_path: the path to store file + dict_path: the path to store dictionary + dict_size: word count of dictionary + if is -1, dictionary will contains all the words in file + """ + if not os.path.exists(dict_path): + dictory = dict() + with open(file_path, "r") as fdata: + for line in fdata: + line = line.split('\t') + for line_split in line: + words = line_split.strip().split() + for word in words: + if word not in dictory: + dictory[word] = 1 + else: + dictory[word] += 1 + output = open(dict_path, "w+") + output.write('\n\n\n') + count = 3 + for key, value in sorted( + dictory.items(), key=lambda d: d[1], reverse=True): + output.write(key + "\n") + count += 1 + if count == dict_size: + break + self.dict_size = count + + def create_dataset(self, + dict_size=-1, + mergeDict=False, + suffixes=['.src', '.trg']): + """ + Create seqToseq dataset + """ + # dataset_list and dir_list has one-to-one relationship + train_dataset = os.path.join(self.data_path, self.train_dir_name) + test_dataset = os.path.join(self.data_path, self.test_dir_name) + gen_dataset = os.path.join(self.data_path, self.gen_dir_name) + dataset_list = [train_dataset, test_dataset, gen_dataset] + + train_dir = os.path.join(self.output_path, self.train_dir_name) + test_dir = os.path.join(self.output_path, self.test_dir_name) + gen_dir = os.path.join(self.output_path, self.gen_dir_name) + dir_list = [train_dir, test_dir, gen_dir] + + # create directory + for dir in dir_list: + if not os.path.exists(dir): + os.makedirs(dir) + + # checkout dataset should be parallel corpora + suffix_len = len(suffixes[0]) + for dataset in dataset_list: + file_list = os.listdir(dataset) + if len(file_list) % 2 == 1: + raise RuntimeError("dataset should be parallel corpora") + file_list.sort() + for i in range(0, len(file_list), 2): + if file_list[i][:-suffix_len] != file_list[i + 1][:-suffix_len]: + raise RuntimeError( + "source and target file name should be equal") + + # cat all the files with the same suffix in dataset + for suffix in suffixes: + for dataset in dataset_list: + outname = os.path.basename(dataset) + suffix + self.cat_file(dataset, suffix, dataset, outname) + + # concat parallel corpora and create file.list + print 'concat parallel corpora for dataset' + id = 0 + list = ['train.list', 'test.list', 'gen.list'] + for dataset in dataset_list: + outname = os.path.basename(dataset) + self.concat_file(dataset, outname + suffixes[0], + outname + suffixes[1], dir_list[id], outname) + save_list([os.path.join(dir_list[id], outname)], + os.path.join(self.output_path, list[id])) + id += 1 + + # build dictionary for train data + dict = ['src.dict', 'trg.dict'] + dict_path = [ + os.path.join(self.output_path, dict[0]), + os.path.join(self.output_path, dict[1]) + ] + if mergeDict: + outname = os.path.join(train_dir, train_dataset.split('/')[-1]) + print 'build src dictionary for train data' + self.build_dict(outname, dict_path[0], dict_size) + print 'build trg dictionary for train data' + os.system('cp ' + dict_path[0] + ' ' + dict_path[1]) + else: + outname = os.path.join(train_dataset, self.train_dir_name) + for id in range(0, 2): + suffix = suffixes[id] + print 'build ' + suffix[1:] + ' dictionary for train data' + self.build_dict(outname + suffix, dict_path[id], dict_size) + print 'dictionary size is', self.dict_size + + def save_dict(dict, filename, is_reverse=True): """ Save dictionary into file. diff --git a/python/paddle/v2/dataset/wmt14.py b/python/paddle/v2/dataset/wmt14.py index 254f07c8dd..f8637c0a00 100644 --- a/python/paddle/v2/dataset/wmt14.py +++ b/python/paddle/v2/dataset/wmt14.py @@ -14,103 +14,92 @@ """ wmt14 dataset """ -import os -import os.path import tarfile import paddle.v2.dataset.common -from wmt14_util import SeqToSeqDatasetCreater __all__ = ['train', 'test', 'build_dict'] URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz' MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5' # this is a small set of data for test. The original data is too large and will be add later. -URL_TRAIN = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz' -MD5_TRAIN = '7373473f86016f1f48037c9c340a2d5b' +URL_TRAIN = 'http://localhost:8989/wmt14.tgz' +MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6' START = "" END = "" UNK = "" UNK_IDX = 2 -DEFAULT_DATA_DIR = "./data" -ORIGIN_DATA_DIR = "wmt14" -INNER_DATA_DIR = "pre-wmt14" -SRC_DICT = INNER_DATA_DIR + "/src.dict" -TRG_DICT = INNER_DATA_DIR + "/trg.dict" -TRAIN_FILE = INNER_DATA_DIR + "/train/train" - - -def __process_data__(data_path, dict_size=None): - downloaded_data = os.path.join(data_path, ORIGIN_DATA_DIR) - if not os.path.exists(downloaded_data): - # 1. download and extract tgz. - with tarfile.open( - paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', - MD5_TRAIN)) as tf: - tf.extractall(data_path) - - # 2. process data file to intermediate format. - processed_data = os.path.join(data_path, INNER_DATA_DIR) - if not os.path.exists(processed_data): - dict_size = dict_size or -1 - data_creator = SeqToSeqDatasetCreater(downloaded_data, processed_data) - data_creator.create_dataset(dict_size, mergeDict=False) - - -def __read_to_dict__(dict_path, count): - with open(dict_path, "r") as fin: + +def __read_to_dict__(tar_file, dict_size): + def __to_dict__(fd, size): out_dict = dict() - for line_count, line in enumerate(fin): - if line_count <= count: + for line_count, line in enumerate(fd): + if line_count < size: out_dict[line.strip()] = line_count else: break - return out_dict - - -def __reader__(file_name, src_dict, trg_dict): - with open(file_name, 'r') as f: - for line_count, line in enumerate(f): - line_split = line.strip().split('\t') - if len(line_split) != 2: - continue - src_seq = line_split[0] # one source sequence - src_words = src_seq.split() - src_ids = [ - src_dict.get(w, UNK_IDX) for w in [START] + src_words + [END] + return out_dict + + with tarfile.open(tar_file, mode='r') as f: + names = [ + each_item.name for each_item in f + if each_item.name.endswith("src.dict") + ] + assert len(names) == 1 + src_dict = __to_dict__(f.extractfile(names[0]), dict_size) + names = [ + each_item.name for each_item in f + if each_item.name.endswith("trg.dict") + ] + assert len(names) == 1 + trg_dict = __to_dict__(f.extractfile(names[0]), dict_size) + return src_dict, trg_dict + + +def reader_creator(tar_file, file_name, dict_size): + def reader(): + src_dict, trg_dict = __read_to_dict__(tar_file, dict_size) + with tarfile.open(tar_file, mode='r') as f: + names = [ + each_item.name for each_item in f + if each_item.name.endswith(file_name) ] - - trg_seq = line_split[1] # one target sequence - trg_words = trg_seq.split() - trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words] - - # remove sequence whose length > 80 in training mode - if len(src_ids) > 80 or len(trg_ids) > 80: - continue - trg_ids_next = trg_ids + [trg_dict[END]] - trg_ids = [trg_dict[START]] + trg_ids - - yield src_ids, trg_ids, trg_ids_next - - -def train(data_dir=None, dict_size=None): - data_dir = data_dir or DEFAULT_DATA_DIR - __process_data__(data_dir, dict_size) - src_lang_dict = os.path.join(data_dir, SRC_DICT) - trg_lang_dict = os.path.join(data_dir, TRG_DICT) - train_file_name = os.path.join(data_dir, TRAIN_FILE) - - default_dict_size = len(open(src_lang_dict, "r").readlines()) - - if dict_size > default_dict_size: - raise ValueError("dict_dim should not be larger then the " - "length of word dict") - - real_dict_dim = dict_size or default_dict_size - - src_dict = __read_to_dict__(src_lang_dict, real_dict_dim) - trg_dict = __read_to_dict__(trg_lang_dict, real_dict_dim) - - return lambda: __reader__(train_file_name, src_dict, trg_dict) + for name in names: + for line in f.extractfile(name): + line_split = line.strip().split('\t') + if len(line_split) != 2: + continue + src_seq = line_split[0] # one source sequence + src_words = src_seq.split() + src_ids = [ + src_dict.get(w, UNK_IDX) + for w in [START] + src_words + [END] + ] + + trg_seq = line_split[1] # one target sequence + trg_words = trg_seq.split() + trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words] + + # remove sequence whose length > 80 in training mode + if len(src_ids) > 80 or len(trg_ids) > 80: + continue + trg_ids_next = trg_ids + [trg_dict[END]] + trg_ids = [trg_dict[START]] + trg_ids + + yield src_ids, trg_ids, trg_ids_next + + return reader + + +def train(dict_size): + return reader_creator( + paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN), + 'train/train', dict_size) + + +def test(dict_size): + return reader_creator( + paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN), + 'test/test', dict_size) diff --git a/python/paddle/v2/dataset/wmt14_util.py b/python/paddle/v2/dataset/wmt14_util.py deleted file mode 100644 index 0d72389164..0000000000 --- a/python/paddle/v2/dataset/wmt14_util.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -from paddle.utils.preprocess_util import save_list, DatasetCreater - - -class SeqToSeqDatasetCreater(DatasetCreater): - """ - A class to process data for sequence to sequence application. - """ - - def __init__(self, data_path, output_path): - """ - data_path: the path to store the train data, test data and gen data - output_path: the path to store the processed dataset - """ - DatasetCreater.__init__(self, data_path) - self.gen_dir_name = 'gen' - self.gen_list_name = 'gen.list' - self.output_path = output_path - - def concat_file(self, file_path, file1, file2, output_path, output): - """ - Concat file1 and file2 to be one output file - The i-th line of output = i-th line of file1 + '\t' + i-th line of file2 - file_path: the path to store file1 and file2 - output_path: the path to store output file - """ - file1 = os.path.join(file_path, file1) - file2 = os.path.join(file_path, file2) - output = os.path.join(output_path, output) - if not os.path.exists(output): - os.system('paste ' + file1 + ' ' + file2 + ' > ' + output) - - def cat_file(self, dir_path, suffix, output_path, output): - """ - Cat all the files in dir_path with suffix to be one output file - dir_path: the base directory to store input file - suffix: suffix of file name - output_path: the path to store output file - """ - cmd = 'cat ' - file_list = os.listdir(dir_path) - file_list.sort() - for file in file_list: - if file.endswith(suffix): - cmd += os.path.join(dir_path, file) + ' ' - output = os.path.join(output_path, output) - if not os.path.exists(output): - os.system(cmd + '> ' + output) - - def build_dict(self, file_path, dict_path, dict_size=-1): - """ - Create the dictionary for the file, Note that - 1. Valid characters include all printable characters - 2. There is distinction between uppercase and lowercase letters - 3. There is 3 special token: - : the start of a sequence - : the end of a sequence - : a word not included in dictionary - file_path: the path to store file - dict_path: the path to store dictionary - dict_size: word count of dictionary - if is -1, dictionary will contains all the words in file - """ - if not os.path.exists(dict_path): - dictory = dict() - with open(file_path, "r") as fdata: - for line in fdata: - line = line.split('\t') - for line_split in line: - words = line_split.strip().split() - for word in words: - if word not in dictory: - dictory[word] = 1 - else: - dictory[word] += 1 - output = open(dict_path, "w+") - output.write('\n\n\n') - count = 3 - for key, value in sorted( - dictory.items(), key=lambda d: d[1], reverse=True): - output.write(key + "\n") - count += 1 - if count == dict_size: - break - self.dict_size = count - - def create_dataset(self, - dict_size=-1, - mergeDict=False, - suffixes=['.src', '.trg']): - """ - Create seqToseq dataset - """ - # dataset_list and dir_list has one-to-one relationship - train_dataset = os.path.join(self.data_path, self.train_dir_name) - test_dataset = os.path.join(self.data_path, self.test_dir_name) - gen_dataset = os.path.join(self.data_path, self.gen_dir_name) - dataset_list = [train_dataset, test_dataset, gen_dataset] - - train_dir = os.path.join(self.output_path, self.train_dir_name) - test_dir = os.path.join(self.output_path, self.test_dir_name) - gen_dir = os.path.join(self.output_path, self.gen_dir_name) - dir_list = [train_dir, test_dir, gen_dir] - - # create directory - for dir in dir_list: - if not os.path.exists(dir): - os.makedirs(dir) - - # checkout dataset should be parallel corpora - suffix_len = len(suffixes[0]) - for dataset in dataset_list: - file_list = os.listdir(dataset) - if len(file_list) % 2 == 1: - raise RuntimeError("dataset should be parallel corpora") - file_list.sort() - for i in range(0, len(file_list), 2): - if file_list[i][:-suffix_len] != file_list[i + 1][:-suffix_len]: - raise RuntimeError( - "source and target file name should be equal") - - # cat all the files with the same suffix in dataset - for suffix in suffixes: - for dataset in dataset_list: - outname = os.path.basename(dataset) + suffix - self.cat_file(dataset, suffix, dataset, outname) - - # concat parallel corpora and create file.list - print 'concat parallel corpora for dataset' - id = 0 - list = ['train.list', 'test.list', 'gen.list'] - for dataset in dataset_list: - outname = os.path.basename(dataset) - self.concat_file(dataset, outname + suffixes[0], - outname + suffixes[1], dir_list[id], outname) - save_list([os.path.join(dir_list[id], outname)], - os.path.join(self.output_path, list[id])) - id += 1 - - # build dictionary for train data - dict = ['src.dict', 'trg.dict'] - dict_path = [ - os.path.join(self.output_path, dict[0]), - os.path.join(self.output_path, dict[1]) - ] - if mergeDict: - outname = os.path.join(train_dir, train_dataset.split('/')[-1]) - print 'build src dictionary for train data' - self.build_dict(outname, dict_path[0], dict_size) - print 'build trg dictionary for train data' - os.system('cp ' + dict_path[0] + ' ' + dict_path[1]) - else: - outname = os.path.join(train_dataset, self.train_dir_name) - for id in range(0, 2): - suffix = suffixes[id] - print 'build ' + suffix[1:] + ' dictionary for train data' - self.build_dict(outname + suffix, dict_path[id], dict_size) - print 'dictionary size is', self.dict_size From 2ad9fd646f1cf612fb721d869f4b4c93e7aa5c2a Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 6 Mar 2017 20:52:18 +0800 Subject: [PATCH 11/13] revert preprocess.py --- demo/sentiment/preprocess.py | 166 ++--------------------------------- demo/seqToseq/preprocess.py | 160 ++++++++++++++++++++++++++++++++- 2 files changed, 165 insertions(+), 161 deletions(-) diff --git a/demo/sentiment/preprocess.py b/demo/sentiment/preprocess.py index 59c3b5febe..29b3682b74 100755 --- a/demo/sentiment/preprocess.py +++ b/demo/sentiment/preprocess.py @@ -12,176 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys +import random import operator -from optparse import OptionParser -from os.path import join as join_path +import numpy as np from subprocess import Popen, PIPE +from os.path import join as join_path +from optparse import OptionParser -import numpy as np from paddle.utils.preprocess_util import * -from paddle.utils.preprocess_util import save_list, DatasetCreater """ Usage: run following command to show help message. python preprocess.py -h """ -class SeqToSeqDatasetCreater(DatasetCreater): - """ - A class to process data for sequence to sequence application. - """ - - def __init__(self, data_path, output_path): - """ - data_path: the path to store the train data, test data and gen data - output_path: the path to store the processed dataset - """ - DatasetCreater.__init__(self, data_path) - self.gen_dir_name = 'gen' - self.gen_list_name = 'gen.list' - self.output_path = output_path - - def concat_file(self, file_path, file1, file2, output_path, output): - """ - Concat file1 and file2 to be one output file - The i-th line of output = i-th line of file1 + '\t' + i-th line of file2 - file_path: the path to store file1 and file2 - output_path: the path to store output file - """ - file1 = os.path.join(file_path, file1) - file2 = os.path.join(file_path, file2) - output = os.path.join(output_path, output) - if not os.path.exists(output): - os.system('paste ' + file1 + ' ' + file2 + ' > ' + output) - - def cat_file(self, dir_path, suffix, output_path, output): - """ - Cat all the files in dir_path with suffix to be one output file - dir_path: the base directory to store input file - suffix: suffix of file name - output_path: the path to store output file - """ - cmd = 'cat ' - file_list = os.listdir(dir_path) - file_list.sort() - for file in file_list: - if file.endswith(suffix): - cmd += os.path.join(dir_path, file) + ' ' - output = os.path.join(output_path, output) - if not os.path.exists(output): - os.system(cmd + '> ' + output) - - def build_dict(self, file_path, dict_path, dict_size=-1): - """ - Create the dictionary for the file, Note that - 1. Valid characters include all printable characters - 2. There is distinction between uppercase and lowercase letters - 3. There is 3 special token: - : the start of a sequence - : the end of a sequence - : a word not included in dictionary - file_path: the path to store file - dict_path: the path to store dictionary - dict_size: word count of dictionary - if is -1, dictionary will contains all the words in file - """ - if not os.path.exists(dict_path): - dictory = dict() - with open(file_path, "r") as fdata: - for line in fdata: - line = line.split('\t') - for line_split in line: - words = line_split.strip().split() - for word in words: - if word not in dictory: - dictory[word] = 1 - else: - dictory[word] += 1 - output = open(dict_path, "w+") - output.write('\n\n\n') - count = 3 - for key, value in sorted( - dictory.items(), key=lambda d: d[1], reverse=True): - output.write(key + "\n") - count += 1 - if count == dict_size: - break - self.dict_size = count - - def create_dataset(self, - dict_size=-1, - mergeDict=False, - suffixes=['.src', '.trg']): - """ - Create seqToseq dataset - """ - # dataset_list and dir_list has one-to-one relationship - train_dataset = os.path.join(self.data_path, self.train_dir_name) - test_dataset = os.path.join(self.data_path, self.test_dir_name) - gen_dataset = os.path.join(self.data_path, self.gen_dir_name) - dataset_list = [train_dataset, test_dataset, gen_dataset] - - train_dir = os.path.join(self.output_path, self.train_dir_name) - test_dir = os.path.join(self.output_path, self.test_dir_name) - gen_dir = os.path.join(self.output_path, self.gen_dir_name) - dir_list = [train_dir, test_dir, gen_dir] - - # create directory - for dir in dir_list: - if not os.path.exists(dir): - os.makedirs(dir) - - # checkout dataset should be parallel corpora - suffix_len = len(suffixes[0]) - for dataset in dataset_list: - file_list = os.listdir(dataset) - if len(file_list) % 2 == 1: - raise RuntimeError("dataset should be parallel corpora") - file_list.sort() - for i in range(0, len(file_list), 2): - if file_list[i][:-suffix_len] != file_list[i + 1][:-suffix_len]: - raise RuntimeError( - "source and target file name should be equal") - - # cat all the files with the same suffix in dataset - for suffix in suffixes: - for dataset in dataset_list: - outname = os.path.basename(dataset) + suffix - self.cat_file(dataset, suffix, dataset, outname) - - # concat parallel corpora and create file.list - print 'concat parallel corpora for dataset' - id = 0 - list = ['train.list', 'test.list', 'gen.list'] - for dataset in dataset_list: - outname = os.path.basename(dataset) - self.concat_file(dataset, outname + suffixes[0], - outname + suffixes[1], dir_list[id], outname) - save_list([os.path.join(dir_list[id], outname)], - os.path.join(self.output_path, list[id])) - id += 1 - - # build dictionary for train data - dict = ['src.dict', 'trg.dict'] - dict_path = [ - os.path.join(self.output_path, dict[0]), - os.path.join(self.output_path, dict[1]) - ] - if mergeDict: - outname = os.path.join(train_dir, train_dataset.split('/')[-1]) - print 'build src dictionary for train data' - self.build_dict(outname, dict_path[0], dict_size) - print 'build trg dictionary for train data' - os.system('cp ' + dict_path[0] + ' ' + dict_path[1]) - else: - outname = os.path.join(train_dataset, self.train_dir_name) - for id in range(0, 2): - suffix = suffixes[id] - print 'build ' + suffix[1:] + ' dictionary for train data' - self.build_dict(outname + suffix, dict_path[id], dict_size) - print 'dictionary size is', self.dict_size - - def save_dict(dict, filename, is_reverse=True): """ Save dictionary into file. diff --git a/demo/seqToseq/preprocess.py b/demo/seqToseq/preprocess.py index afa7bd5e0f..03f371331a 100755 --- a/demo/seqToseq/preprocess.py +++ b/demo/seqToseq/preprocess.py @@ -23,9 +23,167 @@ Options: -m --mergeDict merge source and target dictionary """ import os +import sys + +import string from optparse import OptionParser +from paddle.utils.preprocess_util import save_list, DatasetCreater + + +class SeqToSeqDatasetCreater(DatasetCreater): + """ + A class to process data for sequence to sequence application. + """ + + def __init__(self, data_path, output_path): + """ + data_path: the path to store the train data, test data and gen data + output_path: the path to store the processed dataset + """ + DatasetCreater.__init__(self, data_path) + self.gen_dir_name = 'gen' + self.gen_list_name = 'gen.list' + self.output_path = output_path + + def concat_file(self, file_path, file1, file2, output_path, output): + """ + Concat file1 and file2 to be one output file + The i-th line of output = i-th line of file1 + '\t' + i-th line of file2 + file_path: the path to store file1 and file2 + output_path: the path to store output file + """ + file1 = os.path.join(file_path, file1) + file2 = os.path.join(file_path, file2) + output = os.path.join(output_path, output) + if not os.path.exists(output): + os.system('paste ' + file1 + ' ' + file2 + ' > ' + output) + + def cat_file(self, dir_path, suffix, output_path, output): + """ + Cat all the files in dir_path with suffix to be one output file + dir_path: the base directory to store input file + suffix: suffix of file name + output_path: the path to store output file + """ + cmd = 'cat ' + file_list = os.listdir(dir_path) + file_list.sort() + for file in file_list: + if file.endswith(suffix): + cmd += os.path.join(dir_path, file) + ' ' + output = os.path.join(output_path, output) + if not os.path.exists(output): + os.system(cmd + '> ' + output) + + def build_dict(self, file_path, dict_path, dict_size=-1): + """ + Create the dictionary for the file, Note that + 1. Valid characters include all printable characters + 2. There is distinction between uppercase and lowercase letters + 3. There is 3 special token: + : the start of a sequence + : the end of a sequence + : a word not included in dictionary + file_path: the path to store file + dict_path: the path to store dictionary + dict_size: word count of dictionary + if is -1, dictionary will contains all the words in file + """ + if not os.path.exists(dict_path): + dictory = dict() + with open(file_path, "r") as fdata: + for line in fdata: + line = line.split('\t') + for line_split in line: + words = line_split.strip().split() + for word in words: + if word not in dictory: + dictory[word] = 1 + else: + dictory[word] += 1 + output = open(dict_path, "w+") + output.write('\n\n\n') + count = 3 + for key, value in sorted( + dictory.items(), key=lambda d: d[1], reverse=True): + output.write(key + "\n") + count += 1 + if count == dict_size: + break + self.dict_size = count + + def create_dataset(self, + dict_size=-1, + mergeDict=False, + suffixes=['.src', '.trg']): + """ + Create seqToseq dataset + """ + # dataset_list and dir_list has one-to-one relationship + train_dataset = os.path.join(self.data_path, self.train_dir_name) + test_dataset = os.path.join(self.data_path, self.test_dir_name) + gen_dataset = os.path.join(self.data_path, self.gen_dir_name) + dataset_list = [train_dataset, test_dataset, gen_dataset] + + train_dir = os.path.join(self.output_path, self.train_dir_name) + test_dir = os.path.join(self.output_path, self.test_dir_name) + gen_dir = os.path.join(self.output_path, self.gen_dir_name) + dir_list = [train_dir, test_dir, gen_dir] + + # create directory + for dir in dir_list: + if not os.path.exists(dir): + os.mkdir(dir) + + # checkout dataset should be parallel corpora + suffix_len = len(suffixes[0]) + for dataset in dataset_list: + file_list = os.listdir(dataset) + if len(file_list) % 2 == 1: + raise RuntimeError("dataset should be parallel corpora") + file_list.sort() + for i in range(0, len(file_list), 2): + if file_list[i][:-suffix_len] != file_list[i + 1][:-suffix_len]: + raise RuntimeError( + "source and target file name should be equal") + + # cat all the files with the same suffix in dataset + for suffix in suffixes: + for dataset in dataset_list: + outname = os.path.basename(dataset) + suffix + self.cat_file(dataset, suffix, dataset, outname) + + # concat parallel corpora and create file.list + print 'concat parallel corpora for dataset' + id = 0 + list = ['train.list', 'test.list', 'gen.list'] + for dataset in dataset_list: + outname = os.path.basename(dataset) + self.concat_file(dataset, outname + suffixes[0], + outname + suffixes[1], dir_list[id], outname) + save_list([os.path.join(dir_list[id], outname)], + os.path.join(self.output_path, list[id])) + id += 1 -from paddle.v2.dataset.wmt14_util import SeqToSeqDatasetCreater + # build dictionary for train data + dict = ['src.dict', 'trg.dict'] + dict_path = [ + os.path.join(self.output_path, dict[0]), + os.path.join(self.output_path, dict[1]) + ] + if mergeDict: + outname = os.path.join(train_dir, train_dataset.split('/')[-1]) + print 'build src dictionary for train data' + self.build_dict(outname, dict_path[0], dict_size) + print 'build trg dictionary for train data' + os.system('cp ' + dict_path[0] + ' ' + dict_path[1]) + else: + outname = os.path.join(train_dataset, self.train_dir_name) + for id in range(0, 2): + suffix = suffixes[id] + print 'build ' + suffix[1:] + ' dictionary for train data' + self.build_dict(outname + suffix, dict_path[id], dict_size) + print 'dictionary size is', self.dict_size def main(): From 1940f58f5393497c4fe7b58e66b695d272ab1e77 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 6 Mar 2017 20:59:29 +0800 Subject: [PATCH 12/13] add bos url --- python/paddle/v2/dataset/wmt14.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/v2/dataset/wmt14.py b/python/paddle/v2/dataset/wmt14.py index f8637c0a00..f5a16d5147 100644 --- a/python/paddle/v2/dataset/wmt14.py +++ b/python/paddle/v2/dataset/wmt14.py @@ -23,7 +23,7 @@ __all__ = ['train', 'test', 'build_dict'] URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz' MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5' # this is a small set of data for test. The original data is too large and will be add later. -URL_TRAIN = 'http://localhost:8989/wmt14.tgz' +URL_TRAIN = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz' MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6' START = "" From 929a390f6c14821be9b4dafd5d150225a23b8cf6 Mon Sep 17 00:00:00 2001 From: gangliao Date: Tue, 7 Mar 2017 11:24:32 +0800 Subject: [PATCH 13/13] Set VGG in image classification demo --- demo/image_classification/api_v2_train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/demo/image_classification/api_v2_train.py b/demo/image_classification/api_v2_train.py index 7134fa61e8..53cffa6fb4 100644 --- a/demo/image_classification/api_v2_train.py +++ b/demo/image_classification/api_v2_train.py @@ -16,7 +16,7 @@ import sys import paddle.v2 as paddle -from api_v2_resnet import resnet_cifar10 +from api_v2_vgg import vgg_bn_drop def main(): @@ -31,9 +31,9 @@ def main(): # Add neural network config # option 1. resnet - net = resnet_cifar10(image, depth=32) + # net = resnet_cifar10(image, depth=32) # option 2. vgg - # net = vgg_bn_drop(image) + net = vgg_bn_drop(image) out = paddle.layer.fc(input=net, size=classdim,