|
|
|
@ -32,37 +32,71 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__device__ __forceinline__ T sum_single_warp(T val) {
|
|
|
|
|
val += __shfl_down(val, 16);
|
|
|
|
|
val += __shfl_down(val, 8);
|
|
|
|
|
val += __shfl_down(val, 4);
|
|
|
|
|
val += __shfl_down(val, 2);
|
|
|
|
|
val += __shfl_down(val, 1);
|
|
|
|
|
return val;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// This kernel is called when the class number is less than or equal to 512.
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void SoftCrossEntropyKernel1(T* Y, const T* X, const T* label,
|
|
|
|
|
const int class_num) {
|
|
|
|
|
int tid = threadIdx.x;
|
|
|
|
|
extern __shared__ T d_sum[];
|
|
|
|
|
d_sum[tid] = 0;
|
|
|
|
|
|
|
|
|
|
int cur_idx = tid;
|
|
|
|
|
int next_idx = blockIdx.x * class_num + tid;
|
|
|
|
|
while (cur_idx < class_num) {
|
|
|
|
|
d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
|
|
|
|
|
next_idx += blockDim.x;
|
|
|
|
|
cur_idx += blockDim.x;
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
|
|
|
|
|
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
|
|
|
|
|
__syncthreads();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
T val = d_sum[tid];
|
|
|
|
|
val = sum_single_warp<T>(val);
|
|
|
|
|
if (tid == 0) Y[blockIdx.x] = -val;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// This kernel is called when the class number is larger than 512.
|
|
|
|
|
template <typename T, int BlockSize>
|
|
|
|
|
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
|
|
|
|
|
const int N, const int D) {
|
|
|
|
|
__global__ void SoftCrossEntropyKernel2(T* Y, const T* X, const T* label,
|
|
|
|
|
const int class_num) {
|
|
|
|
|
int tid = threadIdx.x;
|
|
|
|
|
__shared__ T d_sum[BlockSize];
|
|
|
|
|
int next_idx = blockIdx.x * D + tid;
|
|
|
|
|
int next_idx = blockIdx.x * class_num + tid;
|
|
|
|
|
|
|
|
|
|
d_sum[tid] = 0;
|
|
|
|
|
int cur_idx = tid;
|
|
|
|
|
while (cur_idx < D) {
|
|
|
|
|
while (cur_idx < class_num) {
|
|
|
|
|
d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
|
|
|
|
|
next_idx += BlockSize;
|
|
|
|
|
cur_idx += BlockSize;
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
for (int stride = BlockSize >> 1; stride > 0; stride >>= 1) {
|
|
|
|
|
for (unsigned int stride = BlockSize >> 1; stride >= 32; stride >>= 1) {
|
|
|
|
|
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
|
|
|
|
|
__syncthreads();
|
|
|
|
|
if (tid < stride) {
|
|
|
|
|
next_idx = tid + stride;
|
|
|
|
|
d_sum[tid] += d_sum[next_idx];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
if (tid == 0) {
|
|
|
|
|
Y[blockIdx.x] = -d_sum[0];
|
|
|
|
|
}
|
|
|
|
|
T val = d_sum[tid];
|
|
|
|
|
val = sum_single_warp<T>(val);
|
|
|
|
|
if (tid == 0) Y[blockIdx.x] = -val;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(qingqing): make zero setting an common function.
|
|
|
|
|
// TODO(qingqing): make zero setting a common function.
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void zero(T* X, const int N) {
|
|
|
|
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
|
|
|
|
@ -88,11 +122,9 @@ template <typename T>
|
|
|
|
|
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
|
|
|
|
|
const T* label, const int N,
|
|
|
|
|
const int D) {
|
|
|
|
|
int row_ids = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
int col_ids = blockIdx.y * blockDim.y + threadIdx.y;
|
|
|
|
|
int ids = row_ids * D + col_ids;
|
|
|
|
|
|
|
|
|
|
int ids = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
if (ids < N * D) {
|
|
|
|
|
int row_ids = ids / D;
|
|
|
|
|
dX[ids] = -label[ids] * dY[row_ids] / X[ids];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -112,20 +144,34 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
|
|
|
|
|
y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto* y_data = y->data<T>();
|
|
|
|
|
|
|
|
|
|
int n = x->dims()[0];
|
|
|
|
|
int d = x->dims()[1];
|
|
|
|
|
int batch_size = x->dims()[0];
|
|
|
|
|
int class_num = x->dims()[1];
|
|
|
|
|
int block = 512;
|
|
|
|
|
int grid = (n + block - 1) / block;
|
|
|
|
|
// TODO(qingqing) launch kernel on specified stream
|
|
|
|
|
// base on ExecutionContext.
|
|
|
|
|
|
|
|
|
|
if (ctx.Attr<bool>("soft_label")) {
|
|
|
|
|
auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
|
|
|
|
|
grid = d;
|
|
|
|
|
SoftCrossEntropyKernel<T, 512><<<grid, block>>>(y_data, x_data,
|
|
|
|
|
label_data, n, d);
|
|
|
|
|
if (class_num > 512) {
|
|
|
|
|
SoftCrossEntropyKernel2<
|
|
|
|
|
T, 512><<<batch_size, block, 0,
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
ctx.device_context())
|
|
|
|
|
.stream()>>>(y_data, x_data, label_data, class_num);
|
|
|
|
|
} else {
|
|
|
|
|
int block_size = pow(2, int(std::log2(class_num)));
|
|
|
|
|
SoftCrossEntropyKernel1<
|
|
|
|
|
T><<<batch_size, block_size, block_size * sizeof(T),
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
ctx.device_context())
|
|
|
|
|
.stream()>>>(y_data, x_data, label_data, class_num);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
|
|
|
|
|
CrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, d);
|
|
|
|
|
int grid = (batch_size + block - 1) / block;
|
|
|
|
|
CrossEntropyKernel<T><<<
|
|
|
|
|
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
ctx.device_context())
|
|
|
|
|
.stream()>>>(y_data, x_data, label_data,
|
|
|
|
|
batch_size, class_num);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -148,25 +194,27 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
|
|
|
|
|
|
|
|
|
|
int n = x->dims()[0];
|
|
|
|
|
int d = x->dims()[1];
|
|
|
|
|
|
|
|
|
|
int block = 512;
|
|
|
|
|
int grid = (n * d + block - 1) / block;
|
|
|
|
|
zero<T><<<grid, block>>>(dx_data, n * d);
|
|
|
|
|
grid = (n + block - 1) / block;
|
|
|
|
|
// TODO(qingqing): launch kernel on specified stream
|
|
|
|
|
// base on ExecutionContext.
|
|
|
|
|
zero<T><<<grid, block, 0,
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
ctx.device_context())
|
|
|
|
|
.stream()>>>(dx_data, n * d);
|
|
|
|
|
if (ctx.Attr<bool>("soft_label")) {
|
|
|
|
|
int block_x = 32;
|
|
|
|
|
int block_y = 32;
|
|
|
|
|
dim3 block(block_x, block_y);
|
|
|
|
|
dim3 grid((n + block_x - 1) / block_x, (d + block_y - 1) / block_y);
|
|
|
|
|
|
|
|
|
|
auto* label_data = label->data<T>();
|
|
|
|
|
SoftCrossEntropyGradientKernel<T><<<grid, block>>>(
|
|
|
|
|
dx_data, dy_data, x_data, label_data, n, d);
|
|
|
|
|
SoftCrossEntropyGradientKernel<T><<<
|
|
|
|
|
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
ctx.device_context())
|
|
|
|
|
.stream()>>>(dx_data, dy_data, x_data, label_data,
|
|
|
|
|
n, d);
|
|
|
|
|
} else {
|
|
|
|
|
auto* label_data = label->data<int>();
|
|
|
|
|
CrossEntropyGradientKernel<T><<<grid, block>>>(dx_data, dy_data, x_data,
|
|
|
|
|
label_data, n, d);
|
|
|
|
|
CrossEntropyGradientKernel<T><<<
|
|
|
|
|
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
ctx.device_context())
|
|
|
|
|
.stream()>>>(dx_data, dy_data, x_data, label_data,
|
|
|
|
|
n, d);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|