|
|
|
@ -105,14 +105,16 @@ class HuberLossGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
out0->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto x_grad = EigenVector<T>::Flatten(*out0);
|
|
|
|
|
x_grad.device(place) =
|
|
|
|
|
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, -1.0));
|
|
|
|
|
residual.unaryExpr(HuberLossBackward<T>(delta, -1.0));
|
|
|
|
|
x_grad.device(place) = out_grad * x_grad;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (out1) {
|
|
|
|
|
out1->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto y_grad = EigenVector<T>::Flatten(*out1);
|
|
|
|
|
y_grad.device(place) =
|
|
|
|
|
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, 1.0));
|
|
|
|
|
residual.unaryExpr(HuberLossBackward<T>(delta, 1.0));
|
|
|
|
|
y_grad.device(place) = out_grad * y_grad;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|