|
|
|
@ -26,9 +26,10 @@ T NormalizeL1(T* x, size_t len) {
|
|
|
|
|
// Right now, we just bet that sum won't be zero. If this really happens, we
|
|
|
|
|
// will figure out what should be done then.
|
|
|
|
|
PADDLE_ENFORCE(sum,
|
|
|
|
|
"The unnormalized probabilites of all possible unfinished "
|
|
|
|
|
"The unnormalized probabilities of all possible unfinished "
|
|
|
|
|
"sequences must be greater than 0.");
|
|
|
|
|
for (size_t i = 0; i < len; ++i) x[i] /= sum;
|
|
|
|
|
T s = 1. / sum;
|
|
|
|
|
for (size_t i = 0; i < len; ++i) x[i] *= s;
|
|
|
|
|
return sum;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
@ -36,9 +37,9 @@ T NormalizeL1(T* x, size_t len) {
|
|
|
|
|
using framework::LoDTensor;
|
|
|
|
|
using framework::LoD;
|
|
|
|
|
|
|
|
|
|
class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
LinearChainCrfOpMaker(framework::OpProto* proto,
|
|
|
|
|
LinearChainCRFOpMaker(framework::OpProto* proto,
|
|
|
|
|
framework::OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput(
|
|
|
|
@ -51,7 +52,7 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddInput(
|
|
|
|
|
"Transition",
|
|
|
|
|
"(Tensor, default: Tensor<float>). A Tensor with shape [(D + 2) x D]. "
|
|
|
|
|
"The learnable parameter for linear_chain_crf operator. "
|
|
|
|
|
"The learnable parameter for the linear_chain_crf operator. "
|
|
|
|
|
"See more details in the operator's comments.");
|
|
|
|
|
AddInput(
|
|
|
|
|
"Label",
|
|
|
|
@ -82,14 +83,11 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput(
|
|
|
|
|
"LogLikelihood",
|
|
|
|
|
"(Tensor, default: Tensor<float>). The logarithm of the "
|
|
|
|
|
"conditional "
|
|
|
|
|
"(Tensor, default: Tensor<float>). The logarithm of the conditional "
|
|
|
|
|
"likelihood of each training sample in a mini-batch. This is a 2-D "
|
|
|
|
|
"tensor with shape [S x 1], where S is the sequence number in a "
|
|
|
|
|
"mini-batch. "
|
|
|
|
|
"Note: S is equal to the sequence number in a mini-batch. The "
|
|
|
|
|
"output "
|
|
|
|
|
"is no longer a LoDTensor.");
|
|
|
|
|
"mini-batch. Note: S is equal to the sequence number in a mini-batch. "
|
|
|
|
|
"The output is no longer a LoDTensor.");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Conditional Random Field defines an undirected probabilistic graph with nodes
|
|
|
|
|
denoting random variables and edges denoting dependencies between these
|
|
|
|
@ -100,11 +98,11 @@ variables. CRF learns the conditional probability \f$P(Y|X)\f$, where
|
|
|
|
|
Linear chain CRF is a special case of CRF that is useful for sequence labeling
|
|
|
|
|
task. Sequence labeling tasks do not assume a lot of conditional
|
|
|
|
|
independences among inputs. They only concern about the input and the output
|
|
|
|
|
being linear sequences. Thus, the graph model of CRF is a simple chain or
|
|
|
|
|
a line, which results in a linear chain CRF.
|
|
|
|
|
being linear sequences. Thus, the graph model of such a CRF is a simple chain
|
|
|
|
|
or a line, which results in the linear chain CRF.
|
|
|
|
|
|
|
|
|
|
This operator implements the Forward-Backward algorithm for linear chain CRF.
|
|
|
|
|
Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
|
|
|
|
|
This operator implements the Forward-Backward algorithm for the linear chain
|
|
|
|
|
CRF. Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
|
|
|
|
|
|
|
|
|
|
Equation:
|
|
|
|
|
|
|
|
|
@ -144,7 +142,7 @@ nonlinear activation.
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class LinearChainCrfOp : public framework::OperatorWithKernel {
|
|
|
|
|
class LinearChainCRFOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
@ -211,7 +209,7 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class LinearChainCrfOpKernel<platform::CPUPlace, T>
|
|
|
|
|
class LinearChainCRFOpKernel<platform::CPUPlace, T>
|
|
|
|
|
: public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
@ -262,11 +260,11 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
|
|
|
|
|
w_exps.device(place) = w.exp();
|
|
|
|
|
|
|
|
|
|
auto* alpha = ctx.Output<LoDTensor>("Alpha");
|
|
|
|
|
alpha->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
alpha->mutable_data<T>(platform::CPUPlace());
|
|
|
|
|
auto* ll = ctx.Output<LoDTensor>("LogLikelihood");
|
|
|
|
|
// resize the output tensor to the correct dimension.
|
|
|
|
|
ll->Resize({static_cast<int>(seq_num), 1});
|
|
|
|
|
T* log_likelihood = ll->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
T* log_likelihood = ll->mutable_data<T>(platform::CPUPlace());
|
|
|
|
|
for (size_t i = 0; i < seq_num; ++i) {
|
|
|
|
|
int start_pos = static_cast<int>(in_lod[level][i]);
|
|
|
|
|
int end_pos = static_cast<int>(in_lod[level][i + 1]);
|
|
|
|
@ -322,6 +320,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
|
|
|
|
|
}
|
|
|
|
|
alpha_value[k * tag_num + i] = x_exps[k * tag_num + i] * sum;
|
|
|
|
|
}
|
|
|
|
|
// NormalizeL1 is to avoid underflow or overflow at (*).
|
|
|
|
|
ll -= x_row_max[k] +
|
|
|
|
|
std::log(NormalizeL1<T>(alpha_value + k * tag_num, tag_num));
|
|
|
|
|
}
|
|
|
|
@ -330,6 +329,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
|
|
|
|
|
sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps[tag_num + i];
|
|
|
|
|
}
|
|
|
|
|
ll -= std::log(sum);
|
|
|
|
|
// Now ll is equal to -log(Z).
|
|
|
|
|
|
|
|
|
|
const int* lbl = label->data<int>();
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
@ -347,7 +347,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class LinearChainCrfGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
class LinearChainCRFGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
@ -407,11 +407,11 @@ class LinearChainCrfGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
|
|
|
|
|
class LinearChainCRFGradOpKernel<platform::CPUPlace, T>
|
|
|
|
|
: public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
|
|
|
|
|
PADDLE_ENFORCE(platform::is_cpu_place(platform::CPUPlace()),
|
|
|
|
|
"This kernel only runs on CPU.");
|
|
|
|
|
auto* label = ctx.Input<LoDTensor>("Label");
|
|
|
|
|
auto* emission_exps = ctx.Input<LoDTensor>("EmissionExps");
|
|
|
|
@ -493,6 +493,7 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
|
|
|
|
|
}
|
|
|
|
|
beta_value[k * tag_num + i] = sum;
|
|
|
|
|
}
|
|
|
|
|
// NormalizeL1 is to avoid underflow or overflow at (**).
|
|
|
|
|
NormalizeL1<T>(beta_value + k * tag_num, tag_num);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -534,7 +535,7 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
|
|
|
|
|
T sum = 0.;
|
|
|
|
|
for (size_t i = 0; i < tag_num; ++i) {
|
|
|
|
|
for (size_t j = 0; j < tag_num; ++j) {
|
|
|
|
|
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *
|
|
|
|
|
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * // (**)
|
|
|
|
|
alpha_mat(k - 1, i) * tmp_mat(k, j);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -557,11 +558,11 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP(linear_chain_crf, ops::LinearChainCrfOp, ops::LinearChainCrfOpMaker,
|
|
|
|
|
linear_chain_crf_grad, ops::LinearChainCrfGradOp);
|
|
|
|
|
REGISTER_OP(linear_chain_crf, ops::LinearChainCRFOp, ops::LinearChainCRFOpMaker,
|
|
|
|
|
linear_chain_crf_grad, ops::LinearChainCRFGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
linear_chain_crf,
|
|
|
|
|
ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|
ops::LinearChainCRFOpKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
linear_chain_crf_grad,
|
|
|
|
|
ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|
ops::LinearChainCRFGradOpKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|