From b8f7fa97b6f2f8787c9fced40004a3cb45795a05 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 2 May 2018 20:13:59 +0800 Subject: [PATCH 1/3] replace __shfl with __shfl_sync --- paddle/cuda/src/hl_top_k.cu | 9 +++++---- paddle/fluid/operators/top_k_op.cu | 7 ++++++- paddle/fluid/platform/cuda_primitives.h | 7 +++++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/paddle/cuda/src/hl_top_k.cu b/paddle/cuda/src/hl_top_k.cu index 59ba552f56..4a737d5ba7 100644 --- a/paddle/cuda/src/hl_top_k.cu +++ b/paddle/cuda/src/hl_top_k.cu @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "hl_base.h" -#include "hl_sparse.ph" -#include "hl_top_k.h" +#include "paddle/cuda/include/hl_base.h" +#include "paddle/cuda/include/hl_sparse.ph" +#include "paddle/cuda/include/hl_top_k.h" #include "paddle/utils/Logging.h" // using namespace hppl; @@ -244,8 +244,9 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK, if (--beamSize == 0) break; __syncthreads(); + // temporary solution unsigned mask = 0u; - // CREATE_SHFL_MASK(mask, tid < len); + CREATE_SHFL_MASK(mask, true); if (tid == maxId[0]) { if (beam < maxLength) { diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index d7f4d383ce..a2e3973fe8 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/platform/assert.h" +#include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { @@ -235,8 +236,12 @@ __device__ __forceinline__ void BlockReduce(Pair* sh_topk, int* maxid, sh_topk[tid] = topk[*beam]; } } + // temporary solution + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + if (maxid[0] / 32 == warp) { - if (__shfl(*beam, (maxid[0]) % 32, 32) == MaxLength) break; + if (__shfl_sync(mask, *beam, (maxid[0]) % 32, 32) == MaxLength) break; } } } diff --git a/paddle/fluid/platform/cuda_primitives.h b/paddle/fluid/platform/cuda_primitives.h index 866ff30a8b..0f6e6159b6 100644 --- a/paddle/fluid/platform/cuda_primitives.h +++ b/paddle/fluid/platform/cuda_primitives.h @@ -72,6 +72,13 @@ template __forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) { return __shfl_down(val, delta); } + +template +__forceinline__ __device__ T __shfl_sync(unsigned, T val, int src_line, + int width) { + return __shfl(val, src_line, width); +} + #define CREATE_SHFL_MASK(mask, predicate) mask = 0u; #else #define FULL_WARP_MASK 0xFFFFFFFF From 9ab8faaf76057c21be368adb0e23999b3acc5028 Mon Sep 17 00:00:00 2001 From: xzl Date: Thu, 3 May 2018 13:05:01 +0800 Subject: [PATCH 2/3] fix pool with mask layer bug --- paddle/math/Matrix.cpp | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 0e84cb3739..bcd6dfe1fd 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -2157,26 +2157,20 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, int wend = wstart + sizeX; wstart = wstart < 0 ? 0 : wstart; wend = wend < (int)imgSizeW ? wend : (int)imgSizeW; - if (maskData == NULL) { - real tmp = -(real)FLT_MAX; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - tmp = tmp < inputData[h * imgSizeW + w] - ? inputData[h * imgSizeW + w] - : tmp; - } - } - outData[ph * outputW + pw] = tmp; - } else { - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - if (outData[ph * outputW + pw] < inputData[h * imgSizeW + w]) { - outData[ph * outputW + pw] = inputData[h * imgSizeW + w]; - maskData[ph * outputW + pw] = h * imgSizeW + w; - } + + real maxval = -(real)FLT_MAX; + int max_index = -1; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (maxval < inputData[h * imgSizeW + w]) { + maxval = inputData[h * imgSizeW + w]; + max_index = h * imgSizeW + w; } } } + + outData[ph * outputW + pw] = maxval; + if (maskData != NULL) maskData[ph * outputW + pw] = max_index; } } // compute offset From 62fed4cbb33275d1fc4b02f1617b8b8efddd4b00 Mon Sep 17 00:00:00 2001 From: chengduo Date: Thu, 3 May 2018 15:12:20 +0800 Subject: [PATCH 3/3] fix __shfl_down (#10362) --- paddle/cuda/include/hl_base.h | 5 ++++ paddle/fluid/operators/row_conv_op.cu | 12 +++++++-- paddle/function/RowConvOpGpu.cu | 35 +++++++++++++++------------ 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/paddle/cuda/include/hl_base.h b/paddle/cuda/include/hl_base.h index 402302a5bf..77f5d82dbe 100644 --- a/paddle/cuda/include/hl_base.h +++ b/paddle/cuda/include/hl_base.h @@ -229,6 +229,11 @@ extern __thread cudaStream_t default_stream; // __shfl has been deprecated as of CUDA 9.0. #if CUDA_VERSION < 9000 +template +__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) { + return __shfl_down(val, delta); +} + template __forceinline__ __device__ T __shfl_sync(unsigned, T val, int src_line, int width) { diff --git a/paddle/fluid/operators/row_conv_op.cu b/paddle/fluid/operators/row_conv_op.cu index 79d08cf3d1..082f761d37 100644 --- a/paddle/fluid/operators/row_conv_op.cu +++ b/paddle/fluid/operators/row_conv_op.cu @@ -189,6 +189,10 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout, } __syncthreads(); + // NOTE(zcd): temporary solution + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + for (int i = 0; i < num_sequence; i++) { int start = static_cast(batch_indices[i]); int end = static_cast(batch_indices[i + 1]); @@ -220,7 +224,7 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout, for (int offset = 16; offset > 0; offset = offset / 2) { // blockDim.x is 32. - val += platform::__shfl_down_sync(0, val, offset); + val += platform::__shfl_down_sync(mask, val, offset); } __syncthreads(); @@ -251,6 +255,10 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence, T *sh_in = mem; T *sh_dout = &mem[block_x * block_y]; + // NOTE(zcd): temporary solution + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + for (int i = 0; i < num_sequence; i++) { int start = static_cast(batch_indices[i]); int end = static_cast(batch_indices[i + 1]); @@ -276,7 +284,7 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence, for (int offset = 16; offset > 0; offset = offset / 2) { // blockDim.x is 32. - val += platform::__shfl_down_sync(0, val, offset); + val += platform::__shfl_down_sync(mask, val, offset); } __syncthreads(); diff --git a/paddle/function/RowConvOpGpu.cu b/paddle/function/RowConvOpGpu.cu index 9d8a6d80bb..f820ee9a97 100644 --- a/paddle/function/RowConvOpGpu.cu +++ b/paddle/function/RowConvOpGpu.cu @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "RowConvOp.h" -#include "hl_base.h" +#include "paddle/cuda/include/hl_base.h" +#include "paddle/function/RowConvOp.h" namespace paddle { @@ -94,7 +94,7 @@ __global__ void KeRowConv2(real* y, } template <> -void RowConv(GpuMatrix& out, +void RowConv(GpuMatrix& out, // NOLINT const GpuMatrix& in, const GpuMatrix& filter, const GpuIVector& seq) { @@ -144,6 +144,10 @@ __global__ void KeRowConvBwWeight(real* dw, } __syncthreads(); + // NOTE(zcd): temporary solution + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + for (int i = 0; i < numSeq; ++i) { const int start = starts[i]; const int end = starts[i + 1]; @@ -170,11 +174,10 @@ __global__ void KeRowConvBwWeight(real* dw, real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx + context - 1 - t]; __syncthreads(); // warp size and blockDim.x is 32. - val += __shfl_down(val, 16); - val += __shfl_down(val, 8); - val += __shfl_down(val, 4); - val += __shfl_down(val, 2); - val += __shfl_down(val, 1); + + for (int offset = 16; offset > 0; offset /= 2) + val += __shfl_down_sync(mask, val, offset); + __syncthreads(); if (tidx == 0) { sh_dw[t][tidy] += val; @@ -205,6 +208,10 @@ __global__ void KeRowConvBwWeight2(real* dw, __shared__ real sh_x[BLOCK_H][BLOCK_W]; __shared__ real sh_dy[BLOCK_H][BLOCK_W]; + // NOTE(zcd): temporary solution + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + for (int i = 0; i < numSeq; ++i) { const int start = starts[i]; const int end = starts[i + 1]; @@ -230,11 +237,9 @@ __global__ void KeRowConvBwWeight2(real* dw, real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx]; __syncthreads(); // warp size and blockDim.x is 32. - val += __shfl_down(val, 16); - val += __shfl_down(val, 8); - val += __shfl_down(val, 4); - val += __shfl_down(val, 2); - val += __shfl_down(val, 1); + for (int offset = 16; offset > 0; offset /= 2) + val += __shfl_down_sync(mask, val, offset); + __syncthreads(); if (tidx == 0 && (gidx + tidy) < width) { @@ -323,8 +328,8 @@ template <> void RowConvGrad(const GpuMatrix& outG, const GpuMatrix& in, const GpuMatrix& filter, - GpuMatrix& inG, - GpuMatrix& filterG, + GpuMatrix& inG, // NOLINT + GpuMatrix& filterG, // NOLINT const GpuIVector& seq) { const size_t numSeq = seq.getSize() - 1; const size_t contextLength = filter.getHeight();