|
|
|
@ -21,17 +21,16 @@ namespace operators {
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__host__ __device__ T clipping_log(const T x) {
|
|
|
|
|
__host__ __device__ T tolerable_value(const T x) {
|
|
|
|
|
PADDLE_ASSERT(std::is_floating_point<T>::value);
|
|
|
|
|
const T kApproInf = 1e20;
|
|
|
|
|
T v = log(x);
|
|
|
|
|
if (v == INFINITY) {
|
|
|
|
|
if (x == INFINITY) {
|
|
|
|
|
return kApproInf;
|
|
|
|
|
}
|
|
|
|
|
if (v == -INFINITY) {
|
|
|
|
|
if (x == -INFINITY) {
|
|
|
|
|
return -kApproInf;
|
|
|
|
|
}
|
|
|
|
|
return v;
|
|
|
|
|
return x;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -42,7 +41,20 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
|
|
|
|
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
|
|
|
|
|
i += blockDim.x * gridDim.x) {
|
|
|
|
|
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
|
|
|
|
|
Y[i] = -clipping_log(X[i * D + label[i]]);
|
|
|
|
|
Y[i] = -tolerable_value(log(X[i * D + label[i]]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
|
|
|
|
|
const int N, const int D) {
|
|
|
|
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
|
|
|
|
|
i += blockDim.x * gridDim.x) {
|
|
|
|
|
T sum = static_cast<T>(0);
|
|
|
|
|
for (int j = 0; j < D; j++) {
|
|
|
|
|
sum += label[i * D + j] * log(X[i * D + j]);
|
|
|
|
|
}
|
|
|
|
|
Y[i] = -tolerable_value(sum);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -69,57 +81,89 @@ __global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel {
|
|
|
|
|
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
|
|
|
|
|
const T* label, const int N,
|
|
|
|
|
const int D) {
|
|
|
|
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
|
|
|
|
|
i += blockDim.x * gridDim.x) {
|
|
|
|
|
for (int j = 0; j < D; ++j) {
|
|
|
|
|
int idx = i * D + j;
|
|
|
|
|
dX[idx] = -label[idx] * dY[i] / X[idx];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class CrossEntropyOpCUDAKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
|
|
|
|
"It must use GPUPlace.");
|
|
|
|
|
|
|
|
|
|
auto X = ctx.Input<Tensor>("X");
|
|
|
|
|
const T* Xdata = X->data<T>();
|
|
|
|
|
const int* label_data = ctx.Input<Tensor>("label")->data<int>();
|
|
|
|
|
auto Y = ctx.Output<Tensor>("Y");
|
|
|
|
|
Y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
T* Ydata = Y->data<T>();
|
|
|
|
|
auto x = ctx.Input<Tensor>("X");
|
|
|
|
|
auto y = ctx.Output<Tensor>("Y");
|
|
|
|
|
auto label = ctx.Input<Tensor>("Label");
|
|
|
|
|
|
|
|
|
|
auto* x_data = x->data<T>();
|
|
|
|
|
y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto* y_data = y->data<T>();
|
|
|
|
|
|
|
|
|
|
int N = X->dims()[0];
|
|
|
|
|
int D = X->dims()[1];
|
|
|
|
|
int n = x->dims()[0];
|
|
|
|
|
int d = x->dims()[1];
|
|
|
|
|
int block = 512;
|
|
|
|
|
int grid = (N + block - 1) / block;
|
|
|
|
|
int grid = (n + block - 1) / block;
|
|
|
|
|
// TODO(qingqing) launch kernel on specified stream
|
|
|
|
|
// base on ExecutionContext.
|
|
|
|
|
CrossEntropyKernel<T><<<grid, block>>>(Ydata, Xdata, label_data, N, D);
|
|
|
|
|
int label_rank = label->dims().size();
|
|
|
|
|
if (label_rank == 2) {
|
|
|
|
|
// soft cross entropy
|
|
|
|
|
auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
|
|
|
|
|
SoftCrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n,
|
|
|
|
|
d);
|
|
|
|
|
} else {
|
|
|
|
|
// normal cross entropy
|
|
|
|
|
auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
|
|
|
|
|
CrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, d);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
|
|
|
|
|
class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
|
|
|
|
"It must use GPUPlace.");
|
|
|
|
|
|
|
|
|
|
auto X = ctx.Input<Tensor>("X");
|
|
|
|
|
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
auto label = ctx.Input<Tensor>("label");
|
|
|
|
|
auto x = ctx.Input<Tensor>("X");
|
|
|
|
|
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
auto label = ctx.Input<Tensor>("Label");
|
|
|
|
|
|
|
|
|
|
auto* dXdata = dX->template mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto* dYdata = dY->template data<T>();
|
|
|
|
|
auto* Xdata = X->template data<T>();
|
|
|
|
|
auto* label_data = label->data<int>();
|
|
|
|
|
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto* dy_data = dy->data<T>();
|
|
|
|
|
auto* x_data = x->data<T>();
|
|
|
|
|
|
|
|
|
|
int N = X->dims()[0];
|
|
|
|
|
int D = X->dims()[1];
|
|
|
|
|
int n = x->dims()[0];
|
|
|
|
|
int d = x->dims()[1];
|
|
|
|
|
int block = 512;
|
|
|
|
|
int grid = (N * D + block - 1) / block;
|
|
|
|
|
zero<T><<<grid, block>>>(dXdata, N * D);
|
|
|
|
|
|
|
|
|
|
grid = (N + block - 1) / block;
|
|
|
|
|
int grid = (n * d + block - 1) / block;
|
|
|
|
|
zero<T><<<grid, block>>>(dx_data, n * d);
|
|
|
|
|
grid = (n + block - 1) / block;
|
|
|
|
|
// TODO(qingqing): launch kernel on specified stream
|
|
|
|
|
// base on ExecutionContext.
|
|
|
|
|
CrossEntropyGradientKernel<T><<<grid, block>>>(dXdata, dYdata, Xdata,
|
|
|
|
|
label_data, N, D);
|
|
|
|
|
int label_rank = label->dims().size();
|
|
|
|
|
if (label_rank == 2) {
|
|
|
|
|
// soft cross entropy
|
|
|
|
|
auto* label_data = label->data<T>();
|
|
|
|
|
SoftCrossEntropyGradientKernel<T><<<grid, block>>>(
|
|
|
|
|
dx_data, dy_data, x_data, label_data, n, d);
|
|
|
|
|
} else {
|
|
|
|
|
// normal cross entropy
|
|
|
|
|
auto* label_data = label->data<int>();
|
|
|
|
|
CrossEntropyGradientKernel<T><<<grid, block>>>(dx_data, dy_data, x_data,
|
|
|
|
|
label_data, n, d);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -127,7 +171,6 @@ class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
|
|
|
|
|
ops::OnehotCrossEntropyOpCUDAKernel<float>);
|
|
|
|
|
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy_grad,
|
|
|
|
|
ops::OnehotCrossEntropyGradientOpCUDAKernel<float>);
|
|
|
|
|
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>);
|
|
|
|
|
REGISTER_OP_GPU_KERNEL(cross_entropy_grad,
|
|
|
|
|
ops::CrossEntropyGradientOpCUDAKernel<float>);
|
|
|
|
|