|
|
|
@ -85,6 +85,57 @@ struct IdentityGrad {
|
|
|
|
|
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template<typename DeviceContext, typename T>
|
|
|
|
|
void default_elementwise_add_grad(const framework::ExecutionContext& ctx,
|
|
|
|
|
const framework::Tensor* x,
|
|
|
|
|
const framework::Tensor* y,
|
|
|
|
|
const framework::Tensor* out,
|
|
|
|
|
const framework::Tensor* dout,
|
|
|
|
|
framework::Tensor* dx,
|
|
|
|
|
framework::Tensor* dy) {
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
|
|
|
|
|
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
|
|
|
|
|
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
|
|
|
|
|
IdentityGrad<T>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename DeviceContext, typename T>
|
|
|
|
|
typename std::enable_if<
|
|
|
|
|
std::is_floating_point<T>::value &&
|
|
|
|
|
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
|
|
|
|
|
elementwise_add_grad(const framework::ExecutionContext& ctx,
|
|
|
|
|
const framework::Tensor* x,
|
|
|
|
|
const framework::Tensor* y,
|
|
|
|
|
const framework::Tensor* out,
|
|
|
|
|
const framework::Tensor* dout,
|
|
|
|
|
framework::Tensor* dx, framework::Tensor* dy) {
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
|
|
|
|
|
|
if (dx) {
|
|
|
|
|
blas.VCOPY(dout->numel(), dout->data<T>(),
|
|
|
|
|
dx->mutable_data<T>(ctx.GetPlace()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dy) {
|
|
|
|
|
blas.VCOPY(dout->numel(), dout->data<T>(),
|
|
|
|
|
dy->mutable_data<T>(ctx.GetPlace()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename DeviceContext, typename T>
|
|
|
|
|
typename std::enable_if<
|
|
|
|
|
!std::is_floating_point<T>::value ||
|
|
|
|
|
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
|
|
|
|
|
elementwise_add_grad(const framework::ExecutionContext& ctx,
|
|
|
|
|
const framework::Tensor* x,
|
|
|
|
|
const framework::Tensor* y,
|
|
|
|
|
const framework::Tensor* out,
|
|
|
|
|
const framework::Tensor* dout,
|
|
|
|
|
framework::Tensor* dx, framework::Tensor* dy) {
|
|
|
|
|
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
class ElementwiseAddGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -97,24 +148,12 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
|
|
|
|
|
if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) {
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
|
|
|
|
|
|
if (dx) {
|
|
|
|
|
blas.VCOPY(dout->numel(), dout->data<T>(),
|
|
|
|
|
dx->mutable_data<T>(ctx.GetPlace()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dy) {
|
|
|
|
|
blas.VCOPY(dout->numel(), dout->data<T>(),
|
|
|
|
|
dy->mutable_data<T>(ctx.GetPlace()));
|
|
|
|
|
}
|
|
|
|
|
elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
|
|
|
|
|
} else {
|
|
|
|
|
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
|
|
|
|
|
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
|
|
|
|
|
IdentityGrad<T>());
|
|
|
|
|
default_elementwise_add_grad<DeviceContext, T>(
|
|
|
|
|
ctx, x, y, out, dout, dx, dy);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|