|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/linear_chain_crf_op.h"
|
|
|
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -152,12 +153,19 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto transition_dims = ctx->GetInputDim("Transition");
|
|
|
|
|
PADDLE_ENFORCE_EQ(transition_dims.size(), 2,
|
|
|
|
|
"The Input(Transition) should be a 2-D tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
transition_dims[0] - 2, transition_dims[1],
|
|
|
|
|
"An invalid dimension for the Input(Transition), which should "
|
|
|
|
|
"be a 2-D tensor with shape [(D + 2) x D].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
emission_dims[1], transition_dims[1],
|
|
|
|
|
bool check = true;
|
|
|
|
|
if ((!ctx->IsRuntime()) &&
|
|
|
|
|
(transition_dims[0] <= 0 || transition_dims[1] <= 0)) {
|
|
|
|
|
check = false;
|
|
|
|
|
}
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
transition_dims[0] - 2, transition_dims[1],
|
|
|
|
|
"An invalid dimension for the Input(Transition), which should "
|
|
|
|
|
"be a 2-D tensor with shape [(D + 2) x D].");
|
|
|
|
|
}
|
|
|
|
|
PADDLE_INFERSHAPE_ENFORCE_EQ(
|
|
|
|
|
ctx, emission_dims[1], transition_dims[1],
|
|
|
|
|
"The 2nd dimension of the Input(Emission) and the Input(Transition) "
|
|
|
|
|
"should be equal to the tag number.");
|
|
|
|
|
|
|
|
|
@ -165,8 +173,8 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
|
|
|
|
|
"The Input(Label) should be a 2-D tensor with the 2nd "
|
|
|
|
|
"dimensions fixed to 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
emission_dims[0], label_dims[0],
|
|
|
|
|
PADDLE_INFERSHAPE_ENFORCE_EQ(
|
|
|
|
|
ctx, emission_dims[0], label_dims[0],
|
|
|
|
|
"The height of Input(Emission) and the height of Input(Label) "
|
|
|
|
|
"should be the same.");
|
|
|
|
|
|
|
|
|
@ -211,12 +219,19 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto transition_exps_dims = ctx->GetInputDim("TransitionExps");
|
|
|
|
|
PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2,
|
|
|
|
|
"The Input(TransitionExps) should be a 2-D tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
transition_exps_dims[0] - 2, transition_exps_dims[1],
|
|
|
|
|
"An invalid dimension for the Input(TransitionExps), which should "
|
|
|
|
|
"be a 2-D tensor with shape [(D + 2) x D].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
emission_exps_dims[1], transition_exps_dims[1],
|
|
|
|
|
bool check = true;
|
|
|
|
|
if ((!ctx->IsRuntime()) &&
|
|
|
|
|
(transition_exps_dims[0] <= 0 || transition_exps_dims[1] <= 0)) {
|
|
|
|
|
check = false;
|
|
|
|
|
}
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
transition_exps_dims[0] - 2, transition_exps_dims[1],
|
|
|
|
|
"An invalid dimension for the Input(TransitionExps), which should "
|
|
|
|
|
"be a 2-D tensor with shape [(D + 2) x D].");
|
|
|
|
|
}
|
|
|
|
|
PADDLE_INFERSHAPE_ENFORCE_EQ(
|
|
|
|
|
ctx, emission_exps_dims[1], transition_exps_dims[1],
|
|
|
|
|
"The 2nd dimension of the Input(EmissionExps) and the "
|
|
|
|
|
"Input(TransitionExps) should be equal to the tag number.");
|
|
|
|
|
|
|
|
|
@ -224,8 +239,8 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
|
|
|
|
|
"The Input(Label) should be a 2-D tensor with the 2nd "
|
|
|
|
|
"dimensions fixed to 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
emission_exps_dims[0], label_dims[0],
|
|
|
|
|
PADDLE_INFERSHAPE_ENFORCE_EQ(
|
|
|
|
|
ctx, emission_exps_dims[0], label_dims[0],
|
|
|
|
|
"The height of Input(EmissionExps) and the height of Input(Label) "
|
|
|
|
|
"should be the same.");
|
|
|
|
|
|
|
|
|
|