|
|
|
@ -77,8 +77,16 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* x_g = context.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* y_g = context.Output<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(x_g);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(y_g);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
x_g, platform::errors::NotFound(
|
|
|
|
|
"variable(%s) cannot be found "
|
|
|
|
|
"in scope for operator 'squared_l2_distance_grad'.",
|
|
|
|
|
framework::GradVarName("X")));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
y_g, platform::errors::NotFound(
|
|
|
|
|
"variable(%s) cannot be found "
|
|
|
|
|
"in scope for operator 'squared_l2_distance_grad'.",
|
|
|
|
|
framework::GradVarName("Y")));
|
|
|
|
|
|
|
|
|
|
auto sub_result = EigenMatrix<T>::From(*in0);
|
|
|
|
|
auto out_grad = EigenMatrix<T>::From(*in1);
|
|
|
|
@ -106,8 +114,11 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
y_g->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(sub_result.dimensions()[0], y_dims[0],
|
|
|
|
|
"First dimension of gradient must be greater or "
|
|
|
|
|
"equal than first dimension of target.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"First dimension of gradient must be greater or "
|
|
|
|
|
"equal than first dimension of target. But received "
|
|
|
|
|
"gradient dimension = %d and target dimension is %d.",
|
|
|
|
|
sub_result.dimensions()[0], y_dims[0]));
|
|
|
|
|
|
|
|
|
|
if (sub_result.dimensions()[0] == y_dims[0]) {
|
|
|
|
|
auto y_grad =
|
|
|
|
|