parent
b0b26dabe7
commit
45eabb8cf2
@ -0,0 +1,136 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/crf_decoding_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
CRFDecodingOpMaker(framework::OpProto* proto,
|
||||
framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Emission",
|
||||
"(LoDTensor, default: LoDTensor<float>). A LoDTensor with shape "
|
||||
"[N x D] where N is the size of the mini-batch and D is the total "
|
||||
"tag number. This input is the unscaled emission weight matrix of "
|
||||
"the linear_chain_crf operator.");
|
||||
AddInput(
|
||||
"Transition",
|
||||
"(Tensor, default: Tensor<float>). A Tensor with shape [(D + 2) x D]. "
|
||||
"This input is the transition weights learned by the linear_chain_crf "
|
||||
"operator, denoted as w. The 1st row of w are transition weights for "
|
||||
"the start mask. The 2nd row of w are transition weights for the end "
|
||||
"mask. Transition weights between other tags begin from the 3rd row of "
|
||||
"w. See more details in comments of the linear_chain_crf operator.");
|
||||
AddInput(
|
||||
"Label",
|
||||
"(LoDTensor, LoDTensor<int>). The ground truth with shape "
|
||||
"[N x 1]. This input is optional. See more details in the operator's "
|
||||
"comments.")
|
||||
.AsDispensable();
|
||||
AddOutput("ViterbiPath",
|
||||
"(LoDTensor, LoDTensor<int>). The decoding results. What to "
|
||||
"return changes depending on whether the Input(Label) (the groud "
|
||||
"truth) is given. See more details in the operator's comment.");
|
||||
AddComment(R"DOC(
|
||||
The crf_decoding operator reads the emission feature weights and the transition
|
||||
freature weights learned by the linear_chain_crf operator. It implements the
|
||||
Viterbi algorithm which is a dynamic programming algorithm for finding the most
|
||||
likely sequence of hidden states, called the Viterbi path, that results in a
|
||||
sequence of observed tags.
|
||||
|
||||
The output of this operator changes according to whether Input(Label) is given:
|
||||
|
||||
1. Input(Label) is given:
|
||||
|
||||
This happens in training. This operator is used to co-work with the chunk_eval
|
||||
operator.
|
||||
|
||||
When Input(Label) is given, the crf_decoding operator returns a row vector
|
||||
with shape [N x 1] whose values are fixed to be 0, indicating an incorrect
|
||||
prediction, or 1 indicating a tag is correctly predicted. Such an ouput is the
|
||||
input to chunk_eval operator.
|
||||
|
||||
2. Input(Label) is not given:
|
||||
|
||||
This is the standard decoding process.
|
||||
|
||||
The crf_decoding operator returns a row vecotr with shape [N x 1] whose values
|
||||
range from 0 to maximum tag number - 1. Each element indicates an index of a
|
||||
predicted tag.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class CRFDecodingOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Emission"),
|
||||
"Input(Emission) should be not null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Transition"),
|
||||
"Input(Transition) should be not null.");
|
||||
|
||||
PADDLE_ENFORCE(ctx->HasOutput("ViterbiPath"),
|
||||
"Output(ViterbiPath) should be not null.");
|
||||
|
||||
auto emission_dims = ctx->GetInputDim("Emission");
|
||||
PADDLE_ENFORCE_EQ(emission_dims.size(), 2UL,
|
||||
"The Input(Emission) should be a 2-D tensor.");
|
||||
PADDLE_ENFORCE(emission_dims[0], "An empty mini-batch is not allowed.");
|
||||
|
||||
auto transition_dims = ctx->GetInputDim("Transition");
|
||||
PADDLE_ENFORCE_EQ(transition_dims.size(), 2UL,
|
||||
"The Input(Transition) should be a 2-D tensor.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
transition_dims[0] - 2, transition_dims[1],
|
||||
"An invalid dimension for the Input(Transition), which should "
|
||||
"be a 2-D tensor with shape [(D + 2) x D].");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
emission_dims[1], transition_dims[1],
|
||||
"The 2nd dimension of the Input(Emission) and the Input(Transition) "
|
||||
"should be equal to the tag number.");
|
||||
|
||||
if (ctx->HasInput("Label")) {
|
||||
auto label_dims = ctx->GetInputDim("Label");
|
||||
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
|
||||
"The Input(Label) should be a 2-D tensor with the 2nd "
|
||||
"dimensions fixed to 1.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
emission_dims[0], label_dims[0],
|
||||
"The height of Input(Emission) and the height of Input(Label) "
|
||||
"should be the same.");
|
||||
}
|
||||
|
||||
ctx->ShareLoD("Emission", /*->*/ "ViterbiPath");
|
||||
ctx->SetOutputDim("ViterbiPath", {emission_dims[0], 1});
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::DataType IndicateDataType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type());
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(crf_decoding, ops::CRFDecodingOp,
|
||||
ops::CRFDecodingOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
crf_decoding, ops::CRFDecodingOpKernel<paddle::platform::CPUPlace, float>,
|
||||
ops::CRFDecodingOpKernel<paddle::platform::CPUPlace, double>);
|
@ -0,0 +1,127 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::LoDTensor;
|
||||
using framework::LoD;
|
||||
using framework::Tensor;
|
||||
|
||||
template <typename Place, typename T>
|
||||
class CRFDecodingOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
|
||||
"The crf_decoding operator can only run on CPU.");
|
||||
|
||||
auto* emission_weights = ctx.Input<LoDTensor>("Emission");
|
||||
auto* transition_weights = ctx.Input<Tensor>("Transition");
|
||||
auto* label = ctx.Input<LoDTensor>("Label");
|
||||
auto* decoded_path = ctx.Output<Tensor>("ViterbiPath");
|
||||
|
||||
PADDLE_ENFORCE_EQ(emission_weights->NumLevels(), 1UL,
|
||||
"The Input(Emission) should be a sequence.");
|
||||
auto lod = emission_weights->lod();
|
||||
PADDLE_ENFORCE(lod.size(), "Input(Emission) must be a sequence.");
|
||||
const size_t level = 0;
|
||||
const size_t seq_num = lod[level].size() - 1;
|
||||
|
||||
int* path = decoded_path->mutable_data<int>(platform::CPUPlace());
|
||||
math::SetConstant<platform::CPUPlace, int>()(ctx.device_context(),
|
||||
decoded_path, 0);
|
||||
for (size_t i = 0; i < seq_num; ++i) {
|
||||
int start_pos = static_cast<int>(lod[level][i]);
|
||||
int end_pos = static_cast<int>(lod[level][i + 1]);
|
||||
Tensor decoded_path_one_seq = decoded_path->Slice(start_pos, end_pos);
|
||||
Decode(emission_weights->Slice(start_pos, end_pos), *transition_weights,
|
||||
&decoded_path_one_seq);
|
||||
}
|
||||
|
||||
if (label) {
|
||||
PADDLE_ENFORCE_EQ(label->NumLevels(), 1UL,
|
||||
"The Input(Label) should be a sequence.");
|
||||
const int* label_value = label->data<int>();
|
||||
size_t batch_size = emission_weights->dims()[0];
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
path[i] = label_value[i] == path[i] ? 1 : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void Decode(const Tensor& emission_weights, const Tensor& transition_weights,
|
||||
Tensor* decoded_path) const {
|
||||
auto emission_dims = emission_weights.dims();
|
||||
const size_t seq_len = emission_dims[0];
|
||||
const size_t tag_num = emission_dims[1];
|
||||
|
||||
const size_t state_trans_base_idx = 2;
|
||||
|
||||
const T* x = emission_weights.data<T>();
|
||||
const T* w = transition_weights.data<T>();
|
||||
int* path = decoded_path->data<int>();
|
||||
|
||||
// alpha is a memo table. An element alpha(k, v) records the score of the
|
||||
// best sequence of tags from position 1 to position k with v being the end
|
||||
// tag.
|
||||
Tensor alpha;
|
||||
T* alpha_value = alpha.mutable_data<T>(emission_dims, platform::CPUPlace());
|
||||
Tensor track;
|
||||
int* track_value =
|
||||
track.mutable_data<int>(emission_dims, platform::CPUPlace());
|
||||
|
||||
for (size_t i = 0; i < tag_num; ++i) alpha_value[i] = w[i] + x[i];
|
||||
|
||||
for (size_t k = 1; k < seq_len; ++k) {
|
||||
for (size_t i = 0; i < tag_num; ++i) {
|
||||
T max_score = -std::numeric_limits<T>::max();
|
||||
int max_j = 0;
|
||||
for (size_t j = 0; j < tag_num; ++j) {
|
||||
T score = alpha_value[(k - 1) * tag_num + j] +
|
||||
w[(j + state_trans_base_idx) * tag_num + i];
|
||||
if (score > max_score) {
|
||||
max_score = score;
|
||||
max_j = j;
|
||||
}
|
||||
}
|
||||
|
||||
alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i];
|
||||
track_value[k * tag_num + i] = max_j;
|
||||
}
|
||||
}
|
||||
|
||||
T max_score = -std::numeric_limits<T>::max();
|
||||
int max_i = 0;
|
||||
for (size_t i = 0; i < tag_num; ++i) {
|
||||
T score = alpha_value[(seq_len - 1) * tag_num + i] + w[tag_num + i];
|
||||
if (score > max_score) {
|
||||
max_score = score;
|
||||
max_i = i;
|
||||
}
|
||||
}
|
||||
path[seq_len - 1] = max_i;
|
||||
for (int k = seq_len - 1; k >= 1; --k) {
|
||||
path[k - 1] = max_i = track_value[k * tag_num + max_i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,146 @@
|
||||
import unittest
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class CRFDecoding(object):
|
||||
def __init__(self, emission_weights, transition_weights,
|
||||
seq_start_positions):
|
||||
assert (emission_weights.shape[0] == seq_start_positions[-1])
|
||||
self.tag_num = emission_weights.shape[1]
|
||||
self.seq_num = len(seq_start_positions) - 1
|
||||
|
||||
self.seq_start_positions = seq_start_positions
|
||||
self.x = emission_weights
|
||||
|
||||
self.a = transition_weights[0, :]
|
||||
self.b = transition_weights[1, :]
|
||||
self.w = transition_weights[2:, :]
|
||||
|
||||
self.track = np.zeros(
|
||||
(seq_start_positions[-1], self.tag_num), dtype="int32")
|
||||
self.decoded_path = np.zeros(
|
||||
(seq_start_positions[-1], 1), dtype="int32")
|
||||
|
||||
def _decode_one_sequence(self, decoded_path, x):
|
||||
seq_len, tag_num = x.shape
|
||||
alpha = np.zeros((seq_len, tag_num), dtype="float64")
|
||||
track = np.zeros((seq_len, tag_num), dtype="int32")
|
||||
|
||||
for i in range(tag_num):
|
||||
alpha[0, i] = self.a[i] + x[0, i]
|
||||
|
||||
for k in range(1, seq_len):
|
||||
for i in range(tag_num):
|
||||
max_score = -np.finfo("float64").max
|
||||
max_idx = 0
|
||||
for j in range(tag_num):
|
||||
score = alpha[k - 1, j] + self.w[j, i]
|
||||
if score > max_score:
|
||||
max_score = score
|
||||
max_idx = j
|
||||
alpha[k, i] = max_score + x[k, i]
|
||||
track[k, i] = max_idx
|
||||
|
||||
max_score = -np.finfo("float64").max
|
||||
max_idx = 0
|
||||
for i in range(tag_num):
|
||||
score = alpha[seq_len - 1, i] + self.b[i]
|
||||
if score > max_score:
|
||||
max_score = score
|
||||
max_idx = i
|
||||
|
||||
decoded_path[-1] = max_idx
|
||||
for i in range(seq_len - 1, 0, -1):
|
||||
decoded_path[i - 1] = max_idx = track[i, max_idx]
|
||||
|
||||
def decode(self):
|
||||
for i in range(self.seq_num):
|
||||
start = self.seq_start_positions[i]
|
||||
end = self.seq_start_positions[i + 1]
|
||||
self._decode_one_sequence(self.decoded_path[start:end, :],
|
||||
self.x[start:end, :])
|
||||
return self.decoded_path
|
||||
|
||||
|
||||
class TestCRFDecodingOp1(OpTest):
|
||||
"""
|
||||
Compare the dynamic program with random generated parameters and inputs
|
||||
with grouth truth not being given.
|
||||
"""
|
||||
|
||||
def set_test_data(self):
|
||||
SEQ_NUM = 3
|
||||
TAG_NUM = 17
|
||||
MAX_SEQ_LEN = 10
|
||||
|
||||
lod = [[0]]
|
||||
for i in range(SEQ_NUM):
|
||||
lod[-1].append(lod[-1][-1] + random.randint(1, MAX_SEQ_LEN))
|
||||
emission = np.random.uniform(-1, 1,
|
||||
[lod[-1][-1], TAG_NUM]).astype("float64")
|
||||
transition = np.random.uniform(-0.5, 0.5,
|
||||
[TAG_NUM + 2, TAG_NUM]).astype("float64")
|
||||
|
||||
self.inputs = {
|
||||
"Emission": (emission, lod),
|
||||
"Transition": transition,
|
||||
}
|
||||
|
||||
decoder = CRFDecoding(emission, transition, lod[0])
|
||||
decoded_path = decoder.decode()
|
||||
|
||||
self.outputs = {"ViterbiPath": decoded_path}
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "crf_decoding"
|
||||
self.set_test_data()
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestCRFDecodingOp2(OpTest):
|
||||
"""
|
||||
Compare the dynamic program with brute force computation with
|
||||
ground truth being given.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "crf_decoding"
|
||||
TAG_NUM = 5
|
||||
|
||||
lod = [[0, 1, 3, 6, 10]]
|
||||
transition = np.repeat(
|
||||
np.arange(
|
||||
TAG_NUM, dtype="float64").reshape(1, TAG_NUM),
|
||||
TAG_NUM + 2,
|
||||
axis=0)
|
||||
emission = np.repeat(
|
||||
np.arange(
|
||||
TAG_NUM, dtype="float64").reshape(1, TAG_NUM),
|
||||
lod[-1][-1],
|
||||
axis=0)
|
||||
|
||||
labels = np.random.randint(
|
||||
low=0, high=TAG_NUM, size=(lod[-1][-1], 1), dtype="int32")
|
||||
predicted_labels = np.ones(
|
||||
(lod[-1][-1], 1), dtype="int32") * (TAG_NUM - 1)
|
||||
expected_output = (labels == predicted_labels).astype("int32")
|
||||
|
||||
self.inputs = {
|
||||
"Emission": (emission, lod),
|
||||
"Transition": transition,
|
||||
"Label": (labels, lod)
|
||||
}
|
||||
|
||||
self.outputs = {"ViterbiPath": expected_output}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue