@ -25,12 +25,9 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel {
using framework : : OperatorWithKernel : : OperatorWithKernel ;
void InferShape ( framework : : InferShapeContext * ctx ) const override {
PADDLE_ENFORCE_EQ ( ctx - > HasInput ( " X " ) , true , " Input(X) should be not null. " ) ;
PADDLE_ENFORCE_EQ ( ctx - > HasInput ( " Label " ) , true ,
" Input(Label) should be not null. " ) ;
PADDLE_ENFORCE_EQ ( ctx - > HasOutput ( " Y " ) , true ,
" Output(Y) should be not null. " ) ;
OP_INOUT_CHECK ( ctx - > HasInput ( " X " ) , " Input " , " X " , " CrossEntropy " ) ;
OP_INOUT_CHECK ( ctx - > HasInput ( " Label " ) , " Input " , " Label " , " CrossEntropy " ) ;
OP_INOUT_CHECK ( ctx - > HasOutput ( " Y " ) , " Output " , " Y " , " CrossEntropy " ) ;
auto x_dims = ctx - > GetInputDim ( " X " ) ;
auto label_dims = ctx - > GetInputDim ( " Label " ) ;
@ -44,53 +41,61 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ (
framework : : slice_ddim ( x_dims , 0 , rank - 1 ) ,
framework : : slice_ddim ( label_dims , 0 , rank - 1 ) ,
" ShapeError: Input(X) and Input(Label) shall have the same shape "
" except the last dimension. But received: the shape of Input(X) is "
" [%s], "
" the shape of Input(Label) is [%s]. " ,
x_dims , label_dims ) ;
platform : : errors : : InvalidArgument (
" Input(X) and Input(Label) shall have the same shape "
" except the last dimension. But received: the shape of Input(X) "
" is "
" [%s], the shape of Input(Label) is [%s]. " ,
x_dims , label_dims ) ) ;
}
if ( IsSoftLabel ( ctx ) ) {
PADDLE_ENFORCE_EQ (
rank , label_dims . size ( ) ,
" ShapeError: If Attr(soft_label) == true, Input(X) and Input(Label) "
" shall have the same dimensions. But received: the dimensions of "
" Input(X) is [%d], "
" the shape of Input(X) is [%s], the dimensions of Input(Label) is "
" [%d], the shape of "
" Input(Label) is [%s] " ,
rank , x_dims , label_dims . size ( ) , label_dims ) ;
platform : : errors : : InvalidArgument (
" If Attr(soft_label) == true, Input(X) and Input(Label) "
" shall have the same dimensions. But received: the dimensions of "
" Input(X) is [%d], "
" the shape of Input(X) is [%s], the dimensions of Input(Label) "
" is "
" [%d], the shape of "
" Input(Label) is [%s] " ,
rank , x_dims , label_dims . size ( ) , label_dims ) ) ;
if ( check ) {
PADDLE_ENFORCE_EQ (
x_dims [ rank - 1 ] , label_dims [ rank - 1 ] ,
" ShapeError: If Attr(soft_label) == true, the last dimension of "
" Input(X) and Input(Label) should be equal. But received: the "
" last dimension of Input(X) is [%d], the shape of Input(X) is [%s], "
" the last dimension of Input(Label) is [%d], the shape of "
" Input(Label) "
" is [%s], the last dimension is [%d]. " ,
x_dims [ rank - 1 ] , x_dims , label_dims [ rank - 1 ] , label_dims ,
rank - 1 ) ;
platform : : errors : : InvalidArgument (
" If Attr(soft_label) == true, the last dimension of "
" Input(X) and Input(Label) should be equal. But received: the "
" last dimension of Input(X) is [%d], the shape of Input(X) is "
" [%s], "
" the last dimension of Input(Label) is [%d], the shape of "
" Input(Label) "
" is [%s], the last dimension is [%d]. " ,
x_dims [ rank - 1 ] , x_dims , label_dims [ rank - 1 ] , label_dims ,
rank - 1 ) ) ;
}
} else {
if ( rank = = label_dims . size ( ) ) {
PADDLE_ENFORCE_EQ (
label_dims [ rank - 1 ] , 1UL ,
" ShapeError: the last dimension of Input(Label) should be 1. "
" But received: the last dimension of Input(Label) is [%d], "
" the last dimension is [%d] " ,
label_dims [ rank - 1 ] , rank - 1 ) ;
platform : : errors : : InvalidArgument (
" the last dimension of Input(Label) should be 1. "
" But received: the last dimension of Input(Label) is [%d], "
" the last dimension is [%d] " ,
label_dims [ rank - 1 ] , rank - 1 ) ) ;
} else {
PADDLE_ENFORCE_EQ ( rank , label_dims . size ( ) + 1 ,
" ShapeError: The rank of Input(X) should be equal to "
" Input(Label) plus 1. "
" But received: The dimension of Input(X) is [%d], "
" the shape of Input(X) is [%s], "
" the dimension of Input(Label) is [%d], the shape of "
" Input(Label) is [%s] " ,
rank , x_dims , label_dims . size ( ) , label_dims ) ;
PADDLE_ENFORCE_EQ (
rank , label_dims . size ( ) + 1 ,
platform : : errors : : InvalidArgument (
" ShapeError: The rank of Input(X) should be equal to "
" Input(Label) plus 1. "
" But received: The dimension of Input(X) is [%d], "
" the shape of Input(X) is [%s], "
" the dimension of Input(Label) is [%d], the shape of "
" Input(Label) is [%s] " ,
rank , x_dims , label_dims . size ( ) , label_dims ) ) ;
}
}
@ -122,19 +127,23 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
using framework : : OperatorWithKernel : : OperatorWithKernel ;
void InferShape ( framework : : InferShapeContext * ctx ) const {
PADDLE_ENFORCE_EQ ( ctx - > HasInput ( " Label " ) , true ,
" Input(Label) should be not null. " ) ;
PADDLE_ENFORCE_EQ ( ctx - > HasInput ( framework : : GradVarName ( " Y " ) ) , true ,
" Input(Y@GRAD) shoudl be not null. " ) ;
PADDLE_ENFORCE_EQ ( ctx - > HasOutput ( framework : : GradVarName ( " X " ) ) , true ,
" Output(X@GRAD) should be not null. " ) ;
OP_INOUT_CHECK ( ctx - > HasInput ( " Label " ) , " Input " , " Label " ,
" CrossEntropyGradientOpBase " ) ;
OP_INOUT_CHECK ( ctx - > HasInput ( framework : : GradVarName ( " Y " ) ) , " Input " ,
framework : : GradVarName ( " Y " ) , " CrossEntropyGradientOpBase " ) ;
OP_INOUT_CHECK ( ctx - > HasOutput ( framework : : GradVarName ( " X " ) ) , " Output " ,
framework : : GradVarName ( " X " ) , " CrossEntropyGradientOpBase " ) ;
auto x_dims = GetXDim ( ctx ) ;
auto label_dims = ctx - > GetInputDim ( " Label " ) ;
auto dy_dims = ctx - > GetInputDim ( framework : : GradVarName ( " Y " ) ) ;
int rank = x_dims . size ( ) ;
PADDLE_ENFORCE_EQ ( dy_dims . size ( ) , label_dims . size ( ) ,
" Input(Y@Grad) and Input(Y) should have the same rank. " ) ;
PADDLE_ENFORCE_EQ (
dy_dims . size ( ) , label_dims . size ( ) ,
platform : : errors : : InvalidArgument (
" Input(Y@Grad) and Input(Y) should have the same rank. "
" But received: Y@Grad's rank is [%d], Y's rank is [%d] " ,
dy_dims . size ( ) , label_dims . size ( ) ) ) ;
bool check = true ;
if ( ( ! ctx - > IsRuntime ( ) ) & &
@ -143,10 +152,15 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
}
if ( check ) {
PADDLE_ENFORCE_EQ ( framework : : slice_ddim ( x_dims , 0 , rank - 1 ) ,
framework : : slice_ddim ( dy_dims , 0 , rank - 1 ) ,
" The Input(X) and Input(Y@Grad) should have the same "
" shape except the last dimension. " ) ;
PADDLE_ENFORCE_EQ (
framework : : slice_ddim ( x_dims , 0 , rank - 1 ) ,
framework : : slice_ddim ( dy_dims , 0 , rank - 1 ) ,
platform : : errors : : InvalidArgument (
" The Input(X) and Input(Y@Grad) should have the same "
" shape except the last dimension. but received: "
" the shape of Input(X) is [%s], "
" the shape of Input(Y@Grad) is [%s]. " ,
x_dims , dy_dims ) ) ;
}
ctx - > SetOutputDim ( framework : : GradVarName ( " X " ) , x_dims ) ;
@ -253,7 +267,7 @@ class CrossEntropyGradientOp : public CrossEntropyGradientOpBase {
using CrossEntropyGradientOpBase : : CrossEntropyGradientOpBase ;
void InferShape ( framework : : InferShapeContext * ctx ) const override {
PADDLE_ENFORCE_EQ ( ctx - > HasInput ( " X " ) , true , " Input(X) should be not null. " ) ;
OP_INOUT_CHECK ( ctx - > HasInput ( " X " ) , " Input " , " X " , " CrossEntropyGradientOp " ) ;
CrossEntropyGradientOpBase : : InferShape ( ctx ) ;
}
} ;
@ -281,11 +295,10 @@ class CrossEntropyOp2 : public CrossEntropyOpBase {
void InferShape ( framework : : InferShapeContext * ctx ) const override {
CrossEntropyOpBase : : InferShape ( ctx ) ;
PADDLE_ENFORCE_EQ ( ctx - > HasOutput ( " XShape " ) , true ,
" Output(XShape) should be not null. " ) ;
PADDLE_ENFORCE_EQ ( ctx - > HasOutput ( " MatchX " ) , true ,
" Output(MatchX) should be not null. " ) ;
OP_INOUT_CHECK ( ctx - > HasOutput ( " XShape " ) , " Output " , " XShape " ,
" CrossEntropyOp2 " ) ;
OP_INOUT_CHECK ( ctx - > HasOutput ( " MatchX " ) , " Output " , " MatchX " ,
" CrossEntropyOp2 " ) ;
auto x_dims = ctx - > GetInputDim ( " X " ) ;
auto x_dims_vec = framework : : vectorize ( x_dims ) ;
x_dims_vec . push_back ( 0 ) ;
@ -305,8 +318,8 @@ class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase {
public :
using CrossEntropyGradientOpBase : : CrossEntropyGradientOpBase ;
void InferShape ( framework : : InferShapeContext * ctx ) const override {
PADDLE_ENFORCE_EQ ( ctx - > HasInput ( " MatchX " ) , true ,
" Input(MatchX) must exist " ) ;
OP_INOUT_CHECK ( ctx - > HasInput ( " MatchX " ) , " Input " , " MatchX " ,
" CrossEntropyGradientOp2 " ) ;
CrossEntropyGradientOpBase : : InferShape ( ctx ) ;
}