Refine `squared_l2_distance_grad` error message (#24409)

test=develop
release/2.0-alpha
Yang Zhang 5 years ago committed by GitHub
parent 100914ddbe
commit 7c17ed57e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -77,8 +77,16 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel<T> {
auto* x_g = context.Output<Tensor>(framework::GradVarName("X"));
auto* y_g = context.Output<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_NOT_NULL(x_g);
PADDLE_ENFORCE_NOT_NULL(y_g);
PADDLE_ENFORCE_NOT_NULL(
x_g, platform::errors::NotFound(
"variable(%s) cannot be found "
"in scope for operator 'squared_l2_distance_grad'.",
framework::GradVarName("X")));
PADDLE_ENFORCE_NOT_NULL(
y_g, platform::errors::NotFound(
"variable(%s) cannot be found "
"in scope for operator 'squared_l2_distance_grad'.",
framework::GradVarName("Y")));
auto sub_result = EigenMatrix<T>::From(*in0);
auto out_grad = EigenMatrix<T>::From(*in1);
@ -106,8 +114,11 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel<T> {
y_g->mutable_data<T>(context.GetPlace());
PADDLE_ENFORCE_GE(sub_result.dimensions()[0], y_dims[0],
"First dimension of gradient must be greater or "
"equal than first dimension of target.");
platform::errors::InvalidArgument(
"First dimension of gradient must be greater or "
"equal than first dimension of target. But received "
"gradient dimension = %d and target dimension is %d.",
sub_result.dimensions()[0], y_dims[0]));
if (sub_result.dimensions()[0] == y_dims[0]) {
auto y_grad =

Loading…
Cancel
Save