improve elementwise_add_grad perf (#29277)

* improve performance of elementwise_sum_grad
revert-31562-mean
Zhang Ting 5 years ago committed by GitHub
parent ebf689197d
commit befd6d5338
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -22,9 +22,10 @@ namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void default_elementwise_add(const framework::ExecutionContext &ctx, void DefaultElementwiseAddGrad(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z) { const framework::Tensor *y,
framework::Tensor *z) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims = y->dims();
@ -57,7 +58,7 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
SameDimsElemwiseAdd<DeviceContext, T> same_dims_add; SameDimsElemwiseAdd<DeviceContext, T> same_dims_add;
same_dims_add(ctx, x, y, z); same_dims_add(ctx, x, y, z);
} else { } else {
default_elementwise_add<DeviceContext, T>(ctx, x, y, z); DefaultElementwiseAddGrad<DeviceContext, T>(ctx, x, y, z);
} }
} }
}; };
@ -68,13 +69,12 @@ struct IdentityGrad {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void default_elementwise_add_grad(const framework::ExecutionContext &ctx, void DefaultElementwiseAddGrad(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *x,
const framework::Tensor *y, const framework::Tensor *y,
const framework::Tensor *out, const framework::Tensor *out,
const framework::Tensor *dout, const framework::Tensor *dout,
framework::Tensor *dx, framework::Tensor *dx, framework::Tensor *dy) {
framework::Tensor *dy) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>, ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>,
@ -87,11 +87,10 @@ template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
std::is_floating_point<T>::value && std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext &ctx, ElementwiseAddGrad(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y, const framework::Tensor *x, const framework::Tensor *y,
const framework::Tensor *out, const framework::Tensor *out, const framework::Tensor *dout,
const framework::Tensor *dout, framework::Tensor *dx, framework::Tensor *dx, framework::Tensor *dy) {
framework::Tensor *dy) {
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
if (dx) { if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(), blas.VCOPY(dout->numel(), dout->data<T>(),
@ -108,12 +107,11 @@ template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
!std::is_floating_point<T>::value && !std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext &ctx, ElementwiseAddGrad(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y, const framework::Tensor *x, const framework::Tensor *y,
const framework::Tensor *out, const framework::Tensor *out, const framework::Tensor *dout,
const framework::Tensor *dout, framework::Tensor *dx, framework::Tensor *dx, framework::Tensor *dy) {
framework::Tensor *dy) { DefaultElementwiseAddGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
@ -121,11 +119,10 @@ elementwise_add_grad(const framework::ExecutionContext &ctx,
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext &ctx, ElementwiseAddGrad(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y, const framework::Tensor *x, const framework::Tensor *y,
const framework::Tensor *out, const framework::Tensor *out, const framework::Tensor *dout,
const framework::Tensor *dout, framework::Tensor *dx, framework::Tensor *dx, framework::Tensor *dy);
framework::Tensor *dy);
#endif #endif
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
@ -158,10 +155,9 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
*dout, ctx.GetPlace(), *dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy); ctx.template device_context<platform::DeviceContext>(), dy);
} else if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { } else if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); ElementwiseAddGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} else { } else {
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, DefaultElementwiseAddGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
dy);
} }
} }
}; };
@ -186,8 +182,8 @@ class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> {
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe); GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);
ddout->mutable_data<T>(ctx.GetPlace()); ddout->mutable_data<T>(ctx.GetPlace());
default_elementwise_add<DeviceContext, T>(ctx, &ddx_safe, &ddy_safe, DefaultElementwiseAddGrad<DeviceContext, T>(ctx, &ddx_safe, &ddy_safe,
ddout); ddout);
} }
} }
}; };

Loading…
Cancel
Save