Sequence tagging demo (#225)

avx_docs
emailweixu 9 years ago committed by qingqing01
parent 9c5c38fa2a
commit d6944dec16

@ -0,0 +1,21 @@
#!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd $DIR
wget http://www.cnts.ua.ac.be/conll2000/chunking/train.txt.gz
wget http://www.cnts.ua.ac.be/conll2000/chunking/test.txt.gz

File diff suppressed because it is too large Load Diff

@ -0,0 +1,84 @@
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers import *
import math
define_py_data_sources2(train_list="data/train.list",
test_list="data/test.list",
module="dataprovider",
obj="process")
batch_size = 1
settings(
learning_method=MomentumOptimizer(),
batch_size=batch_size,
regularization=L2Regularization(batch_size * 1e-4),
average_window=0.5,
learning_rate=1e-1,
learning_rate_decay_a=1e-5,
learning_rate_decay_b=0.25,
)
num_label_types=23
def get_simd_size(size):
return int(math.ceil(float(size) / 8)) * 8
# Currently, in order to use sparse_update=True,
# the size has to be aligned.
num_label_types = get_simd_size(num_label_types)
features = data_layer(name="features", size=76328)
word = data_layer(name="word", size=6778)
pos = data_layer(name="pos", size=44)
chunk = data_layer(name="chunk",
size=num_label_types)
crf_input = fc_layer(
input=features,
size=num_label_types,
act=LinearActivation(),
bias_attr=False,
param_attr=ParamAttr(initial_std=0, sparse_update=True))
crf=crf_layer(
input=crf_input,
label=chunk,
param_attr=ParamAttr(name="crfw", initial_std=0),
)
crf_decoding=crf_decoding_layer(
size=num_label_types,
input=crf_input,
label=chunk,
param_attr=ParamAttr(name="crfw"),
)
sum_evaluator(
name="error",
input=crf_decoding,
)
chunk_evaluator(
name="chunk_f1",
input =[crf_decoding, chunk],
chunk_scheme="IOB",
num_chunk_types=11,
)
inputs(word, pos, chunk, features)
outputs(crf)

@ -0,0 +1,45 @@
# Sequence Tagging
This demo is a sequence model for assigning tags to each token in a sentence. The task is described at <a href = "http://www.cnts.ua.ac.be/conll2000/chunking">CONLL2000 Text Chunking</a> task.
## Download data
```bash
cd demo/sequence_tagging
./data/get_data.sh
```
## Train model
```bash
cd demo/sequence_tagging
./train.sh
```
## Model description
We provide two models. One is a linear CRF model (linear_crf.py) with is equivalent to the one at <a href="http://leon.bottou.org/projects/sgd#stochastic_gradient_crfs">leon.bottou.org/projects/sgd</a>. The second one is a stacked bidirectional RNN and CRF model (rnn_crf.py).
<center>
<table border="2" cellspacing="0" cellpadding="6" rules="all" frame="border">
<thead>
<th scope="col" class="left">Model name</th>
<th scope="col" class="left">Number of parameters</th>
<th scope="col" class="left">F1 score</th>
</thead>
<tbody>
<tr>
<td class="left">linear_crf</td>
<td class="left"> 1.8M </td>
<td class="left"> 0.937</td>
</tr>
<tr>
<td class="left">rnn_crf</td>
<td class="left"> 960K </td>
<td class="left">0.941</td>
</tr>
</tbody>
</table>
</center>
<br>

@ -0,0 +1,130 @@
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers import *
import math
define_py_data_sources2(train_list="data/train.list",
test_list="data/test.list",
module="dataprovider",
obj="process")
batch_size = 16
settings(
learning_method=MomentumOptimizer(),
batch_size=batch_size,
regularization=L2Regularization(batch_size * 1e-5),
average_window=0.5,
learning_rate = 2e-3,
learning_rate_decay_a = 5e-7,
learning_rate_decay_b = 0.5,
)
word_dim=128
hidden_dim = 128
with_rnn = True
initial_std=1/math.sqrt(hidden_dim)
param_attr=ParamAttr(initial_std=initial_std)
cpu_layer_attr=ExtraLayerAttribute(device=-1)
default_device(0)
num_label_types=23
features = data_layer(name="features", size=76328)
word = data_layer(name="word", size=6778)
pos = data_layer(name="pos", size=44)
chunk = data_layer(name="chunk",
size=num_label_types,
layer_attr=cpu_layer_attr)
emb = embedding_layer(
input=word, size=word_dim, param_attr=ParamAttr(initial_std=0))
hidden1 = mixed_layer(
size=hidden_dim,
act=STanhActivation(),
bias_attr=True,
input=[full_matrix_projection(emb),
table_projection(pos, param_attr=param_attr)]
)
if with_rnn:
rnn1 = recurrent_layer(
act=ReluActivation(),
bias_attr=True,
input=hidden1,
param_attr=ParamAttr(initial_std=0),
)
hidden2 = mixed_layer(
size=hidden_dim,
act=STanhActivation(),
bias_attr=True,
input=[full_matrix_projection(hidden1)
] + ([
full_matrix_projection(rnn1, param_attr=ParamAttr(initial_std=0))
] if with_rnn else []),
)
if with_rnn:
rnn2=recurrent_layer(
reverse=True,
act=ReluActivation(),
bias_attr=True,
input=hidden2,
param_attr=ParamAttr(initial_std=0),
)
crf_input = mixed_layer(
size=num_label_types,
bias_attr=False,
input=[
full_matrix_projection(hidden2),
] + ([
full_matrix_projection(rnn2, param_attr=ParamAttr(initial_std=0))
] if with_rnn else []),
)
crf = crf_layer(
input=crf_input,
label=chunk,
param_attr=ParamAttr(name="crfw", initial_std=0),
layer_attr=cpu_layer_attr,
)
crf_decoding = crf_decoding_layer(
size=num_label_types,
input=crf_input,
label=chunk,
param_attr=ParamAttr(name="crfw"),
layer_attr=cpu_layer_attr,
)
sum_evaluator(
name="error",
input=crf_decoding,
)
chunk_evaluator(
name="chunk_f1",
input =[crf_decoding, chunk],
chunk_scheme="IOB",
num_chunk_types=11,
)
inputs(word, pos, chunk, features)
outputs(crf)

@ -0,0 +1,10 @@
#!/bin/bash
paddle train \
--config rnn_crf.py \
--parallel_nn=1 \
--use_gpu=1 \
--dot_period=10 \
--log_period=1000 \
--test_period=0 \
--num_passes=10

@ -0,0 +1,9 @@
#!/bin/bash
paddle train \
--config linear_crf.py \
--use_gpu=0 \
--dot_period=100 \
--log_period=10000 \
--test_period=0 \
--num_passes=10

@ -362,6 +362,13 @@ def __extends__(dict1, dict2):
default_factory=lambda _: BaseRegularization())
def settings(batch_size,
learning_rate=1e-3,
learning_rate_decay_a=0.,
learning_rate_decay_b=0.,
learning_rate_schedule='poly',
learning_rate_args='',
average_window=0,
do_average_in_cpu=False,
max_average_window=None,
learning_method=None,
regularization=None,
is_async=False,
@ -408,10 +415,14 @@ def settings(batch_size,
else:
algorithm = 'owlqn'
args=['batch_size', 'learning_rate', 'learning_rate_decay_a',
'learning_rate_decay_b', 'learning_rate_schedule',
'learning_rate_args', 'average_window', 'do_average_in_cpu',
'max_average_window']
kwargs = dict()
kwargs['batch_size'] = batch_size
kwargs['learning_rate'] = learning_rate
kwargs['algorithm'] = algorithm
for arg in args:
kwargs[arg] = locals()[arg]
kwargs = __extends__(kwargs, learning_method.to_setting_kwargs())
learning_method.extra_settings()

Loading…
Cancel
Save