parent
dcfbbd3f1d
commit
d92c671d5f
@ -1,48 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/crf_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class CrfOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
CrfOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {}
|
||||
};
|
||||
|
||||
class CrfOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContextBase* ctx) const override {}
|
||||
};
|
||||
|
||||
class CrfGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContextBase* ctx) const override {}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(crf, ops::CrfOp, ops::CrfOpMaker, crf_grad, ops::CrfGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(crf, ops::CrfOpKernel<float>);
|
||||
REGISTER_OP_CPU_KERNEL(crf_grad, ops::CrfGradOpKernel<float>);
|
@ -0,0 +1,141 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/linear_chain_crf_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
LinearChainCrfOpMaker(framework::OpProto* proto,
|
||||
framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput(
|
||||
"Emission",
|
||||
"(LoDTensor, default: LoDTensor<float>). "
|
||||
"The unscaled emission weight matrix for the linear chain CRF. "
|
||||
"This input is a LoDTensor with shape [N x D] where N is the total "
|
||||
"element number of all input squences in a mini-batch, "
|
||||
"and D is the total tag number.");
|
||||
AddInput(
|
||||
"Transition",
|
||||
"(Tensor, default: Tensor<float>). A Tensor with shape [(D + 2) x D]. "
|
||||
"The learnable parameter for linear_chain_crf operator. "
|
||||
"See more details in the operator's comments.");
|
||||
AddInput(
|
||||
"Label",
|
||||
"(LoDTensor, default: LoDTensor<int>). The ground truth which is a 2-D "
|
||||
"LoDTensor with shape [N x 1], where N is the total element number in "
|
||||
"a mini-batch.");
|
||||
AddOutput(
|
||||
"Alpha",
|
||||
"Tensor, default: Tensor<float>. The forward vectors for the entire "
|
||||
"batch. A two dimensional tensor with shape [N x D], "
|
||||
"denoted as \f$\alpha\f$. \f$\alpha$\f is a memo table used to "
|
||||
"calculate the normalization factor in CRF. \f$\alpha[k, v]$\f stores "
|
||||
"the unnormalized probabilites of all possible unfinished sequences of "
|
||||
"tags that end at position \f$k$\f with tag \f$v$\f. For each \f$k$\f, "
|
||||
"\f$\alpha[k, v]$\f is a vector of length \f$D$\f with a component for "
|
||||
"each tag value \f$v$\f. This vector is called a forward vecotr and "
|
||||
"will also be used in backward computations.")
|
||||
.AsIntermediate();
|
||||
AddOutput(
|
||||
"LogLikelihood",
|
||||
"(Tensor, default: Tensor<float>). The logarithm of the conditional "
|
||||
"likelihood of each training sample in a mini-batch. This is a 2-D "
|
||||
"tensor with shape [S x 1], where S is the sequence number in a "
|
||||
"mini-batch. "
|
||||
"Note: S is equal to the sequence number in a mini-batch. The output "
|
||||
"is no longer a LoDTensor.");
|
||||
AddComment(R"DOC(
|
||||
Conditional Random Field defines an undirected probabilistic graph with nodes
|
||||
denoting random variables and edges denoting dependencies between these
|
||||
variables. CRF learns the conditional probability \f$P(Y|X)\f$, where
|
||||
\f$X = (x_1, x_2, ... , x_n)\f$ are structured inputs and
|
||||
\f$Y = (y_1, y_2, ... , y_n)\f$ are labels for the inputs.
|
||||
|
||||
Linear chain CRF is a special case of CRF that is useful for sequence labeling
|
||||
task. Sequence labeling tasks do not assume a lot of conditional
|
||||
independences among inputs. They only concern about the input and the output
|
||||
being linear sequences. Thus, the graph model of CRF is a simple chain or
|
||||
a line, which results in a linear chain CRF.
|
||||
|
||||
This operator implements the Forward-Backward algorithm for linear chain CRF.
|
||||
Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
|
||||
|
||||
Equation:
|
||||
|
||||
- Denote the first input of this operator (Emission) as \f$x\f$ here.
|
||||
- The first D values of the second input (Transition) of this operator are for
|
||||
starting weights, denoted as \f$a\f$ here.
|
||||
- The next D values of the second input (Transition) of this operator are for
|
||||
ending weights, denoted as \f$b\f$ here.
|
||||
- The remaning values of the second input (Transition) are for transition
|
||||
weights, denoted as \f$w\f$ here.
|
||||
- Denote the third input of this operator (Label) as \f$s\f$ here.
|
||||
|
||||
The probability of a sequence \f$s\f$ of length \f$L\f$ is defined as:
|
||||
\f$P(s) = (1/Z) exp(a_{s_1} + b_{s_L}
|
||||
+ \sum_{l=1}^L x_{s_l}
|
||||
+ \sum_{l=2}^L w_{s_{l-1},s_l})\f$
|
||||
where \f$Z\f$ is a normalization value so that the sum of \f$P(s)\f$ over
|
||||
all possible sequences is \f$1\f$, and \f$x\f$ is the emission feature weight
|
||||
to the linear chain CRF.
|
||||
|
||||
Finaly, the linear chain CRF operator outputs the logarithm of the conditional
|
||||
likelihood of each training sample in a mini-batch.
|
||||
|
||||
NOTE:
|
||||
1. The feature function for a CRF is made up of the emission features and the
|
||||
transition features. The emission feature weights are NOT computed in
|
||||
this operator. They MUST be computed first before this operator is called.
|
||||
|
||||
2. Because this operator performs globally normaliztion over all possible
|
||||
sequences internally, it expects UNSCALED emission feature weights.
|
||||
Please do not call this op with the emission feature being output of any
|
||||
nonlinear activation.
|
||||
|
||||
3. The 2nd dimension of the first input of this operator (Emission) MUST be
|
||||
equal to the tag number.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class LinearChainCrfOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContextBase* ctx) const override {}
|
||||
};
|
||||
|
||||
class LinearChainCrfGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContextBase* ctx) const override {}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(linear_chain_crf, ops::LinearChainCrfOp, ops::LinearChainCrfOpMaker,
|
||||
linear_chain_crf_grad, ops::LinearChainCrfGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(linear_chain_crf, ops::LinearChainCrfOpKernel<float>);
|
||||
REGISTER_OP_CPU_KERNEL(linear_chain_crf_grad,
|
||||
ops::LinearChainCrfGradOpKernel<float>);
|
@ -1,13 +0,0 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestCrfOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "crf"
|
||||
batch_size = 3
|
||||
class_num = 37
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,122 @@
|
||||
import unittest
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class LinearChainCrfForward(object):
|
||||
def __init__(self, seq_start_positions, emission_weights,
|
||||
transition_weights, labels):
|
||||
self.tag_num = emission_weights.shape[1]
|
||||
self.seq_num = len(seq_start_positions) - 1
|
||||
|
||||
self.seq_start_positions = seq_start_positions
|
||||
self.labels = labels
|
||||
self.x = emission_weights
|
||||
|
||||
self.x_row_max = np.amax(self.x, axis=1, keepdims=True)
|
||||
self.x_exps = np.exp(self.x - self.x_row_max)
|
||||
|
||||
# unnormalized logits of the transition weights for the start mark.
|
||||
self.a = transition_weights[0, :]
|
||||
self.a_exps = np.exp(self.a)
|
||||
# unnormalized logits of the transition weights for the end mark.
|
||||
self.b = transition_weights[1, :]
|
||||
self.b_exps = np.exp(self.b)
|
||||
# unnormalized logits of the transition weights for all the other tags.
|
||||
self.w = transition_weights[2:, :]
|
||||
self.w_exps = np.exp(self.w)
|
||||
|
||||
# The output of linear chain crf operator.
|
||||
# alpha is a memo table in dynamic programming to caculate
|
||||
# nomalization factor.
|
||||
self.alpha = np.zeros(
|
||||
(seq_start_positions[-1], self.tag_num), dtype="float32")
|
||||
self.log_likelihood = np.zeros((self.tag_num, 1))
|
||||
|
||||
def _l1_norm(self, x):
|
||||
s = np.sum(x)
|
||||
x /= s
|
||||
return s
|
||||
|
||||
def _forward_a_sequence(self, x, x_row_max, x_exps, label, alpha):
|
||||
seq_len = x_row_max.shape[0]
|
||||
log_likelihood = 0.
|
||||
|
||||
for i in range(self.tag_num):
|
||||
alpha[0, i] = self.a_exps[i] * x_exps[0, i]
|
||||
log_likelihood = -x_row_max[0] - np.log(self._l1_norm(alpha[0, :]))
|
||||
|
||||
# calculate the unnormalized logits of the normalization factor.
|
||||
for k in range(1, seq_len):
|
||||
for i in range(self.tag_num):
|
||||
s = 0.
|
||||
for j in range(self.tag_num):
|
||||
s += alpha[k - 1, j] * self.w_exps[j, i]
|
||||
alpha[k, i] = x_exps[k, i] * s
|
||||
log_likelihood -= x_row_max[k] + np.log(self._l1_norm(alpha[k, :]))
|
||||
s = 0.
|
||||
for i in range(self.tag_num):
|
||||
s += alpha[-1, i] * self.b_exps[i]
|
||||
log_likelihood -= np.log(s)
|
||||
|
||||
# calculate the noninator part.
|
||||
log_likelihood += (
|
||||
self.a[label[0]] + self.x[0, label[0]] + self.b[label[-1]])
|
||||
for k in range(1, seq_len):
|
||||
log_likelihood += (
|
||||
self.x[k, label[k]] + self.w[label[k - 1], label[k]])
|
||||
return log_likelihood
|
||||
|
||||
def crf_forward_compute(self):
|
||||
for i in range(self.seq_num):
|
||||
start = self.seq_start_positions[i]
|
||||
end = self.seq_start_positions[i + 1]
|
||||
|
||||
self.log_likelihood[i] = self._forward_a_sequence(
|
||||
self.x[start:end], self.x_row_max[start:end, :],
|
||||
self.x_exps[start:end, :], self.labels[start:end, :],
|
||||
self.alpha[start:end, :])
|
||||
return self.alpha, self.log_likelihood
|
||||
|
||||
|
||||
class TestLinearChainCrfOp(OpTest):
|
||||
def set_test_data(self):
|
||||
SEQ_NUM = 3
|
||||
TAG_NUM = 17
|
||||
MAX_SEQ_LEN = 13
|
||||
|
||||
# the linear_chain_crf operator only supports sequence (LoD level = 1)
|
||||
lod = [[0]]
|
||||
for i in range(SEQ_NUM):
|
||||
lod[-1].append(lod[-1][-1] + random.randint(1, MAX_SEQ_LEN))
|
||||
|
||||
emission = np.random.uniform(-1, 1,
|
||||
[lod[-1][-1], TAG_NUM]).astype("float32")
|
||||
transition = np.random.uniform(-0.5, 0.5,
|
||||
[TAG_NUM + 2, TAG_NUM]).astype("float32")
|
||||
labels = np.random.randint(
|
||||
low=0, high=TAG_NUM, size=(lod[-1][-1], 1), dtype="int32")
|
||||
|
||||
self.inputs = {
|
||||
"Emission": (emission, lod),
|
||||
"Transition": transition,
|
||||
"label": (labels, lod)
|
||||
}
|
||||
|
||||
crf = LinearChainCrfForward(lod[0], emission, transition, labels)
|
||||
alpha, log_likelihood = crf.crf_forward_compute()
|
||||
|
||||
self.outputs = {"Alpha": alpha, "LogLikelihood": log_likelihood}
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "linear_chain_crf"
|
||||
self.set_test_data()
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue