Revert "improve elementwise_add_grad perf (#29277)" (#29464)

This reverts commit befd6d5338.
revert-31562-mean
Zhang Ting 4 years ago committed by GitHub
parent 57a4f16d9e
commit 560b432349
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -22,10 +22,9 @@ namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
void DefaultElementwiseAddGrad(const framework::ExecutionContext &ctx,
void default_elementwise_add(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y,
framework::Tensor *z) {
const framework::Tensor *y, framework::Tensor *z) {
int axis = ctx.Attr<int>("axis");
auto x_dims = x->dims();
auto y_dims = y->dims();
@ -58,7 +57,7 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
SameDimsElemwiseAdd<DeviceContext, T> same_dims_add;
same_dims_add(ctx, x, y, z);
} else {
DefaultElementwiseAddGrad<DeviceContext, T>(ctx, x, y, z);
default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
}
}
};
@ -69,12 +68,13 @@ struct IdentityGrad {
};
template <typename DeviceContext, typename T>
void DefaultElementwiseAddGrad(const framework::ExecutionContext &ctx,
void default_elementwise_add_grad(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y,
const framework::Tensor *out,
const framework::Tensor *dout,
framework::Tensor *dx, framework::Tensor *dy) {
framework::Tensor *dx,
framework::Tensor *dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>,
@ -87,10 +87,11 @@ template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
ElementwiseAddGrad(const framework::ExecutionContext &ctx,
elementwise_add_grad(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
const framework::Tensor *out, const framework::Tensor *dout,
framework::Tensor *dx, framework::Tensor *dy) {
const framework::Tensor *out,
const framework::Tensor *dout, framework::Tensor *dx,
framework::Tensor *dy) {
auto blas = math::GetBlas<DeviceContext, T>(ctx);
if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(),
@ -107,11 +108,12 @@ template <typename DeviceContext, typename T>
typename std::enable_if<
!std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
ElementwiseAddGrad(const framework::ExecutionContext &ctx,
elementwise_add_grad(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
const framework::Tensor *out, const framework::Tensor *dout,
framework::Tensor *dx, framework::Tensor *dy) {
DefaultElementwiseAddGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
const framework::Tensor *out,
const framework::Tensor *dout, framework::Tensor *dx,
framework::Tensor *dy) {
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}
#ifdef PADDLE_WITH_CUDA
@ -119,10 +121,11 @@ ElementwiseAddGrad(const framework::ExecutionContext &ctx,
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
ElementwiseAddGrad(const framework::ExecutionContext &ctx,
elementwise_add_grad(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
const framework::Tensor *out, const framework::Tensor *dout,
framework::Tensor *dx, framework::Tensor *dy);
const framework::Tensor *out,
const framework::Tensor *dout, framework::Tensor *dx,
framework::Tensor *dy);
#endif
template <typename DeviceContext, typename T>
@ -155,9 +158,10 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy);
} else if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
ElementwiseAddGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} else {
DefaultElementwiseAddGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
dy);
}
}
};
@ -182,7 +186,7 @@ class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> {
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);
ddout->mutable_data<T>(ctx.GetPlace());
DefaultElementwiseAddGrad<DeviceContext, T>(ctx, &ddx_safe, &ddy_safe,
default_elementwise_add<DeviceContext, T>(ctx, &ddx_safe, &ddy_safe,
ddout);
}
}

Loading…
Cancel
Save