Optimization of elementwise CUDA kernel (#30801)

fix_imperative_dygraph_error
JamesLim 4 years ago committed by GitHub
parent 0b3c229606
commit 45c7d90564
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -99,6 +99,7 @@ inline void get_mid_dims(const framework::DDim &x_dims,
(*post) *= x_dims[i]; (*post) *= x_dims[i];
} }
} }
inline int GetElementwiseIndex(const int *x_dims_array, const int max_dim, inline int GetElementwiseIndex(const int *x_dims_array, const int max_dim,
const int *index_array) { const int *index_array) {
int index_ = 0; int index_ = 0;
@ -202,12 +203,16 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x,
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
template <typename Functor, typename T, typename OutType> template <typename Functor, typename T, typename OutType>
__global__ void ElementwiseKernel(const T *x, const T *y, OutType *out, int pre, __global__ void ElementwiseKernel(const T *__restrict__ x_data,
int n, int post, int total, Functor func) { const T *__restrict__ y_data,
OutType *__restrict__ out_data, int n,
int post, const size_t total, Functor func) {
int tid = threadIdx.x + blockDim.x * blockIdx.x; int tid = threadIdx.x + blockDim.x * blockIdx.x;
int idx = tid / post % n; int stride = blockDim.x * gridDim.x;
if (tid < total) {
out[tid] = func(x[tid], y[idx]); for (int i = tid; i < total; i += stride) {
int idx = i / post % n;
out_data[i] = func(x_data[i], y_data[idx]);
} }
} }
@ -224,14 +229,16 @@ void ComputeElementwiseCUDA(const framework::Tensor *x,
int numel = pre * n * post; int numel = pre * n * post;
int threads = 256; int threads = 256;
int blocks = (numel + threads - 1) / threads; int blocks = (numel + threads - 1) / threads;
if (is_xsize_larger) { if (is_xsize_larger) {
ElementwiseKernel<Functor, T, ElementwiseKernel<Functor, T,
OutType><<<blocks, threads, 0, ctx.stream()>>>( OutType><<<blocks, threads, 0, ctx.stream()>>>(
x_data, y_data, out_data, pre, n, post, numel, func); x_data, y_data, out_data, n, post, numel, func);
} else { } else {
ElementwiseKernel<Functor, T, ElementwiseKernel<Functor, T,
OutType><<<blocks, threads, 0, ctx.stream()>>>( OutType><<<blocks, threads, 0, ctx.stream()>>>(
y_data, x_data, out_data, pre, n, post, numel, func); y_data, x_data, out_data, n, post, numel, func);
} }
} }

Loading…
Cancel
Save