@ -32,14 +32,23 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
int rank = x_dims . size ( ) ;
PADDLE_ENFORCE_EQ ( rank , label_dims . size ( ) ,
" Input(X) and Input(Label) shall have the same rank. " ) ;
PADDLE_ENFORCE_EQ ( framework : : slice_ddim ( x_dims , 0 , rank - 1 ) ,
framework : : slice_ddim ( label_dims , 0 , rank - 1 ) ,
" Input(X) and Input(Label) shall have the same shape "
" except the last dimension. " ) ;
bool check = true ;
if ( ( ! ctx - > IsRuntime ( ) ) & & ( framework : : product ( x_dims ) < = 0 | |
framework : : product ( label_dims ) < = 0 ) ) {
check = false ;
}
if ( check ) {
PADDLE_ENFORCE_EQ ( framework : : slice_ddim ( x_dims , 0 , rank - 1 ) ,
framework : : slice_ddim ( label_dims , 0 , rank - 1 ) ,
" Input(X) and Input(Label) shall have the same shape "
" except the last dimension. " ) ;
}
if ( ctx - > Attrs ( ) . Get < bool > ( " soft_label " ) ) {
PADDLE_ENFORCE_EQ ( x_dims [ rank - 1 ] , label_dims [ rank - 1 ] ,
" If Attr(soft_label) == true, the last dimension of "
" Input(X) and Input(Label) should be equal. " ) ;
if ( check ) {
PADDLE_ENFORCE_EQ ( x_dims [ rank - 1 ] , label_dims [ rank - 1 ] ,
" If Attr(soft_label) == true, the last dimension of "
" Input(X) and Input(Label) should be equal. " ) ;
}
} else {
PADDLE_ENFORCE_EQ ( label_dims [ rank - 1 ] , 1UL ,
" If Attr(softLabel) == false, the last dimension of "
@ -82,20 +91,32 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
" Input(Y@Grad) and Input(X) should have the same rank. " ) ;
PADDLE_ENFORCE_EQ ( label_dims . size ( ) , rank ,
" Input(Label) and Input(X) should have the same rank. " ) ;
PADDLE_ENFORCE_EQ ( framework : : slice_ddim ( x_dims , 0 , rank - 1 ) ,
framework : : slice_ddim ( label_dims , 0 , rank - 1 ) ,
" The Input(X) and Input(Label) should have the same "
" shape except the last dimension. " ) ;
PADDLE_ENFORCE_EQ ( framework : : slice_ddim ( x_dims , 0 , rank - 1 ) ,
framework : : slice_ddim ( dy_dims , 0 , rank - 1 ) ,
" The Input(X) and Input(Y@Grad) should have the same "
" shape except the last dimension. " ) ;
bool check = true ;
if ( ( ! ctx - > IsRuntime ( ) ) & & ( framework : : product ( x_dims ) < = 0 | |
framework : : product ( label_dims ) < = 0 ) ) {
check = false ;
}
if ( check ) {
PADDLE_ENFORCE_EQ ( framework : : slice_ddim ( x_dims , 0 , rank - 1 ) ,
framework : : slice_ddim ( label_dims , 0 , rank - 1 ) ,
" The Input(X) and Input(Label) should have the same "
" shape except the last dimension. " ) ;
PADDLE_ENFORCE_EQ ( framework : : slice_ddim ( x_dims , 0 , rank - 1 ) ,
framework : : slice_ddim ( dy_dims , 0 , rank - 1 ) ,
" The Input(X) and Input(Y@Grad) should have the same "
" shape except the last dimension. " ) ;
}
PADDLE_ENFORCE_EQ ( dy_dims [ rank - 1 ] , 1 ,
" The last dimension of Input(Y@Grad) should be 1. " ) ;
if ( ctx - > Attrs ( ) . Get < bool > ( " soft_label " ) ) {
PADDLE_ENFORCE_EQ ( x_dims [ rank - 1 ] , label_dims [ rank - 1 ] ,
" When Attr(soft_label) == true, the last dimension of "
" Input(X) and Input(Label) should be equal. " ) ;
if ( check ) {
PADDLE_ENFORCE_EQ (
x_dims [ rank - 1 ] , label_dims [ rank - 1 ] ,
" When Attr(soft_label) == true, the last dimension of "
" Input(X) and Input(Label) should be equal. " ) ;
}
} else {
PADDLE_ENFORCE_EQ ( label_dims [ rank - 1 ] , 1 ,
" When Attr(soft_label) == false, the last dimension of "