Fix __shfl_down_sync_ of cross_entropy (#10345)

* fix __shfl_down_sync_ of cross_entropy

* use reduceSum

* "fix ci"
trainerSaveLoadParams
chengduo 7 years ago committed by dzhwinter
parent 6d5e582d20
commit 4fbde42cdf

@ -22,6 +22,7 @@ limitations under the License. */
#ifdef __NVCC__
#include <cuda.h>
#include <thrust/iterator/iterator_adaptor.h>
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
#endif
@ -336,43 +337,6 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
}
#ifdef __NVCC__
template <typename T>
__device__ T reduceSum(T val, int tid, int len) {
// NOTE(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
const int warpSize = 32;
__shared__ T shm[warpSize];
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, tid < len);
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::__shfl_down_sync(mask, val, offset);
if (tid < warpSize) shm[tid] = 0;
__syncthreads();
if (tid % warpSize == 0) {
shm[tid / warpSize] = val;
}
__syncthreads();
CREATE_SHFL_MASK(mask, tid < warpSize);
if (tid < warpSize) {
val = shm[tid];
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::__shfl_down_sync(mask, val, offset);
}
return val;
}
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
const T* x, const T* y, const T* out, const T* dout, int h, int w,
@ -395,7 +359,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
if (dy) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = reduceSum(val, tid, h);
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
@ -472,7 +436,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
if (dy) {
int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = reduceSum(val, tid, h);
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
@ -30,66 +31,22 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
}
}
template <typename T>
__device__ __forceinline__ T sum_single_warp(T val) {
val += platform::__shfl_down_sync(0, val, 16);
val += platform::__shfl_down_sync(0, val, 8);
val += platform::__shfl_down_sync(0, val, 4);
val += platform::__shfl_down_sync(0, val, 2);
val += platform::__shfl_down_sync(0, val, 1);
return val;
}
// CUDA do not support dynamic arrary in template
// https://stackoverflow.com/questions/20497209
template <typename T>
struct SharedMemory {
// Ensure that we won't compile any un-specialized types
__device__ T* GetPointer() { return NULL; }
};
template <>
struct SharedMemory<float> {
__device__ float* GetPointer() {
extern __shared__ float s_float[];
return s_float;
}
};
template <>
struct SharedMemory<double> {
__device__ double* GetPointer() {
extern __shared__ double s_double[];
return s_double;
}
};
template <typename T>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
const int class_num) {
int tid = threadIdx.x;
SharedMemory<T> d_sum_shared;
T* d_sum = d_sum_shared.GetPointer();
d_sum[tid] = 0;
T val = 0;
int cur_idx = tid;
int next_idx = blockIdx.x * class_num + tid;
while (cur_idx < class_num) {
d_sum[tid] +=
math::TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
next_idx += blockDim.x;
cur_idx += blockDim.x;
int idx = blockIdx.x * class_num + tid;
int end = blockIdx.x * class_num + class_num;
for (; idx < end; idx += blockDim.x) {
val += math::TolerableValue<T>()(std::log(X[idx])) * label[idx];
}
__syncthreads();
for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
__syncthreads();
val = paddle::platform::reduceSum(val, tid, blockDim.x);
if (threadIdx.x == 0) {
Y[blockIdx.x] = -val;
}
T val = d_sum[tid];
val = sum_single_warp<T>(val);
if (tid == 0) Y[blockIdx.x] = -val;
}
} // namespace
@ -113,9 +70,7 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
? 512
: pow(2, static_cast<int>(std::log2(class_num)));
SoftCrossEntropyKernel<T><<<
batch_size, block, block * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
SoftCrossEntropyKernel<T><<<batch_size, block, 0, ctx.stream()>>>(
loss_data, prob_data, label_data, class_num);
} else {
const int64_t* label_data = labels->data<int64_t>();

@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/row_conv_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/cuda_device_function.h"
namespace paddle {
namespace operators {

@ -0,0 +1,74 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda.h>
namespace paddle {
namespace platform {
// __shfl_down and __shfl have been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
return __shfl_down(val, delta);
}
template <typename T>
__forceinline__ __device__ T __shfl_sync(unsigned, T val, int src_line,
int width) {
return __shfl(val, src_line, width);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
template <typename T>
__device__ T reduceSum(T val, int tid, int len) {
// NOTE(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
const int warpSize = 32;
__shared__ T shm[warpSize];
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, tid < len);
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::__shfl_down_sync(mask, val, offset);
if (tid < warpSize) shm[tid] = 0;
if (tid % warpSize == 0) {
shm[tid / warpSize] = val;
}
__syncthreads();
CREATE_SHFL_MASK(mask, tid < warpSize);
if (tid < warpSize) {
val = shm[tid];
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::__shfl_down_sync(mask, val, offset);
}
return val;
}
} // namespace platform
} // namespace paddle

@ -66,18 +66,5 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
}
#endif
// __shfl_down has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
return __shfl_down(val, delta);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
} // namespace platform
} // namespace paddle

Loading…
Cancel
Save