|
|
|
@ -176,14 +176,26 @@ static void DistGradFunction(const framework::ExecutionContext& context) {
|
|
|
|
|
} else if (p == INFINITY || p == -INFINITY) {
|
|
|
|
|
// p=inf or -inf, Lp-norm = |z_i|, the j-th element of dz tends to 0 if
|
|
|
|
|
// j!=i, or equals to sign(z_i) * dout if j=i.
|
|
|
|
|
grad_t.device(place) =
|
|
|
|
|
(x_minux_y_abs == out_t.broadcast(out_bcast_dims)).template cast<T>() *
|
|
|
|
|
sign * out_grad_t.broadcast(out_bcast_dims);
|
|
|
|
|
if (platform::is_cpu_place(context.GetPlace())) {
|
|
|
|
|
grad_t.device(place) = (x_minux_y_abs == out_t.broadcast(out_bcast_dims))
|
|
|
|
|
.template cast<T>() *
|
|
|
|
|
sign.eval() * out_grad_t.broadcast(out_bcast_dims);
|
|
|
|
|
} else {
|
|
|
|
|
grad_t.device(place) = (x_minux_y_abs == out_t.broadcast(out_bcast_dims))
|
|
|
|
|
.template cast<T>() *
|
|
|
|
|
sign * out_grad_t.broadcast(out_bcast_dims);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// dz = pow(abs(x-y)/out, p-1) * sign(x-y) * dout
|
|
|
|
|
grad_t.device(place) =
|
|
|
|
|
(x_minux_y_abs / out_t.broadcast(out_bcast_dims)).pow(p - 1) * sign *
|
|
|
|
|
out_grad_t.broadcast(out_bcast_dims);
|
|
|
|
|
if (platform::is_cpu_place(context.GetPlace())) {
|
|
|
|
|
grad_t.device(place) =
|
|
|
|
|
(x_minux_y_abs / out_t.broadcast(out_bcast_dims)).pow(p - 1) *
|
|
|
|
|
sign.eval() * out_grad_t.broadcast(out_bcast_dims);
|
|
|
|
|
} else {
|
|
|
|
|
grad_t.device(place) =
|
|
|
|
|
(x_minux_y_abs / out_t.broadcast(out_bcast_dims)).pow(p - 1) * sign *
|
|
|
|
|
out_grad_t.broadcast(out_bcast_dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Eigen::DSizes<int, Rank * 2> x_reshape_dims;
|
|
|
|
|