revert-31562-mean
wangchaochaohu 4 years ago committed by GitHub
parent d0b789d27f
commit eab44e1f32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -148,6 +148,8 @@ __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out,
size_t width_stride = gridDim.x * blockDim.x;
size_t full_width = (width & (~((uint64_t)(BLOCK_W - 1)))) +
((width & (BLOCK_W - 1)) ? BLOCK_W : 0);
size_t full_height = (height & (~((uint64_t)(BLOCK_H - 1)))) +
((height & (BLOCK_H - 1)) ? BLOCK_H : 0);
#pragma unroll
for (size_t w = idx; w < full_width; w += width_stride) {
@ -155,10 +157,10 @@ __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out,
__syncthreads();
size_t offset = w + threadIdx.y * width;
#pragma unroll
for (size_t h = threadIdx.y; h < height;
for (size_t h = threadIdx.y; h < full_height;
h += BLOCK_H) { // block-stride loop across matrix height
sdata[threadIdx.y][threadIdx.x] +=
(w < width) ? in[offset] : (static_cast<T>(0));
(w < width && h < height) ? in[offset] : (static_cast<T>(0));
offset += width * BLOCK_H;
}
__syncthreads();
@ -184,21 +186,24 @@ __global__ void FP16MatrixColReduce(
size_t width_stride = gridDim.x * blockDim.x;
size_t full_width = (width & (~((uint64_t)(BLOCK_W - 1)))) +
((width & (BLOCK_W - 1)) ? BLOCK_W : 0);
size_t full_height = (height & (~((uint64_t)(BLOCK_H - 1)))) +
((height & (BLOCK_H - 1)) ? BLOCK_H : 0);
#pragma unroll
for (size_t w = idx; w < full_width; w += width_stride) {
for (int r = 0; r < repeats; r++) {
sdata[threadIdx.y + r * BLOCK_W][threadIdx.x] = 0;
}
__syncthreads();
#pragma unroll
for (int r = 0; r < repeats; r++) {
size_t offset = w + (r * BLOCK_W + threadIdx.y) * width;
#pragma unroll
for (size_t h = r * BLOCK_H + threadIdx.y; h < height;
for (size_t h = threadIdx.y + r * BLOCK_W; h < full_height;
h += BLOCK_H) { // block-stride loop across matrix height
sdata[r * BLOCK_W + threadIdx.y][threadIdx.x] +=
(w < width) ? in[offset + r * BLOCK_W * width]
: (static_cast<paddle::platform::float16>(0));
(w < width && h < height)
? in[offset]
: (static_cast<paddle::platform::float16>(0));
offset += width * BLOCK_H;
}
}
@ -373,6 +378,7 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes,
dout_data, out_data, nums, stream);
PADDLE_ENFORCE_CUDA_SUCCESS(err);
return;
}
constexpr int block_x = 32;

Loading…
Cancel
Save