|
|
|
@ -534,8 +534,8 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
|
|
|
|
|
const framework::Tensor& dout, int axis,
|
|
|
|
|
framework::Tensor* dx, framework::Tensor* dy,
|
|
|
|
|
DX_OP dx_op, DY_OP dy_op) {
|
|
|
|
|
const framework::DDim x_dim = x.dims();
|
|
|
|
|
const framework::DDim y_dim = y.dims();
|
|
|
|
|
const framework::DDim& x_dim = x.dims();
|
|
|
|
|
const framework::DDim& y_dim = y.dims();
|
|
|
|
|
if (x.dims() == y.dims()) {
|
|
|
|
|
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
|
|
|
|
|
ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
|
|
|
|
@ -558,19 +558,19 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx,
|
|
|
|
|
framework::Tensor* dx, framework::Tensor* dy,
|
|
|
|
|
DX_OP dx_op, DY_OP dy_op) {
|
|
|
|
|
if (dy == nullptr) {
|
|
|
|
|
const framework::DDim dx_dims = dout.dims();
|
|
|
|
|
const framework::DDim& dx_dims = dout.dims();
|
|
|
|
|
auto dy_dims = dx_dims;
|
|
|
|
|
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
|
|
|
|
|
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
|
|
|
|
|
} else {
|
|
|
|
|
if (dout.dims() == dy->dims()) {
|
|
|
|
|
const framework::DDim dx_dims = dout.dims();
|
|
|
|
|
const framework::DDim dy_dims = dy->dims();
|
|
|
|
|
const framework::DDim& dx_dims = dout.dims();
|
|
|
|
|
const framework::DDim& dy_dims = dy->dims();
|
|
|
|
|
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
|
|
|
|
|
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
|
|
|
|
|
} else { // Y is a scalar
|
|
|
|
|
auto dx_dims = dout.dims();
|
|
|
|
|
const framework::DDim dy_dims = dy->dims();
|
|
|
|
|
const framework::DDim& dy_dims = dy->dims();
|
|
|
|
|
ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
|
|
|
|
|
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
|
|
|
|
|
}
|
|
|
|
|