|
|
|
@ -197,6 +197,40 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
template <typename Functor, typename T, typename OutType>
|
|
|
|
|
__global__ void ElementwiseKernel(const T *x, const T *y, OutType *out, int pre,
|
|
|
|
|
int n, int post, int total, Functor func) {
|
|
|
|
|
int tid = threadIdx.x + blockDim.x * blockIdx.x;
|
|
|
|
|
int idx = tid / post % n;
|
|
|
|
|
if (tid < total) {
|
|
|
|
|
out[tid] = func(x[tid], y[idx]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Functor, typename T, typename OutType>
|
|
|
|
|
void ComputeElementwiseCUDA(const framework::Tensor *x,
|
|
|
|
|
const framework::Tensor *y, framework::Tensor *z,
|
|
|
|
|
int pre, int n, int post,
|
|
|
|
|
const platform::CUDADeviceContext &ctx,
|
|
|
|
|
Functor func, const bool is_xsize_larger = true) {
|
|
|
|
|
const T *x_data = x->data<T>();
|
|
|
|
|
const T *y_data = y->data<T>();
|
|
|
|
|
OutType *out_data = z->mutable_data<OutType>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
int numel = pre * n * post;
|
|
|
|
|
int threads = 256;
|
|
|
|
|
int blocks = (numel + threads - 1) / threads;
|
|
|
|
|
if (is_xsize_larger) {
|
|
|
|
|
ElementwiseKernel<Functor, T,
|
|
|
|
|
OutType><<<blocks, threads, 0, ctx.stream()>>>(
|
|
|
|
|
x_data, y_data, out_data, pre, n, post, numel, func);
|
|
|
|
|
} else {
|
|
|
|
|
ElementwiseKernel<Functor, T,
|
|
|
|
|
OutType><<<blocks, threads, 0, ctx.stream()>>>(
|
|
|
|
|
y_data, x_data, out_data, pre, n, post, numel, func);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Functor, typename T, typename OutType = T>
|
|
|
|
|
__global__ void CommonForwardBroadcastCUDAKernel(
|
|
|
|
|
const int *x_strides_array, const int *y_strides_array,
|
|
|
|
@ -1908,6 +1942,16 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
|
|
|
|
|
ctx, x, y, z, x_dims, y_dims, func, axis, is_xsize_larger);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
ComputeElementwiseCUDA<Functor, T, OutType>(
|
|
|
|
|
x, y, z, pre, n, post,
|
|
|
|
|
ctx.template device_context<platform::CUDADeviceContext>(), func,
|
|
|
|
|
is_xsize_larger);
|
|
|
|
|
#endif
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (post == 1) {
|
|
|
|
|
functor.RunRowWise(n, pre);
|
|
|
|
|
return;
|
|
|
|
|