improve elementwise_add_grad perf (#29277)

* improve performance of elementwise_sum_grad
revert-31562-mean
Zhang Ting 4 years ago committed by GitHub
parent ebf689197d
commit befd6d5338
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -22,9 +22,10 @@ namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
void default_elementwise_add(const framework::ExecutionContext &ctx,
void DefaultElementwiseAddGrad(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z) {
const framework::Tensor *y,
framework::Tensor *z) {
int axis = ctx.Attr<int>("axis");
auto x_dims = x->dims();
auto y_dims = y->dims();
@ -57,7 +58,7 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
SameDimsElemwiseAdd<DeviceContext, T> same_dims_add;
same_dims_add(ctx, x, y, z);
} else {
default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
DefaultElementwiseAddGrad<DeviceContext, T>(ctx, x, y, z);
}
}
};
@ -68,13 +69,12 @@ struct IdentityGrad {
};
template <typename DeviceContext, typename T>
void default_elementwise_add_grad(const framework::ExecutionContext &ctx,
void DefaultElementwiseAddGrad(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y,
const framework::Tensor *out,
const framework::Tensor *dout,
framework::Tensor *dx,
framework::Tensor *dy) {
framework::Tensor *dx, framework::Tensor *dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>,
@ -87,11 +87,10 @@ template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext &ctx,
ElementwiseAddGrad(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
const framework::Tensor *out,
const framework::Tensor *dout, framework::Tensor *dx,
framework::Tensor *dy) {
const framework::Tensor *out, const framework::Tensor *dout,
framework::Tensor *dx, framework::Tensor *dy) {
auto blas = math::GetBlas<DeviceContext, T>(ctx);
if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(),
@ -108,12 +107,11 @@ template <typename DeviceContext, typename T>
typename std::enable_if<
!std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext &ctx,
ElementwiseAddGrad(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
const framework::Tensor *out,
const framework::Tensor *dout, framework::Tensor *dx,
framework::Tensor *dy) {
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
const framework::Tensor *out, const framework::Tensor *dout,
framework::Tensor *dx, framework::Tensor *dy) {
DefaultElementwiseAddGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}
#ifdef PADDLE_WITH_CUDA
@ -121,11 +119,10 @@ elementwise_add_grad(const framework::ExecutionContext &ctx,
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext &ctx,
ElementwiseAddGrad(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
const framework::Tensor *out,
const framework::Tensor *dout, framework::Tensor *dx,
framework::Tensor *dy);
const framework::Tensor *out, const framework::Tensor *dout,
framework::Tensor *dx, framework::Tensor *dy);
#endif
template <typename DeviceContext, typename T>
@ -158,10 +155,9 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy);
} else if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
ElementwiseAddGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} else {
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
dy);
DefaultElementwiseAddGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}
}
};
@ -186,7 +182,7 @@ class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> {
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);
ddout->mutable_data<T>(ctx.GetPlace());
default_elementwise_add<DeviceContext, T>(ctx, &ddx_safe, &ddy_safe,
DefaultElementwiseAddGrad<DeviceContext, T>(ctx, &ddx_safe, &ddy_safe,
ddout);
}
}

Loading…
Cancel
Save