|
|
|
@ -22,10 +22,9 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
void DefaultElementwiseAddGrad(const framework::ExecutionContext &ctx,
|
|
|
|
|
void default_elementwise_add(const framework::ExecutionContext &ctx,
|
|
|
|
|
const framework::Tensor *x,
|
|
|
|
|
const framework::Tensor *y,
|
|
|
|
|
framework::Tensor *z) {
|
|
|
|
|
const framework::Tensor *y, framework::Tensor *z) {
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
auto x_dims = x->dims();
|
|
|
|
|
auto y_dims = y->dims();
|
|
|
|
@ -58,7 +57,7 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
|
|
|
|
|
SameDimsElemwiseAdd<DeviceContext, T> same_dims_add;
|
|
|
|
|
same_dims_add(ctx, x, y, z);
|
|
|
|
|
} else {
|
|
|
|
|
DefaultElementwiseAddGrad<DeviceContext, T>(ctx, x, y, z);
|
|
|
|
|
default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -69,12 +68,13 @@ struct IdentityGrad {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
void DefaultElementwiseAddGrad(const framework::ExecutionContext &ctx,
|
|
|
|
|
void default_elementwise_add_grad(const framework::ExecutionContext &ctx,
|
|
|
|
|
const framework::Tensor *x,
|
|
|
|
|
const framework::Tensor *y,
|
|
|
|
|
const framework::Tensor *out,
|
|
|
|
|
const framework::Tensor *dout,
|
|
|
|
|
framework::Tensor *dx, framework::Tensor *dy) {
|
|
|
|
|
framework::Tensor *dx,
|
|
|
|
|
framework::Tensor *dy) {
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
|
|
|
|
|
ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>,
|
|
|
|
@ -87,10 +87,11 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
typename std::enable_if<
|
|
|
|
|
std::is_floating_point<T>::value &&
|
|
|
|
|
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
|
|
|
|
|
ElementwiseAddGrad(const framework::ExecutionContext &ctx,
|
|
|
|
|
elementwise_add_grad(const framework::ExecutionContext &ctx,
|
|
|
|
|
const framework::Tensor *x, const framework::Tensor *y,
|
|
|
|
|
const framework::Tensor *out, const framework::Tensor *dout,
|
|
|
|
|
framework::Tensor *dx, framework::Tensor *dy) {
|
|
|
|
|
const framework::Tensor *out,
|
|
|
|
|
const framework::Tensor *dout, framework::Tensor *dx,
|
|
|
|
|
framework::Tensor *dy) {
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
|
if (dx) {
|
|
|
|
|
blas.VCOPY(dout->numel(), dout->data<T>(),
|
|
|
|
@ -107,11 +108,12 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
typename std::enable_if<
|
|
|
|
|
!std::is_floating_point<T>::value &&
|
|
|
|
|
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
|
|
|
|
|
ElementwiseAddGrad(const framework::ExecutionContext &ctx,
|
|
|
|
|
elementwise_add_grad(const framework::ExecutionContext &ctx,
|
|
|
|
|
const framework::Tensor *x, const framework::Tensor *y,
|
|
|
|
|
const framework::Tensor *out, const framework::Tensor *dout,
|
|
|
|
|
framework::Tensor *dx, framework::Tensor *dy) {
|
|
|
|
|
DefaultElementwiseAddGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
|
|
|
|
|
const framework::Tensor *out,
|
|
|
|
|
const framework::Tensor *dout, framework::Tensor *dx,
|
|
|
|
|
framework::Tensor *dy) {
|
|
|
|
|
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
@ -119,10 +121,11 @@ ElementwiseAddGrad(const framework::ExecutionContext &ctx,
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
typename std::enable_if<
|
|
|
|
|
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
|
|
|
|
|
ElementwiseAddGrad(const framework::ExecutionContext &ctx,
|
|
|
|
|
elementwise_add_grad(const framework::ExecutionContext &ctx,
|
|
|
|
|
const framework::Tensor *x, const framework::Tensor *y,
|
|
|
|
|
const framework::Tensor *out, const framework::Tensor *dout,
|
|
|
|
|
framework::Tensor *dx, framework::Tensor *dy);
|
|
|
|
|
const framework::Tensor *out,
|
|
|
|
|
const framework::Tensor *dout, framework::Tensor *dx,
|
|
|
|
|
framework::Tensor *dy);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
@ -155,9 +158,10 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
|
|
|
|
|
*dout, ctx.GetPlace(),
|
|
|
|
|
ctx.template device_context<platform::DeviceContext>(), dy);
|
|
|
|
|
} else if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
|
|
|
|
|
ElementwiseAddGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
|
|
|
|
|
elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
|
|
|
|
|
} else {
|
|
|
|
|
DefaultElementwiseAddGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
|
|
|
|
|
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
|
|
|
|
|
dy);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -182,7 +186,7 @@ class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);
|
|
|
|
|
|
|
|
|
|
ddout->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
DefaultElementwiseAddGrad<DeviceContext, T>(ctx, &ddx_safe, &ddy_safe,
|
|
|
|
|
default_elementwise_add<DeviceContext, T>(ctx, &ddx_safe, &ddy_safe,
|
|
|
|
|
ddout);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|