|
|
|
@ -28,26 +28,49 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
|
|
|
|
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
|
|
|
|
|
i += blockDim.x * gridDim.x) {
|
|
|
|
|
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
|
|
|
|
|
Y[i] = -tolerable_value(log(X[i * D + label[i]]));
|
|
|
|
|
Y[i] = -TolerableValue<T>()(log(X[i * D + label[i]]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__device__ __forceinline__ T sum_single_warp(T val) {
|
|
|
|
|
val += __shfl_down(val, 16);
|
|
|
|
|
val += __shfl_down(val, 8);
|
|
|
|
|
val += __shfl_down(val, 4);
|
|
|
|
|
val += __shfl_down(val, 2);
|
|
|
|
|
val += __shfl_down(val, 1);
|
|
|
|
|
return val;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
|
|
|
|
|
const int N, const int D) {
|
|
|
|
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
|
|
|
|
|
i += blockDim.x * gridDim.x) {
|
|
|
|
|
T sum = static_cast<T>(0);
|
|
|
|
|
for (int j = 0; j < D; j++) {
|
|
|
|
|
sum += label[i * D + j] * tolerable_value(log(X[i * D + j]));
|
|
|
|
|
}
|
|
|
|
|
Y[i] = -sum;
|
|
|
|
|
const int class_num) {
|
|
|
|
|
int tid = threadIdx.x;
|
|
|
|
|
extern __shared__ T d_sum[];
|
|
|
|
|
d_sum[tid] = 0;
|
|
|
|
|
|
|
|
|
|
int cur_idx = tid;
|
|
|
|
|
int next_idx = blockIdx.x * class_num + tid;
|
|
|
|
|
while (cur_idx < class_num) {
|
|
|
|
|
d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
|
|
|
|
|
next_idx += blockDim.x;
|
|
|
|
|
cur_idx += blockDim.x;
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
|
|
|
|
|
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
|
|
|
|
|
__syncthreads();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
T val = d_sum[tid];
|
|
|
|
|
val = sum_single_warp<T>(val);
|
|
|
|
|
if (tid == 0) Y[blockIdx.x] = -val;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(qingqing): make zero setting an common function.
|
|
|
|
|
// TODO(qingqing): make zero setting a common function.
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void zero(T* X, const int N) {
|
|
|
|
|
__global__ void Zero(T* X, const int N) {
|
|
|
|
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
|
|
|
|
|
i += blockDim.x * gridDim.x) {
|
|
|
|
|
X[i] = 0.0;
|
|
|
|
@ -71,13 +94,10 @@ template <typename T>
|
|
|
|
|
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
|
|
|
|
|
const T* label, const int N,
|
|
|
|
|
const int D) {
|
|
|
|
|
// TOOD(qingqing): optimize for this kernel
|
|
|
|
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
|
|
|
|
|
i += blockDim.x * gridDim.x) {
|
|
|
|
|
for (int j = 0; j < D; ++j) {
|
|
|
|
|
int idx = i * D + j;
|
|
|
|
|
dX[idx] = -label[idx] * dY[i] / X[idx];
|
|
|
|
|
}
|
|
|
|
|
int ids = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
if (ids < N * D) {
|
|
|
|
|
int row_ids = ids / D;
|
|
|
|
|
dX[ids] = -label[ids] * dY[row_ids] / X[ids];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -86,29 +106,36 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
|
|
|
|
"It must use GPUPlace.");
|
|
|
|
|
"This kernel only runs on GPU device.");
|
|
|
|
|
|
|
|
|
|
auto x = ctx.Input<Tensor>("X");
|
|
|
|
|
auto y = ctx.Output<Tensor>("Y");
|
|
|
|
|
auto label = ctx.Input<Tensor>("Label");
|
|
|
|
|
const Tensor* x = ctx.Input<Tensor>("X");
|
|
|
|
|
const Tensor* label = ctx.Input<Tensor>("Label");
|
|
|
|
|
Tensor* y = ctx.Output<Tensor>("Y");
|
|
|
|
|
|
|
|
|
|
auto* x_data = x->data<T>();
|
|
|
|
|
y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto* y_data = y->data<T>();
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
T* y_data = y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
int n = x->dims()[0];
|
|
|
|
|
int d = x->dims()[1];
|
|
|
|
|
int block = 512;
|
|
|
|
|
int grid = (n + block - 1) / block;
|
|
|
|
|
// TODO(qingqing) launch kernel on specified stream
|
|
|
|
|
// base on ExecutionContext.
|
|
|
|
|
if (ctx.Attr<bool>("soft_label")) {
|
|
|
|
|
int batch_size = x->dims()[0];
|
|
|
|
|
int class_num = x->dims()[1];
|
|
|
|
|
|
|
|
|
|
if (ctx.Attr<bool>("softLabel")) {
|
|
|
|
|
auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
|
|
|
|
|
SoftCrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n,
|
|
|
|
|
d);
|
|
|
|
|
int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
|
|
|
|
|
|
|
|
|
|
SoftCrossEntropyKernel<
|
|
|
|
|
T><<<batch_size, block, block * sizeof(T),
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
ctx.device_context())
|
|
|
|
|
.stream()>>>(y_data, x_data, label_data, class_num);
|
|
|
|
|
} else {
|
|
|
|
|
auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
|
|
|
|
|
CrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, d);
|
|
|
|
|
int block = 512;
|
|
|
|
|
int grid = (batch_size + block - 1) / block;
|
|
|
|
|
CrossEntropyKernel<T><<<
|
|
|
|
|
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
ctx.device_context())
|
|
|
|
|
.stream()>>>(y_data, x_data, label_data,
|
|
|
|
|
batch_size, class_num);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -118,33 +145,43 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
|
|
|
|
"It must use GPUPlace.");
|
|
|
|
|
"This kernel only runs on GPU device.");
|
|
|
|
|
|
|
|
|
|
const Tensor* x = ctx.Input<Tensor>("X");
|
|
|
|
|
const Tensor* label = ctx.Input<Tensor>("Label");
|
|
|
|
|
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
|
|
|
|
|
auto x = ctx.Input<Tensor>("X");
|
|
|
|
|
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
auto label = ctx.Input<Tensor>("Label");
|
|
|
|
|
const T* dy_data =
|
|
|
|
|
ctx.Input<Tensor>(framework::GradVarName("Y"))->data<T>();
|
|
|
|
|
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
|
|
|
|
|
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto* dy_data = dy->data<T>();
|
|
|
|
|
auto* x_data = x->data<T>();
|
|
|
|
|
int batch_size = x->dims()[0];
|
|
|
|
|
int class_num = x->dims()[1];
|
|
|
|
|
|
|
|
|
|
int n = x->dims()[0];
|
|
|
|
|
int d = x->dims()[1];
|
|
|
|
|
int block = 512;
|
|
|
|
|
int grid = (n * d + block - 1) / block;
|
|
|
|
|
zero<T><<<grid, block>>>(dx_data, n * d);
|
|
|
|
|
grid = (n + block - 1) / block;
|
|
|
|
|
// TODO(qingqing): launch kernel on specified stream
|
|
|
|
|
// base on ExecutionContext.
|
|
|
|
|
if (ctx.Attr<bool>("soft_label")) {
|
|
|
|
|
int grid = (batch_size * class_num + block - 1) / block;
|
|
|
|
|
|
|
|
|
|
if (ctx.Attr<bool>("softLabel")) {
|
|
|
|
|
auto* label_data = label->data<T>();
|
|
|
|
|
SoftCrossEntropyGradientKernel<T><<<grid, block>>>(
|
|
|
|
|
dx_data, dy_data, x_data, label_data, n, d);
|
|
|
|
|
SoftCrossEntropyGradientKernel<T><<<
|
|
|
|
|
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
ctx.device_context())
|
|
|
|
|
.stream()>>>(dx_data, dy_data, x_data, label_data,
|
|
|
|
|
batch_size, class_num);
|
|
|
|
|
} else {
|
|
|
|
|
Zero<T><<<grid, block, 0,
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
ctx.device_context())
|
|
|
|
|
.stream()>>>(dx_data, batch_size * class_num);
|
|
|
|
|
|
|
|
|
|
auto* label_data = label->data<int>();
|
|
|
|
|
CrossEntropyGradientKernel<T><<<grid, block>>>(dx_data, dy_data, x_data,
|
|
|
|
|
label_data, n, d);
|
|
|
|
|
grid = (batch_size + block - 1) / block;
|
|
|
|
|
CrossEntropyGradientKernel<T><<<
|
|
|
|
|
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
ctx.device_context())
|
|
|
|
|
.stream()>>>(dx_data, dy_data, x_data, label_data,
|
|
|
|
|
batch_size, class_num);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|