|
|
|
@ -142,24 +142,27 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Emission"),
|
|
|
|
|
"Input(Emission) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Transition"),
|
|
|
|
|
"Input(Transition) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Alpha"),
|
|
|
|
|
"Output(Alpha) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("EmissionExps"),
|
|
|
|
|
"Output(EmissionExps) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("TransitionExps"),
|
|
|
|
|
"Output(TransitionExps) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("LogLikelihood"),
|
|
|
|
|
"Output(LogLikelihood) should be not null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Emission"), "Input", "Emission",
|
|
|
|
|
"LinearChainCRF");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Transition"), "Input", "Transition",
|
|
|
|
|
"LinearChainCRF");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "LinearChainCRF");
|
|
|
|
|
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Alpha"), "Output", "Alpha",
|
|
|
|
|
"LinearChainCRF");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("EmissionExps"), "Output", "EmissionExps",
|
|
|
|
|
"LinearChainCRF");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("TransitionExps"), "Output", "TransitionExps",
|
|
|
|
|
"LinearChainCRF");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("LogLikelihood"), "Output", "LogLikelihood",
|
|
|
|
|
"LinearChainCRF");
|
|
|
|
|
|
|
|
|
|
auto transition_dims = ctx->GetInputDim("Transition");
|
|
|
|
|
PADDLE_ENFORCE_EQ(transition_dims.size(), 2,
|
|
|
|
|
"The Input(Transition) should be a 2-D tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(transition_dims.size(), 2UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The Input(Transition) should be a 2-D tensor. But "
|
|
|
|
|
"received: input rank %u, input shape [%s].",
|
|
|
|
|
transition_dims.size(), transition_dims));
|
|
|
|
|
bool check = true;
|
|
|
|
|
if ((!ctx->IsRuntime()) &&
|
|
|
|
|
(transition_dims[0] <= 0 || transition_dims[1] <= 0)) {
|
|
|
|
@ -168,49 +171,88 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
transition_dims[0] - 2, transition_dims[1],
|
|
|
|
|
"An invalid dimension for the Input(Transition), which should "
|
|
|
|
|
"be a 2-D tensor with shape [(D + 2) x D].");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"An invalid dimension for the Input(Transition), which should "
|
|
|
|
|
"be a 2-D tensor with shape [(D + 2) x D]. But received: input "
|
|
|
|
|
"rank %u, "
|
|
|
|
|
"input shape [%s].",
|
|
|
|
|
transition_dims.size(), transition_dims));
|
|
|
|
|
}
|
|
|
|
|
auto emission_dims = ctx->GetInputDim("Emission");
|
|
|
|
|
PADDLE_ENFORCE_NE(emission_dims[0], 0,
|
|
|
|
|
"An empty mini-batch is not allowed.");
|
|
|
|
|
if (ctx->HasInput("Length")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(emission_dims.size(), 3,
|
|
|
|
|
"The Input(Emission) should be a 3-D tensor.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The Input(Emission) should be a 3-D tensor. But "
|
|
|
|
|
"received: input rank %u, input shape [%s].",
|
|
|
|
|
emission_dims.size(), emission_dims));
|
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
(label_dims.size() == 3UL && label_dims[2] == 1) ||
|
|
|
|
|
(label_dims.size() == 2UL),
|
|
|
|
|
true,
|
|
|
|
|
"The Input(Label) should be a 3-D tensor with last "
|
|
|
|
|
"dimension fixed to 1 or a 2-D tensor in padding mode.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The Input(Label) should be a 3-D tensor with last dimension "
|
|
|
|
|
"fixed to 1 or a 2-D tensor in padding mode. But received: input "
|
|
|
|
|
"rank %u, input shape [%s].",
|
|
|
|
|
label_dims.size(), label_dims));
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(emission_dims[0], label_dims[0],
|
|
|
|
|
"The batch size of Input(Emission) and Input(Label) "
|
|
|
|
|
"should be the same.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The batch size of Input(Emission) "
|
|
|
|
|
"and Input(Label) should be the same. But "
|
|
|
|
|
"received Input(Emission): "
|
|
|
|
|
"rank %u, shape [%s]; received Input(Label): "
|
|
|
|
|
"rank %u, shape [%s].",
|
|
|
|
|
emission_dims.size(), emission_dims,
|
|
|
|
|
label_dims.size(), label_dims));
|
|
|
|
|
PADDLE_ENFORCE_EQ(emission_dims[1], label_dims[1],
|
|
|
|
|
"The max length of Input(Emission) and Input(Label) "
|
|
|
|
|
"should be the same.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The max length of Input(Emission) "
|
|
|
|
|
"and Input(Label) should be the same. But "
|
|
|
|
|
"received Input(Emission): "
|
|
|
|
|
"rank %u, shape [%s]; received Input(Label): "
|
|
|
|
|
"rank %u, shape [%s].",
|
|
|
|
|
emission_dims.size(), emission_dims,
|
|
|
|
|
label_dims.size(), label_dims));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(emission_dims.size(), 2,
|
|
|
|
|
"The Input(Emission) should be a 2-D tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
emission_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The Input(Emission) should be a 2-D tensor. But received: "
|
|
|
|
|
"input rank %u, input shape [%s].",
|
|
|
|
|
emission_dims.size(), emission_dims));
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(emission_dims[1], transition_dims[1],
|
|
|
|
|
"The 2nd dimension of the Input(Emission) and the "
|
|
|
|
|
"Input(Transition) "
|
|
|
|
|
"should be equal to the tag number.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The 2nd dimension of the Input(Emission) and "
|
|
|
|
|
"the Input(Transition) "
|
|
|
|
|
"should be equal to the tag number. But received "
|
|
|
|
|
"Input(Emission): rank "
|
|
|
|
|
"%u, shape [%s]; received Input(Transition): "
|
|
|
|
|
"rank %u, shape [%s].",
|
|
|
|
|
emission_dims.size(), emission_dims,
|
|
|
|
|
transition_dims.size(), transition_dims));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims.size(), 2,
|
|
|
|
|
"The Input(Label) should be a 2-D tensor with the 2nd "
|
|
|
|
|
"dimensions fixed to 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
label_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The Input(Label) should be a 2-D tensor with the 2nd "
|
|
|
|
|
"dimensions fixed to 1. But received: input rank %u, "
|
|
|
|
|
"input shape [%s].",
|
|
|
|
|
label_dims.size(), label_dims));
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
emission_dims[0], label_dims[0],
|
|
|
|
|
"The height of Input(Emission) and the height of Input(Label) "
|
|
|
|
|
"should be the same.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The first dimension of Input(Emission) and Input(Label) "
|
|
|
|
|
"should be the same. But received Input(Emission): rank %u, "
|
|
|
|
|
"shape "
|
|
|
|
|
"[%s]; received Input(Label): rank %u, shape [%s].",
|
|
|
|
|
emission_dims.size(), emission_dims, label_dims.size(),
|
|
|
|
|
label_dims));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("Alpha", emission_dims);
|
|
|
|
@ -239,12 +281,13 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("EmissionExps"),
|
|
|
|
|
"Input(EmissionExps) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("TransitionExps"),
|
|
|
|
|
"Input(TransitionExps) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("LogLikelihood")),
|
|
|
|
|
"Input(LogLikelihood@GRAD) shoudl be not null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("EmissionExps"), "Input", "EmissionExps",
|
|
|
|
|
"LinearChainCRFGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("TransitionExps"), "Input", "TransitionExps",
|
|
|
|
|
"LinearChainCRFGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("LogLikelihood")),
|
|
|
|
|
"Input", framework::GradVarName("LogLikelihood"),
|
|
|
|
|
"LinearChainCRFGrad");
|
|
|
|
|
|
|
|
|
|
auto transition_exps_dims = ctx->GetInputDim("TransitionExps");
|
|
|
|
|
auto emission_exps_dims = ctx->GetInputDim("EmissionExps");
|
|
|
|
|