|
|
|
@ -23,21 +23,28 @@ class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
void Make() override {
|
|
|
|
|
AddInput("Emission",
|
|
|
|
|
"(LoDTensor, default LoDTensor<float>) "
|
|
|
|
|
"A 2-D LoDTensor with shape [N x D], where N is the size of the "
|
|
|
|
|
"(LoDTensor/Tensor<float>). When a LoDTensor input,A 2-D LoDTensor"
|
|
|
|
|
" with shape [N x D], where N is the size of the "
|
|
|
|
|
"mini-batch and D is the total tag number. The unscaled emission "
|
|
|
|
|
"weight matrix for the linear chain CRF. ");
|
|
|
|
|
"weight matrix for the linear chain CRF. When a Tensor input,"
|
|
|
|
|
"A Tensor with shape [N x S x D], where N is batch number,"
|
|
|
|
|
"S is max length of sequences, D is the total tag number.");
|
|
|
|
|
AddInput("Transition",
|
|
|
|
|
"(Tensor, default Tensor<float>) A 2-D Tensor with shape "
|
|
|
|
|
"[(D + 2) x D]. The learnable parameter for the linear_chain_crf "
|
|
|
|
|
"operator. See more details in the operator's comments.");
|
|
|
|
|
AddInput("Label",
|
|
|
|
|
"(LoDTensor, default LoDTensor<int64_t>) A LoDTensor with shape "
|
|
|
|
|
"(LoDTensor/Tensor<int64_t>), when a LoDTensor input, "
|
|
|
|
|
"[N x 1], where N is the total element number in a mini-batch. "
|
|
|
|
|
"The ground truth.");
|
|
|
|
|
"when a Tensor input, [N x S], where N is batch number. "
|
|
|
|
|
"S is max length of sequences. The ground truth.");
|
|
|
|
|
AddInput("length",
|
|
|
|
|
"(Tensor, default Tensor<int64_t>) A Tensor with shape "
|
|
|
|
|
"[M x 1], where M is the sequence number in a mini-batch.")
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddOutput(
|
|
|
|
|
"Alpha",
|
|
|
|
|
"(Tensor, default Tensor<float>) A 2-D Tensor with shape [N x D]. "
|
|
|
|
|
"(Tensor, default Tensor<float>), the same shape with Emission. "
|
|
|
|
|
"The forward vectors for the entire batch. Denote it as $\alpha$. "
|
|
|
|
|
"$\alpha$ is a memo table used to calculate the normalization "
|
|
|
|
|
"factor in CRF. $\alpha[k, v]$ stores the unnormalized "
|
|
|
|
@ -49,7 +56,7 @@ class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput(
|
|
|
|
|
"EmissionExps",
|
|
|
|
|
"(Tensor, default Tensor<float>) A 2-D Tensor with shape [N x D]. "
|
|
|
|
|
"(Tensor, default Tensor<float>), the same shape with Emission. "
|
|
|
|
|
"The exponentials of Input(Emission). This is an intermediate "
|
|
|
|
|
"computational result in forward computation, and will be reused in "
|
|
|
|
|
"backward computation.")
|
|
|
|
@ -145,11 +152,6 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("LogLikelihood"),
|
|
|
|
|
"Output(LogLikelihood) should be not null.");
|
|
|
|
|
|
|
|
|
|
auto emission_dims = ctx->GetInputDim("Emission");
|
|
|
|
|
PADDLE_ENFORCE_EQ(emission_dims.size(), 2,
|
|
|
|
|
"The Input(Emission) should be a 2-D tensor.");
|
|
|
|
|
PADDLE_ENFORCE(emission_dims[0], "An empty mini-batch is not allowed.");
|
|
|
|
|
|
|
|
|
|
auto transition_dims = ctx->GetInputDim("Transition");
|
|
|
|
|
PADDLE_ENFORCE_EQ(transition_dims.size(), 2,
|
|
|
|
|
"The Input(Transition) should be a 2-D tensor.");
|
|
|
|
@ -164,20 +166,40 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
|
|
|
|
|
"An invalid dimension for the Input(Transition), which should "
|
|
|
|
|
"be a 2-D tensor with shape [(D + 2) x D].");
|
|
|
|
|
}
|
|
|
|
|
PADDLE_INFERSHAPE_ENFORCE_EQ(
|
|
|
|
|
ctx, emission_dims[1], transition_dims[1],
|
|
|
|
|
"The 2nd dimension of the Input(Emission) and the Input(Transition) "
|
|
|
|
|
"should be equal to the tag number.");
|
|
|
|
|
|
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
|
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
|
|
|
|
|
"The Input(Label) should be a 2-D tensor with the 2nd "
|
|
|
|
|
"dimensions fixed to 1.");
|
|
|
|
|
PADDLE_INFERSHAPE_ENFORCE_EQ(
|
|
|
|
|
ctx, emission_dims[0], label_dims[0],
|
|
|
|
|
"The height of Input(Emission) and the height of Input(Label) "
|
|
|
|
|
"should be the same.");
|
|
|
|
|
|
|
|
|
|
auto emission_dims = ctx->GetInputDim("Emission");
|
|
|
|
|
PADDLE_ENFORCE_NE(emission_dims[0], 0,
|
|
|
|
|
"An empty mini-batch is not allowed.");
|
|
|
|
|
if (ctx->HasInput("length")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(emission_dims.size(), 3,
|
|
|
|
|
"The Input(Emission) should be a 3-D tensor.");
|
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims.size(), 3,
|
|
|
|
|
"The Input(Label) should be a 3-D tensor");
|
|
|
|
|
PADDLE_INFERSHAPE_ENFORCE_EQ(
|
|
|
|
|
ctx, emission_dims[0], label_dims[0],
|
|
|
|
|
"The batch size of Input(Emission) and Input(Label) "
|
|
|
|
|
"should be the same.");
|
|
|
|
|
PADDLE_INFERSHAPE_ENFORCE_EQ(
|
|
|
|
|
ctx, emission_dims[1], label_dims[1],
|
|
|
|
|
"The max length of Input(Emission) and Input(Label) "
|
|
|
|
|
"should be the same.");
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(emission_dims.size(), 2,
|
|
|
|
|
"The Input(Emission) should be a 2-D tensor.");
|
|
|
|
|
PADDLE_INFERSHAPE_ENFORCE_EQ(
|
|
|
|
|
ctx, emission_dims[1], transition_dims[1],
|
|
|
|
|
"The 2nd dimension of the Input(Emission) and the Input(Transition) "
|
|
|
|
|
"should be equal to the tag number.");
|
|
|
|
|
|
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims.size(), 2,
|
|
|
|
|
"The Input(Label) should be a 2-D tensor with the 2nd "
|
|
|
|
|
"dimensions fixed to 1.");
|
|
|
|
|
PADDLE_INFERSHAPE_ENFORCE_EQ(
|
|
|
|
|
ctx, emission_dims[0], label_dims[0],
|
|
|
|
|
"The height of Input(Emission) and the height of Input(Label) "
|
|
|
|
|
"should be the same.");
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("Alpha", emission_dims);
|
|
|
|
|
ctx->SetOutputDim("EmissionExps", emission_dims);
|
|
|
|
|
ctx->SetOutputDim("TransitionExps", transition_dims);
|
|
|
|
@ -210,12 +232,6 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("LogLikelihood")),
|
|
|
|
|
"Input(LogLikelihood@GRAD) shoudl be not null.");
|
|
|
|
|
|
|
|
|
|
auto emission_exps_dims = ctx->GetInputDim("EmissionExps");
|
|
|
|
|
PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 2,
|
|
|
|
|
"The Input(EmissionExps) should be a 2-D tensor.");
|
|
|
|
|
PADDLE_ENFORCE(emission_exps_dims[0],
|
|
|
|
|
"An empty mini-batch is not allowed.");
|
|
|
|
|
|
|
|
|
|
auto transition_exps_dims = ctx->GetInputDim("TransitionExps");
|
|
|
|
|
PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2,
|
|
|
|
|
"The Input(TransitionExps) should be a 2-D tensor.");
|
|
|
|
@ -230,15 +246,34 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
"An invalid dimension for the Input(TransitionExps), which should "
|
|
|
|
|
"be a 2-D tensor with shape [(D + 2) x D].");
|
|
|
|
|
}
|
|
|
|
|
PADDLE_INFERSHAPE_ENFORCE_EQ(
|
|
|
|
|
ctx, emission_exps_dims[1], transition_exps_dims[1],
|
|
|
|
|
"The 2nd dimension of the Input(EmissionExps) and the "
|
|
|
|
|
"Input(TransitionExps) should be equal to the tag number.");
|
|
|
|
|
|
|
|
|
|
auto emission_exps_dims = ctx->GetInputDim("EmissionExps");
|
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
|
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
|
|
|
|
|
"The Input(Label) should be a 2-D tensor with the 2nd "
|
|
|
|
|
"dimensions fixed to 1.");
|
|
|
|
|
if (ctx->HasInput("length")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 3,
|
|
|
|
|
"The Input(EmissionExps) should be a 3-D tensor.");
|
|
|
|
|
PADDLE_INFERSHAPE_ENFORCE_EQ(
|
|
|
|
|
ctx, emission_exps_dims[2], transition_exps_dims[1],
|
|
|
|
|
"The 3nd dimension of the Input(EmissionExps) and the "
|
|
|
|
|
"Input(TransitionExps) should be equal to the tag number.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims.size(), 3,
|
|
|
|
|
"The Input(Label) should be a 3-D tensor with the 3nd "
|
|
|
|
|
"dimensions fixed to 1.");
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 2,
|
|
|
|
|
"The Input(EmissionExps) should be a 2-D tensor.");
|
|
|
|
|
PADDLE_INFERSHAPE_ENFORCE_EQ(
|
|
|
|
|
ctx, emission_exps_dims[1], transition_exps_dims[1],
|
|
|
|
|
"The 2nd dimension of the Input(EmissionExps) and the "
|
|
|
|
|
"Input(TransitionExps) should be equal to the tag number.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims.size(), 2,
|
|
|
|
|
"The Input(Label) should be a 2-D tensor");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims[1], 1,
|
|
|
|
|
"The Input(Label) 2nd dimensions fixed to 1.");
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_NE(emission_exps_dims[0], 0,
|
|
|
|
|
"An empty mini-batch is not allowed.");
|
|
|
|
|
|
|
|
|
|
PADDLE_INFERSHAPE_ENFORCE_EQ(
|
|
|
|
|
ctx, emission_exps_dims[0], label_dims[0],
|
|
|
|
|
"The height of Input(EmissionExps) and the height of Input(Label) "
|
|
|
|
@ -246,8 +281,12 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("Emission"))) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Emission"), emission_exps_dims);
|
|
|
|
|
ctx->ShareLoD("Emission", framework::GradVarName("Emission"));
|
|
|
|
|
if (ctx->HasInput("length") == false) {
|
|
|
|
|
ctx->ShareLoD("Emission", framework::GradVarName("Emission"));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// ctx->SetOutputDim(framework::GradVarName("Emission"),
|
|
|
|
|
// emission_exps_dims);
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("Transition"))) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Transition"),
|
|
|
|
|
transition_exps_dims);
|
|
|
|
@ -275,15 +314,15 @@ class LinearChainCRFGradDescMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
|
|
|
|
|
op->SetType("linear_chain_crf_grad");
|
|
|
|
|
op->SetAttrMap(Attrs());
|
|
|
|
|
|
|
|
|
|
op->SetInput("Emission", Input("Emission"));
|
|
|
|
|
op->SetInput("Transition", Input("Transition"));
|
|
|
|
|
op->SetInput("Label", Input("Label"));
|
|
|
|
|
|
|
|
|
|
op->SetInput("Alpha", Output("Alpha"));
|
|
|
|
|
op->SetInput("EmissionExps", Output("EmissionExps"));
|
|
|
|
|
op->SetInput("TransitionExps", Output("TransitionExps"));
|
|
|
|
|
|
|
|
|
|
if (ForwardOp().Inputs().count("length") > 0) {
|
|
|
|
|
op->SetInput("length", Input("length"));
|
|
|
|
|
}
|
|
|
|
|
op->SetInput(framework::GradVarName("LogLikelihood"),
|
|
|
|
|
OutputGrad("LogLikelihood"));
|
|
|
|
|
|
|
|
|
|