[Paddle-TRT] slice kernel optimization (#24783)

* parallel move shared data test=develop

* test=develop
fix-sync_batch_norm-hang-in-fleet
zlsh80826 5 years ago committed by GitHub
parent 1a7fbb734c
commit 49e4ee27e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -35,10 +35,8 @@ __global__ void SliceKernel(int num, int dims, const T *input,
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
extern __shared__ int shared_data[];
if (threadIdx.x == 0) {
for (int i = 0; i < dims * 3; i++) {
shared_data[i] = offsets_info[i];
}
for (int i = threadIdx.x; i < dims * 3; i += blockDim.x) {
shared_data[i] = offsets_info[i];
}
__syncthreads();

Loading…
Cancel
Save