Add distributed unit tests about text_classification/simnet-bow/ctr (#12812)
* add dist ut for text_classification * add dist ut for text_classification * add simnet bow unittest * add dist ut for simnet bow * add trainning data url for simnet bow * add trainning data url for simnet bow * modify simnet test_reader to train reader * add test_dist_ctr * test_dist_ctr can run now * dense update is good * add unit test for selected rows * debug unit test * fix dist sparse update problem * Constant args at init * optimize code * simnet optimize * fix DebugStringEx * optimize sum_op.h * add ScaleOpVarTypeInference * clean code * fix test_dist_transpiler.py * code optimize * modify delta * fix sparse update bug * dist test use one cpu * update some data * remove unused code * add use cuda config * unit test fix * unit test fix * unit test fix * unit test fix * dist_word2vec use CPU * unit test fix * unit test fix * code clean * code clean * merge develop * api spec update * Revert: api spec update * replace simnet data with fake * replace simnet data with fake * update dim * add batch auc * code clean * code clean * modify print to stderr * update simnet delta -> 1e-5 * update RUN_STEP * add use_reader_alloc * add use_reader_alloc * add use_reader_alloc * modify delta * add use_reader_alloc * fix stderr write * python3 compatibility test=develop * python3 compatibility, test=develop * Update dist_text_classification.py * test=developrevert-13637-optimize-opyreader
parent
1ab7b55162
commit
97cf1eb6d7
@ -0,0 +1,109 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
|
||||
import dist_ctr_reader
|
||||
from test_dist_base import TestDistRunnerBase, runtime_main
|
||||
|
||||
IS_SPARSE = True
|
||||
|
||||
# Fix seed for test
|
||||
fluid.default_startup_program().random_seed = 1
|
||||
fluid.default_main_program().random_seed = 1
|
||||
|
||||
|
||||
class TestDistCTR2x2(TestDistRunnerBase):
|
||||
def get_model(self, batch_size=2):
|
||||
dnn_input_dim, lr_input_dim = dist_ctr_reader.load_data_meta()
|
||||
""" network definition """
|
||||
dnn_data = fluid.layers.data(
|
||||
name="dnn_data",
|
||||
shape=[-1, 1],
|
||||
dtype="int64",
|
||||
lod_level=1,
|
||||
append_batch_size=False)
|
||||
lr_data = fluid.layers.data(
|
||||
name="lr_data",
|
||||
shape=[-1, 1],
|
||||
dtype="int64",
|
||||
lod_level=1,
|
||||
append_batch_size=False)
|
||||
label = fluid.layers.data(
|
||||
name="click",
|
||||
shape=[-1, 1],
|
||||
dtype="int64",
|
||||
lod_level=0,
|
||||
append_batch_size=False)
|
||||
|
||||
# build dnn model
|
||||
dnn_layer_dims = [128, 64, 32, 1]
|
||||
dnn_embedding = fluid.layers.embedding(
|
||||
is_distributed=False,
|
||||
input=dnn_data,
|
||||
size=[dnn_input_dim, dnn_layer_dims[0]],
|
||||
param_attr=fluid.ParamAttr(
|
||||
name="deep_embedding",
|
||||
initializer=fluid.initializer.Constant(value=0.01)),
|
||||
is_sparse=IS_SPARSE)
|
||||
dnn_pool = fluid.layers.sequence_pool(
|
||||
input=dnn_embedding, pool_type="sum")
|
||||
dnn_out = dnn_pool
|
||||
for i, dim in enumerate(dnn_layer_dims[1:]):
|
||||
fc = fluid.layers.fc(
|
||||
input=dnn_out,
|
||||
size=dim,
|
||||
act="relu",
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=0.01)),
|
||||
name='dnn-fc-%d' % i)
|
||||
dnn_out = fc
|
||||
|
||||
# build lr model
|
||||
lr_embbding = fluid.layers.embedding(
|
||||
is_distributed=False,
|
||||
input=lr_data,
|
||||
size=[lr_input_dim, 1],
|
||||
param_attr=fluid.ParamAttr(
|
||||
name="wide_embedding",
|
||||
initializer=fluid.initializer.Constant(value=0.01)),
|
||||
is_sparse=IS_SPARSE)
|
||||
lr_pool = fluid.layers.sequence_pool(input=lr_embbding, pool_type="sum")
|
||||
|
||||
merge_layer = fluid.layers.concat(input=[dnn_out, lr_pool], axis=1)
|
||||
|
||||
predict = fluid.layers.fc(input=merge_layer, size=2, act='softmax')
|
||||
acc = fluid.layers.accuracy(input=predict, label=label)
|
||||
auc_var, batch_auc_var, auc_states = fluid.layers.auc(input=predict,
|
||||
label=label)
|
||||
cost = fluid.layers.cross_entropy(input=predict, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
|
||||
inference_program = paddle.fluid.default_main_program().clone()
|
||||
|
||||
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.0001)
|
||||
sgd_optimizer.minimize(avg_cost)
|
||||
|
||||
dataset = dist_ctr_reader.Dataset()
|
||||
train_reader = paddle.batch(dataset.train(), batch_size=batch_size)
|
||||
test_reader = paddle.batch(dataset.test(), batch_size=batch_size)
|
||||
|
||||
return inference_program, avg_cost, train_reader, test_reader, None, predict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runtime_main(TestDistCTR2x2)
|
@ -0,0 +1,172 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import paddle
|
||||
import tarfile
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger("paddle")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
DATA_URL = "http://paddle-ctr-data.cdn.bcebos.com/avazu_ctr_data.tgz"
|
||||
DATA_MD5 = "c11df99fbd14e53cd4bfa6567344b26e"
|
||||
"""
|
||||
avazu_ctr_data/train.txt
|
||||
avazu_ctr_data/infer.txt
|
||||
avazu_ctr_data/test.txt
|
||||
avazu_ctr_data/data.meta.txt
|
||||
"""
|
||||
|
||||
|
||||
def read_data(file_name):
|
||||
path = paddle.dataset.common.download(DATA_URL, "avazu_ctr_data", DATA_MD5)
|
||||
tar = tarfile.open(path, "r:gz")
|
||||
tar_info = None
|
||||
for member in tar.getmembers():
|
||||
if member.name.endswith(file_name):
|
||||
tar_info = member
|
||||
f = tar.extractfile(tar_info)
|
||||
ret_lines = [_.decode('utf-8') for _ in f.readlines()]
|
||||
return ret_lines
|
||||
|
||||
|
||||
class TaskMode:
|
||||
TRAIN_MODE = 0
|
||||
TEST_MODE = 1
|
||||
INFER_MODE = 2
|
||||
|
||||
def __init__(self, mode):
|
||||
self.mode = mode
|
||||
|
||||
def is_train(self):
|
||||
return self.mode == self.TRAIN_MODE
|
||||
|
||||
def is_test(self):
|
||||
return self.mode == self.TEST_MODE
|
||||
|
||||
def is_infer(self):
|
||||
return self.mode == self.INFER_MODE
|
||||
|
||||
@staticmethod
|
||||
def create_train():
|
||||
return TaskMode(TaskMode.TRAIN_MODE)
|
||||
|
||||
@staticmethod
|
||||
def create_test():
|
||||
return TaskMode(TaskMode.TEST_MODE)
|
||||
|
||||
@staticmethod
|
||||
def create_infer():
|
||||
return TaskMode(TaskMode.INFER_MODE)
|
||||
|
||||
|
||||
class ModelType:
|
||||
CLASSIFICATION = 0
|
||||
REGRESSION = 1
|
||||
|
||||
def __init__(self, mode):
|
||||
self.mode = mode
|
||||
|
||||
def is_classification(self):
|
||||
return self.mode == self.CLASSIFICATION
|
||||
|
||||
def is_regression(self):
|
||||
return self.mode == self.REGRESSION
|
||||
|
||||
@staticmethod
|
||||
def create_classification():
|
||||
return ModelType(ModelType.CLASSIFICATION)
|
||||
|
||||
@staticmethod
|
||||
def create_regression():
|
||||
return ModelType(ModelType.REGRESSION)
|
||||
|
||||
|
||||
def load_dnn_input_record(sent):
|
||||
return list(map(int, sent.split()))
|
||||
|
||||
|
||||
def load_lr_input_record(sent):
|
||||
res = []
|
||||
for _ in [x.split(':') for x in sent.split()]:
|
||||
res.append(int(_[0]))
|
||||
return res
|
||||
|
||||
|
||||
feeding_index = {'dnn_input': 0, 'lr_input': 1, 'click': 2}
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
def train(self):
|
||||
'''
|
||||
Load trainset.
|
||||
'''
|
||||
file_name = "train.txt"
|
||||
logger.info("load trainset from %s" % file_name)
|
||||
mode = TaskMode.create_train()
|
||||
return self._parse_creator(file_name, mode)
|
||||
|
||||
def test(self):
|
||||
'''
|
||||
Load testset.
|
||||
'''
|
||||
file_name = "test.txt"
|
||||
logger.info("load testset from %s" % file_name)
|
||||
mode = TaskMode.create_test()
|
||||
return self._parse_creator(file_name, mode)
|
||||
|
||||
def infer(self):
|
||||
'''
|
||||
Load infer set.
|
||||
'''
|
||||
file_name = "infer.txt"
|
||||
logger.info("load inferset from %s" % file_name)
|
||||
mode = TaskMode.create_infer()
|
||||
return self._parse_creator(file_name, mode)
|
||||
|
||||
def _parse_creator(self, file_name, mode):
|
||||
'''
|
||||
Parse dataset.
|
||||
'''
|
||||
|
||||
def _parse():
|
||||
data = read_data(file_name)
|
||||
for line_id, line in enumerate(data):
|
||||
fs = line.strip().split('\t')
|
||||
dnn_input = load_dnn_input_record(fs[0])
|
||||
lr_input = load_lr_input_record(fs[1])
|
||||
if not mode.is_infer():
|
||||
click = int(fs[2])
|
||||
yield [dnn_input, lr_input, click]
|
||||
else:
|
||||
yield [dnn_input, lr_input]
|
||||
|
||||
return _parse
|
||||
|
||||
|
||||
def load_data_meta():
|
||||
'''
|
||||
load data meta info from path, return (dnn_input_dim, lr_input_dim)
|
||||
'''
|
||||
lines = read_data('data.meta.txt')
|
||||
err_info = "wrong meta format"
|
||||
assert len(lines) == 2, err_info
|
||||
assert 'dnn_input_dim:' in lines[0] and 'lr_input_dim:' in lines[
|
||||
1], err_info
|
||||
res = map(int, [_.split(':')[1] for _ in lines])
|
||||
res = list(res)
|
||||
logger.info('dnn input dim: %d' % res[0])
|
||||
logger.info('lr input dim: %d' % res[1])
|
||||
return res
|
@ -0,0 +1,238 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import random
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.profiler as profiler
|
||||
from paddle.fluid import core
|
||||
import unittest
|
||||
from multiprocessing import Process
|
||||
import os
|
||||
import signal
|
||||
from functools import reduce
|
||||
from test_dist_base import TestDistRunnerBase, runtime_main
|
||||
|
||||
DTYPE = "int64"
|
||||
DATA_URL = 'http://paddle-dist-ce-data.bj.bcebos.com/simnet.train.1000'
|
||||
DATA_MD5 = '24e49366eb0611c552667989de2f57d5'
|
||||
|
||||
# For Net
|
||||
base_lr = 0.2
|
||||
emb_lr = base_lr * 3
|
||||
dict_dim = 1500
|
||||
emb_dim = 128
|
||||
hid_dim = 128
|
||||
margin = 0.1
|
||||
sample_rate = 1
|
||||
|
||||
# Fix seed for test
|
||||
fluid.default_startup_program().random_seed = 1
|
||||
fluid.default_main_program().random_seed = 1
|
||||
|
||||
|
||||
def get_acc(cos_q_nt, cos_q_pt, batch_size):
|
||||
cond = fluid.layers.less_than(cos_q_nt, cos_q_pt)
|
||||
cond = fluid.layers.cast(cond, dtype='float64')
|
||||
cond_3 = fluid.layers.reduce_sum(cond)
|
||||
acc = fluid.layers.elementwise_div(
|
||||
cond_3,
|
||||
fluid.layers.fill_constant(
|
||||
shape=[1], value=batch_size * 1.0, dtype='float64'),
|
||||
name="simnet_acc")
|
||||
return acc
|
||||
|
||||
|
||||
def get_loss(cos_q_pt, cos_q_nt):
|
||||
loss_op1 = fluid.layers.elementwise_sub(
|
||||
fluid.layers.fill_constant_batch_size_like(
|
||||
input=cos_q_pt, shape=[-1, 1], value=margin, dtype='float32'),
|
||||
cos_q_pt)
|
||||
loss_op2 = fluid.layers.elementwise_add(loss_op1, cos_q_nt)
|
||||
loss_op3 = fluid.layers.elementwise_max(
|
||||
fluid.layers.fill_constant_batch_size_like(
|
||||
input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'),
|
||||
loss_op2)
|
||||
avg_cost = fluid.layers.mean(loss_op3)
|
||||
return avg_cost
|
||||
|
||||
|
||||
def get_optimizer():
|
||||
# SGD optimizer
|
||||
optimizer = fluid.optimizer.SGD(learning_rate=base_lr)
|
||||
return optimizer
|
||||
|
||||
|
||||
def train_network(batch_size, is_distributed=False, is_sparse=False):
|
||||
# query
|
||||
q = fluid.layers.data(
|
||||
name="query_ids", shape=[1], dtype="int64", lod_level=1)
|
||||
## embedding
|
||||
q_emb = fluid.layers.embedding(
|
||||
input=q,
|
||||
is_distributed=is_distributed,
|
||||
size=[dict_dim, emb_dim],
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=0.01),
|
||||
name="__emb__",
|
||||
learning_rate=emb_lr),
|
||||
is_sparse=is_sparse)
|
||||
## vsum
|
||||
q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum')
|
||||
q_ss = fluid.layers.softsign(q_sum)
|
||||
## fc layer after conv
|
||||
q_fc = fluid.layers.fc(
|
||||
input=q_ss,
|
||||
size=hid_dim,
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=0.01),
|
||||
name="__q_fc__",
|
||||
learning_rate=base_lr))
|
||||
# label data
|
||||
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
|
||||
# pt
|
||||
pt = fluid.layers.data(
|
||||
name="pos_title_ids", shape=[1], dtype="int64", lod_level=1)
|
||||
## embedding
|
||||
pt_emb = fluid.layers.embedding(
|
||||
input=pt,
|
||||
is_distributed=is_distributed,
|
||||
size=[dict_dim, emb_dim],
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=0.01),
|
||||
name="__emb__",
|
||||
learning_rate=emb_lr),
|
||||
is_sparse=is_sparse)
|
||||
## vsum
|
||||
pt_sum = fluid.layers.sequence_pool(input=pt_emb, pool_type='sum')
|
||||
pt_ss = fluid.layers.softsign(pt_sum)
|
||||
## fc layer
|
||||
pt_fc = fluid.layers.fc(
|
||||
input=pt_ss,
|
||||
size=hid_dim,
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=0.01),
|
||||
name="__fc__",
|
||||
learning_rate=base_lr),
|
||||
bias_attr=fluid.ParamAttr(name="__fc_b__"))
|
||||
# nt
|
||||
nt = fluid.layers.data(
|
||||
name="neg_title_ids", shape=[1], dtype="int64", lod_level=1)
|
||||
## embedding
|
||||
nt_emb = fluid.layers.embedding(
|
||||
input=nt,
|
||||
is_distributed=is_distributed,
|
||||
size=[dict_dim, emb_dim],
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=0.01),
|
||||
name="__emb__",
|
||||
learning_rate=emb_lr),
|
||||
is_sparse=is_sparse)
|
||||
## vsum
|
||||
nt_sum = fluid.layers.sequence_pool(input=nt_emb, pool_type='sum')
|
||||
nt_ss = fluid.layers.softsign(nt_sum)
|
||||
## fc layer
|
||||
nt_fc = fluid.layers.fc(
|
||||
input=nt_ss,
|
||||
size=hid_dim,
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=0.01),
|
||||
name="__fc__",
|
||||
learning_rate=base_lr),
|
||||
bias_attr=fluid.ParamAttr(name="__fc_b__"))
|
||||
cos_q_pt = fluid.layers.cos_sim(q_fc, pt_fc)
|
||||
cos_q_nt = fluid.layers.cos_sim(q_fc, nt_fc)
|
||||
# loss
|
||||
avg_cost = get_loss(cos_q_pt, cos_q_nt)
|
||||
# acc
|
||||
acc = get_acc(cos_q_nt, cos_q_pt, batch_size)
|
||||
return [avg_cost, acc, cos_q_pt]
|
||||
|
||||
|
||||
def combination(x, y):
|
||||
res = [[[xi, yi] for yi in y] for xi in x]
|
||||
return res[0]
|
||||
|
||||
|
||||
def get_one_data(file_list):
|
||||
for file in file_list:
|
||||
contents = []
|
||||
with open(file, "r") as fin:
|
||||
for i in fin:
|
||||
contents.append(i.strip())
|
||||
for index, q in enumerate(contents):
|
||||
try:
|
||||
one_data = [[int(j) for j in i.split(" ")]
|
||||
for i in q.split(";")[:-1]]
|
||||
if one_data[1][0] + one_data[1][1] != len(one_data) - 3:
|
||||
q = fin.readline()
|
||||
continue
|
||||
tmp = combination(one_data[3:3 + one_data[1][0]],
|
||||
one_data[3 + one_data[1][0]:])
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
for each in tmp:
|
||||
yield [one_data[2], 0, each[0], each[1]]
|
||||
|
||||
|
||||
def get_batch_reader(file_list, batch_size):
|
||||
def batch_reader():
|
||||
res = []
|
||||
for i in get_one_data(file_list):
|
||||
if random.random() <= sample_rate:
|
||||
res.append(i)
|
||||
if len(res) >= batch_size:
|
||||
yield res
|
||||
res = []
|
||||
|
||||
return batch_reader
|
||||
|
||||
|
||||
def get_train_reader(batch_size):
|
||||
# The training data set.
|
||||
train_file = os.path.join(paddle.dataset.common.DATA_HOME, "simnet",
|
||||
"train")
|
||||
train_reader = get_batch_reader([train_file], batch_size)
|
||||
train_feed = ["query_ids", "pos_title_ids", "neg_title_ids", "label"]
|
||||
return train_reader, train_feed
|
||||
|
||||
|
||||
class TestDistSimnetBow2x2(TestDistRunnerBase):
|
||||
def get_model(self, batch_size=2):
|
||||
# Train program
|
||||
avg_cost, acc, predict = \
|
||||
train_network(batch_size, bool(int(os.environ["IS_DISTRIBUTED"])), bool(int(os.environ["IS_SPARSE"])))
|
||||
|
||||
inference_program = fluid.default_main_program().clone()
|
||||
|
||||
# Optimization
|
||||
opt = get_optimizer()
|
||||
opt.minimize(avg_cost)
|
||||
|
||||
# Reader
|
||||
train_reader, _ = get_train_reader(batch_size)
|
||||
return inference_program, avg_cost, train_reader, train_reader, acc, predict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
paddle.dataset.common.download(DATA_URL, 'simnet', DATA_MD5, "train")
|
||||
runtime_main(TestDistSimnetBow2x2)
|
@ -0,0 +1,231 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.profiler as profiler
|
||||
from paddle.fluid import core
|
||||
import unittest
|
||||
from multiprocessing import Process
|
||||
import os
|
||||
import signal
|
||||
import six
|
||||
import tarfile
|
||||
import string
|
||||
import re
|
||||
from functools import reduce
|
||||
from test_dist_base import TestDistRunnerBase, runtime_main
|
||||
|
||||
DTYPE = "float32"
|
||||
VOCAB_URL = 'http://paddle-dist-ce-data.bj.bcebos.com/imdb.vocab'
|
||||
VOCAB_MD5 = '23c86a0533c0151b6f12fa52b106dcc2'
|
||||
DATA_URL = 'http://paddle-dist-ce-data.bj.bcebos.com/text_classification.tar.gz'
|
||||
DATA_MD5 = '29ebfc94f11aea9362bbb7f5e9d86b8a'
|
||||
|
||||
|
||||
# Load dictionary.
|
||||
def load_vocab(filename):
|
||||
vocab = {}
|
||||
if six.PY2:
|
||||
with open(filename, 'r') as f:
|
||||
for idx, line in enumerate(f):
|
||||
vocab[line.strip()] = idx
|
||||
else:
|
||||
with open(filename, 'r', encoding="utf-8") as f:
|
||||
for idx, line in enumerate(f):
|
||||
vocab[line.strip()] = idx
|
||||
return vocab
|
||||
|
||||
|
||||
def get_worddict(dict_path):
|
||||
word_dict = load_vocab(dict_path)
|
||||
word_dict["<unk>"] = len(word_dict)
|
||||
dict_dim = len(word_dict)
|
||||
return word_dict, dict_dim
|
||||
|
||||
|
||||
def conv_net(input,
|
||||
dict_dim,
|
||||
emb_dim=128,
|
||||
window_size=3,
|
||||
num_filters=128,
|
||||
fc0_dim=96,
|
||||
class_dim=2):
|
||||
emb = fluid.layers.embedding(
|
||||
input=input,
|
||||
size=[dict_dim, emb_dim],
|
||||
is_sparse=False,
|
||||
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
|
||||
value=0.01)))
|
||||
|
||||
conv_3 = fluid.nets.sequence_conv_pool(
|
||||
input=emb,
|
||||
num_filters=num_filters,
|
||||
filter_size=window_size,
|
||||
act="tanh",
|
||||
pool_type="max",
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=0.01)))
|
||||
|
||||
fc_0 = fluid.layers.fc(
|
||||
input=[conv_3],
|
||||
size=fc0_dim,
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=0.01)))
|
||||
|
||||
prediction = fluid.layers.fc(
|
||||
input=[fc_0],
|
||||
size=class_dim,
|
||||
act="softmax",
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=0.01)))
|
||||
|
||||
return prediction
|
||||
|
||||
|
||||
def inference_network(dict_dim):
|
||||
data = fluid.layers.data(
|
||||
name="words", shape=[1], dtype="int64", lod_level=1)
|
||||
out = conv_net(data, dict_dim)
|
||||
return out
|
||||
|
||||
|
||||
def get_reader(word_dict, batch_size):
|
||||
# The training data set.
|
||||
train_reader = paddle.batch(train(word_dict), batch_size=batch_size)
|
||||
|
||||
# The testing data set.
|
||||
test_reader = paddle.batch(test(word_dict), batch_size=batch_size)
|
||||
|
||||
return train_reader, test_reader
|
||||
|
||||
|
||||
def get_optimizer(learning_rate):
|
||||
optimizer = fluid.optimizer.SGD(learning_rate=learning_rate)
|
||||
return optimizer
|
||||
|
||||
|
||||
class TestDistTextClassification2x2(TestDistRunnerBase):
|
||||
def get_model(self, batch_size=2):
|
||||
vocab = os.path.join(paddle.dataset.common.DATA_HOME,
|
||||
"text_classification", "imdb.vocab")
|
||||
word_dict, dict_dim = get_worddict(vocab)
|
||||
|
||||
# Input data
|
||||
data = fluid.layers.data(
|
||||
name="words", shape=[1], dtype="int64", lod_level=1)
|
||||
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||
|
||||
# Train program
|
||||
predict = conv_net(data, dict_dim)
|
||||
cost = fluid.layers.cross_entropy(input=predict, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
acc = fluid.layers.accuracy(input=predict, label=label)
|
||||
inference_program = fluid.default_main_program().clone()
|
||||
|
||||
# Optimization
|
||||
opt = get_optimizer(learning_rate=0.001)
|
||||
opt.minimize(avg_cost)
|
||||
|
||||
# Reader
|
||||
train_reader, test_reader = get_reader(word_dict, batch_size)
|
||||
|
||||
return inference_program, avg_cost, train_reader, test_reader, acc, predict
|
||||
|
||||
|
||||
def tokenize(pattern):
|
||||
"""
|
||||
Read files that match the given pattern. Tokenize and yield each file.
|
||||
"""
|
||||
|
||||
with tarfile.open(
|
||||
paddle.dataset.common.download(DATA_URL, 'text_classification',
|
||||
DATA_MD5)) as tarf:
|
||||
# Note that we should use tarfile.next(), which does
|
||||
# sequential access of member files, other than
|
||||
# tarfile.extractfile, which does random access and might
|
||||
# destroy hard disks.
|
||||
tf = tarf.next()
|
||||
while tf != None:
|
||||
if bool(pattern.match(tf.name)):
|
||||
# newline and punctuations removal and ad-hoc tokenization.
|
||||
yield tarf.extractfile(tf).read().rstrip(six.b(
|
||||
"\n\r")).translate(
|
||||
None, six.b(string.punctuation)).lower().split()
|
||||
tf = tarf.next()
|
||||
|
||||
|
||||
def reader_creator(pos_pattern, neg_pattern, word_idx):
|
||||
UNK = word_idx['<unk>']
|
||||
INS = []
|
||||
|
||||
def load(pattern, out, label):
|
||||
for doc in tokenize(pattern):
|
||||
out.append(([word_idx.get(w, UNK) for w in doc], label))
|
||||
|
||||
load(pos_pattern, INS, 0)
|
||||
load(neg_pattern, INS, 1)
|
||||
|
||||
def reader():
|
||||
for doc, label in INS:
|
||||
yield doc, label
|
||||
|
||||
return reader
|
||||
|
||||
|
||||
def train(word_idx):
|
||||
"""
|
||||
IMDB training set creator.
|
||||
|
||||
It returns a reader creator, each sample in the reader is an zero-based ID
|
||||
sequence and label in [0, 1].
|
||||
|
||||
:param word_idx: word dictionary
|
||||
:type word_idx: dict
|
||||
:return: Training reader creator
|
||||
:rtype: callable
|
||||
"""
|
||||
return reader_creator(
|
||||
re.compile("train/pos/.*\.txt$"),
|
||||
re.compile("train/neg/.*\.txt$"), word_idx)
|
||||
|
||||
|
||||
def test(word_idx):
|
||||
"""
|
||||
IMDB test set creator.
|
||||
|
||||
It returns a reader creator, each sample in the reader is an zero-based ID
|
||||
sequence and label in [0, 1].
|
||||
|
||||
:param word_idx: word dictionary
|
||||
:type word_idx: dict
|
||||
:return: Test reader creator
|
||||
:rtype: callable
|
||||
"""
|
||||
return reader_creator(
|
||||
re.compile("test/pos/.*\.txt$"),
|
||||
re.compile("test/neg/.*\.txt$"), word_idx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
paddle.dataset.common.download(VOCAB_URL, 'text_classification', VOCAB_MD5)
|
||||
paddle.dataset.common.download(DATA_URL, 'text_classification', DATA_MD5)
|
||||
runtime_main(TestDistTextClassification2x2)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,31 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from test_dist_base import TestDistBase
|
||||
|
||||
|
||||
class TestDistCTR2x2(TestDistBase):
|
||||
def _setup_config(self):
|
||||
self._sync_mode = True
|
||||
self._use_cuda = False
|
||||
|
||||
def test_dist_ctr(self):
|
||||
self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,79 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from test_dist_base import TestDistBase
|
||||
|
||||
|
||||
class TestDistSimnetBowDense2x2(TestDistBase):
|
||||
def _setup_config(self):
|
||||
self._sync_mode = True
|
||||
self._use_cuda = False
|
||||
|
||||
def test_simnet_bow(self):
|
||||
need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '0'}
|
||||
self.check_with_place(
|
||||
"dist_simnet_bow.py",
|
||||
delta=1e-5,
|
||||
check_error_log=False,
|
||||
need_envs=need_envs)
|
||||
|
||||
|
||||
class TestDistSimnetBow2x2DenseAsync(TestDistBase):
|
||||
def _setup_config(self):
|
||||
self._sync_mode = False
|
||||
self._use_cuda = False
|
||||
|
||||
def test_simnet_bow(self):
|
||||
need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '0'}
|
||||
self.check_with_place(
|
||||
"dist_simnet_bow.py",
|
||||
delta=100,
|
||||
check_error_log=False,
|
||||
need_envs=need_envs)
|
||||
|
||||
|
||||
class TestDistSimnetBowSparse2x2(TestDistBase):
|
||||
def _setup_config(self):
|
||||
self._sync_mode = True
|
||||
self._use_cuda = False
|
||||
|
||||
def test_simnet_bow(self):
|
||||
need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '1'}
|
||||
self.check_with_place(
|
||||
"dist_simnet_bow.py",
|
||||
delta=1e-5,
|
||||
check_error_log=False,
|
||||
need_envs=need_envs)
|
||||
|
||||
|
||||
class TestDistSimnetBow2x2SparseAsync(TestDistBase):
|
||||
def _setup_config(self):
|
||||
self._sync_mode = False
|
||||
self._use_cuda = False
|
||||
|
||||
def test_simnet_bow(self):
|
||||
need_envs = {"IS_DISTRIBUTED": '0', "IS_SPARSE": '1'}
|
||||
self.check_with_place(
|
||||
"dist_simnet_bow.py",
|
||||
delta=100,
|
||||
check_error_log=False,
|
||||
need_envs=need_envs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,40 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import unittest
|
||||
from test_dist_base import TestDistBase
|
||||
|
||||
|
||||
class TestDistTextClassification2x2(TestDistBase):
|
||||
def _setup_config(self):
|
||||
self._sync_mode = True
|
||||
self._use_cuda = False
|
||||
|
||||
def test_text_classification(self):
|
||||
self.check_with_place("dist_text_classification.py", delta=1e-6)
|
||||
|
||||
|
||||
class TestDistTextClassification2x2Async(TestDistBase):
|
||||
def _setup_config(self):
|
||||
self._sync_mode = False
|
||||
self._use_cuda = False
|
||||
|
||||
def test_se_resnext(self):
|
||||
self.check_with_place("dist_text_classification.py", delta=100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue