Optimization of elementwise CUDA kernel (#30801)

fix_imperative_dygraph_error
JamesLim 4 years ago committed by GitHub
parent 0b3c229606
commit 45c7d90564
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -99,6 +99,7 @@ inline void get_mid_dims(const framework::DDim &x_dims,
(*post) *= x_dims[i];
}
}
inline int GetElementwiseIndex(const int *x_dims_array, const int max_dim,
const int *index_array) {
int index_ = 0;
@ -202,12 +203,16 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x,
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename Functor, typename T, typename OutType>
__global__ void ElementwiseKernel(const T *x, const T *y, OutType *out, int pre,
int n, int post, int total, Functor func) {
__global__ void ElementwiseKernel(const T *__restrict__ x_data,
const T *__restrict__ y_data,
OutType *__restrict__ out_data, int n,
int post, const size_t total, Functor func) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int idx = tid / post % n;
if (tid < total) {
out[tid] = func(x[tid], y[idx]);
int stride = blockDim.x * gridDim.x;
for (int i = tid; i < total; i += stride) {
int idx = i / post % n;
out_data[i] = func(x_data[i], y_data[idx]);
}
}
@ -224,14 +229,16 @@ void ComputeElementwiseCUDA(const framework::Tensor *x,
int numel = pre * n * post;
int threads = 256;
int blocks = (numel + threads - 1) / threads;
if (is_xsize_larger) {
ElementwiseKernel<Functor, T,
OutType><<<blocks, threads, 0, ctx.stream()>>>(
x_data, y_data, out_data, pre, n, post, numel, func);
x_data, y_data, out_data, n, post, numel, func);
} else {
ElementwiseKernel<Functor, T,
OutType><<<blocks, threads, 0, ctx.stream()>>>(
y_data, x_data, out_data, pre, n, post, numel, func);
y_data, x_data, out_data, n, post, numel, func);
}
}

Loading…
Cancel
Save