|
|
|
@ -26,9 +26,8 @@ class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"Emission",
|
|
|
|
|
"(LoDTensor, default: LoDTensor<float>). "
|
|
|
|
|
"The unscaled emission weight matrix for the linear chain CRF. "
|
|
|
|
|
"This input is a LoDTensor with shape [N x D] where N is the total "
|
|
|
|
|
"element number of all input squences in a mini-batch, "
|
|
|
|
|
"and D is the total tag number.");
|
|
|
|
|
"This input is a LoDTensor with shape [N x D] where N is the size of "
|
|
|
|
|
"the mini-batch and D is the total tag number.");
|
|
|
|
|
AddInput(
|
|
|
|
|
"Transition",
|
|
|
|
|
"(Tensor, default: Tensor<float>). A Tensor with shape [(D + 2) x D]. "
|
|
|
|
@ -36,7 +35,7 @@ class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"See more details in the operator's comments.");
|
|
|
|
|
AddInput(
|
|
|
|
|
"Label",
|
|
|
|
|
"(LoDTensor, default: LoDTensor<int>). The groundtruth which is a 2-D "
|
|
|
|
|
"(LoDTensor, default: LoDTensor<int>). The ground truth which is a 2-D "
|
|
|
|
|
"LoDTensor with shape [N x 1], where N is the total element number in "
|
|
|
|
|
"a mini-batch.");
|
|
|
|
|
AddOutput(
|
|
|
|
@ -77,12 +76,13 @@ variables. CRF learns the conditional probability \f$P(Y|X)\f$, where
|
|
|
|
|
|
|
|
|
|
Linear chain CRF is a special case of CRF that is useful for sequence labeling
|
|
|
|
|
task. Sequence labeling tasks do not assume a lot of conditional
|
|
|
|
|
independences among inputs. They only concern about the input and the output
|
|
|
|
|
being linear sequences. Thus, the graph model of such a CRF is a simple chain
|
|
|
|
|
or a line, which results in the linear chain CRF.
|
|
|
|
|
independences among inputs. The only constraint they impose is that the input
|
|
|
|
|
and output must be linear sequences. Thus, the graph of such a CRF is a simple
|
|
|
|
|
chain or a line, which results in the linear chain CRF.
|
|
|
|
|
|
|
|
|
|
This operator implements the Forward-Backward algorithm for the linear chain
|
|
|
|
|
CRF. Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
|
|
|
|
|
CRF. Please see http://www.cs.columbia.edu/~mcollins/fb.pdf and
|
|
|
|
|
http://cseweb.ucsd.edu/~elkan/250Bwinter2012/loglinearCRFs.pdf for reference.
|
|
|
|
|
|
|
|
|
|
Equation:
|
|
|
|
|
|
|
|
|
@ -111,7 +111,7 @@ NOTE:
|
|
|
|
|
transition features. The emission feature weights are NOT computed in
|
|
|
|
|
this operator. They MUST be computed first before this operator is called.
|
|
|
|
|
|
|
|
|
|
2. Because this operator performs globally normaliztion over all possible
|
|
|
|
|
2. Because this operator performs global normalization over all possible
|
|
|
|
|
sequences internally, it expects UNSCALED emission feature weights.
|
|
|
|
|
Please do not call this op with the emission feature being output of any
|
|
|
|
|
nonlinear activation.
|
|
|
|
@ -171,9 +171,10 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->SetOutputDim("Alpha", emission_dims);
|
|
|
|
|
ctx->SetOutputDim("EmissionExps", emission_dims);
|
|
|
|
|
ctx->SetOutputDim("TransitionExps", transition_dims);
|
|
|
|
|
// (TODO caoying) This is tricky. The 1st dimension of Output(LogLikelihood)
|
|
|
|
|
// TODO(caoying) This is tricky. The 1st dimension of Output(LogLikelihood)
|
|
|
|
|
// is the sequence number in a mini-batch. The dimension set here should be
|
|
|
|
|
// resized to its correct size in the function Compute.
|
|
|
|
|
// resized to its correct size in the function Compute. Fix this once we can
|
|
|
|
|
// get LoD information in the InferShape interface.
|
|
|
|
|
ctx->SetOutputDim("LogLikelihood", {emission_dims[0], 1});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -236,7 +237,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
// Explicitly set that the data type of output of the linear_chain_crf_grad
|
|
|
|
|
// operator is determined by its input: graidents of LogLikelihood.
|
|
|
|
|
// operator is determined by its input: gradients of LogLikelihood.
|
|
|
|
|
framework::DataType IndicateDataType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
return framework::ToDataType(
|
|
|
|
|