|
|
|
@ -53,10 +53,14 @@ class NLLLossOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE_EQ(w_dims.size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Weight) should be a 1D tensor."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[1], w_dims[0],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Weight) Tensor's size should match "
|
|
|
|
|
"to the the total number of classes."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims[1], w_dims[0],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Expected input tensor Weight's size should equal "
|
|
|
|
|
"to the first dimension of the input tensor X. But received "
|
|
|
|
|
"Weight's "
|
|
|
|
|
"size is %d, the first dimension of input X is %d",
|
|
|
|
|
w_dims[0], x_dims[1]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (x_dims.size() == 2) {
|
|
|
|
@ -68,7 +72,8 @@ class NLLLossOp : public framework::OperatorWithKernel {
|
|
|
|
|
} else if (x_dims.size() == 4) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims.size(), 3,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The tensor rank of Input(Label) must be 3."));
|
|
|
|
|
"Expected Input(Lable) dimensions=3, received %d.",
|
|
|
|
|
label_dims.size()));
|
|
|
|
|
auto input0 = x_dims[0];
|
|
|
|
|
auto input2 = x_dims[2];
|
|
|
|
|
auto input3 = x_dims[3];
|
|
|
|
|