|
|
|
@ -28,39 +28,7 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
class ElementwiseAddKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
auto* x = ctx.Input<Tensor>("X");
|
|
|
|
|
auto* y = ctx.Input<Tensor>("Y");
|
|
|
|
|
auto* z = ctx.Output<Tensor>("Out");
|
|
|
|
|
z->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
TransformFunctor<AddFunctor<T>, T, DeviceContext> functor(
|
|
|
|
|
x, y, z, ctx.template device_context<DeviceContext>(), AddFunctor<T>());
|
|
|
|
|
|
|
|
|
|
auto x_dims = x->dims();
|
|
|
|
|
auto y_dims = y->dims();
|
|
|
|
|
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
|
|
|
|
|
"Rank of first input must >= rank of second input.");
|
|
|
|
|
|
|
|
|
|
if (x_dims == y_dims) {
|
|
|
|
|
functor.Run();
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
|
|
|
|
|
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
|
|
|
|
|
"Axis should be in range [0, x_dims)");
|
|
|
|
|
|
|
|
|
|
int pre, n, post;
|
|
|
|
|
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
|
|
|
|
|
if (post == 1) {
|
|
|
|
|
functor.RunRowWise(n, pre);
|
|
|
|
|
return;
|
|
|
|
|
} else {
|
|
|
|
|
functor.RunMidWise(n, pre, post);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|