|
|
|
@ -313,21 +313,18 @@ EIGEN_FUNCTOR(Div, EIGEN_DIV);
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T, typename functor,
|
|
|
|
|
typename broadcastfunctor, typename broadcast2functor>
|
|
|
|
|
void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
auto* x = ctx.Input<Tensor>("X");
|
|
|
|
|
auto* y = ctx.Input<Tensor>("Y");
|
|
|
|
|
auto* out = ctx.Input<Tensor>("Out");
|
|
|
|
|
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
|
|
|
|
|
|
|
|
|
|
const framework::Tensor* x,
|
|
|
|
|
const framework::Tensor* y,
|
|
|
|
|
const framework::Tensor* out,
|
|
|
|
|
const framework::Tensor* dout, int axis,
|
|
|
|
|
framework::Tensor* dx, framework::Tensor* dy) {
|
|
|
|
|
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
|
|
|
|
|
auto x_dims = x->dims();
|
|
|
|
|
auto y_dims = y->dims();
|
|
|
|
|
|
|
|
|
|
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
if (dx) {
|
|
|
|
|
dx->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
@ -348,7 +345,6 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
|
|
|
|
|
x_dims = framework::make_ddim(extended_dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
|
|
|
|
|
|
|
|
|
|
int pre, n, post;
|
|
|
|
@ -367,13 +363,10 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
|
|
|
|
|
|
|
|
|
|
template <typename Functor, typename DeviceContext, typename T,
|
|
|
|
|
typename OutType = T>
|
|
|
|
|
void ElementwiseComputeEx(const framework::ExecutionContext& ctx) {
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
auto* x = ctx.Input<Tensor>("X");
|
|
|
|
|
auto* y = ctx.Input<Tensor>("Y");
|
|
|
|
|
auto* z = ctx.Output<Tensor>("Out");
|
|
|
|
|
z->mutable_data<OutType>(ctx.GetPlace());
|
|
|
|
|
void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
|
|
|
|
|
const framework::Tensor* x,
|
|
|
|
|
const framework::Tensor* y, int axis,
|
|
|
|
|
framework::Tensor* z) {
|
|
|
|
|
TransformFunctor<Functor, T, DeviceContext, OutType> functor(
|
|
|
|
|
x, y, z, ctx.template device_context<DeviceContext>(), Functor());
|
|
|
|
|
|
|
|
|
@ -394,7 +387,6 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx) {
|
|
|
|
|
x_dims = framework::make_ddim(extended_dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
|
|
|
|
|
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
|
|
|
|
|
"Axis should be in range [0, x_dims)");
|
|
|
|
|