|
|
@ -176,14 +176,15 @@ class MidWiseTransformIterator<T, platform::CUDADeviceContext>
|
|
|
|
};
|
|
|
|
};
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Functor, typename T, typename DeviceContext>
|
|
|
|
template <typename Functor, typename T, typename DeviceContext,
|
|
|
|
|
|
|
|
typename OutType = T>
|
|
|
|
class TransformFunctor {
|
|
|
|
class TransformFunctor {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
|
|
|
|
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
|
|
|
|
framework::Tensor* z, const DeviceContext& ctx, Functor func)
|
|
|
|
framework::Tensor* z, const DeviceContext& ctx, Functor func)
|
|
|
|
: x_(x->data<T>()),
|
|
|
|
: x_(x->data<T>()),
|
|
|
|
y_(y->data<T>()),
|
|
|
|
y_(y->data<T>()),
|
|
|
|
z_(z->mutable_data<T>(ctx.GetPlace())),
|
|
|
|
z_(z->mutable_data<OutType>(ctx.GetPlace())),
|
|
|
|
nx_(x->numel()),
|
|
|
|
nx_(x->numel()),
|
|
|
|
ctx_(ctx),
|
|
|
|
ctx_(ctx),
|
|
|
|
func_(func) {}
|
|
|
|
func_(func) {}
|
|
|
@ -208,7 +209,7 @@ class TransformFunctor {
|
|
|
|
private:
|
|
|
|
private:
|
|
|
|
const T* x_;
|
|
|
|
const T* x_;
|
|
|
|
const T* y_;
|
|
|
|
const T* y_;
|
|
|
|
T* z_;
|
|
|
|
OutType* z_;
|
|
|
|
int64_t nx_;
|
|
|
|
int64_t nx_;
|
|
|
|
const DeviceContext& ctx_;
|
|
|
|
const DeviceContext& ctx_;
|
|
|
|
Functor func_;
|
|
|
|
Functor func_;
|
|
|
@ -364,15 +365,16 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Functor, typename DeviceContext, typename T>
|
|
|
|
template <typename Functor, typename DeviceContext, typename T,
|
|
|
|
|
|
|
|
typename OutType = T>
|
|
|
|
void ElementwiseComputeEx(const framework::ExecutionContext& ctx) {
|
|
|
|
void ElementwiseComputeEx(const framework::ExecutionContext& ctx) {
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
|
|
auto* x = ctx.Input<Tensor>("X");
|
|
|
|
auto* x = ctx.Input<Tensor>("X");
|
|
|
|
auto* y = ctx.Input<Tensor>("Y");
|
|
|
|
auto* y = ctx.Input<Tensor>("Y");
|
|
|
|
auto* z = ctx.Output<Tensor>("Out");
|
|
|
|
auto* z = ctx.Output<Tensor>("Out");
|
|
|
|
z->mutable_data<T>(ctx.GetPlace());
|
|
|
|
z->mutable_data<OutType>(ctx.GetPlace());
|
|
|
|
TransformFunctor<Functor, T, DeviceContext> functor(
|
|
|
|
TransformFunctor<Functor, T, DeviceContext, OutType> functor(
|
|
|
|
x, y, z, ctx.template device_context<DeviceContext>(), Functor());
|
|
|
|
x, y, z, ctx.template device_context<DeviceContext>(), Functor());
|
|
|
|
|
|
|
|
|
|
|
|
auto x_dims = x->dims();
|
|
|
|
auto x_dims = x->dims();
|
|
|
|