Merge pull request #4144 from lcy-seso/softmax_with_cross_entropy_op
Softmax with cross entropy op.update-doc-pybind
commit
29cb85634c
@ -0,0 +1,59 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/cross_entropy.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
|
||||
template <typename T>
|
||||
class CrossEntropyFunctor<platform::CPUPlace, T> {
|
||||
public:
|
||||
void operator()(const framework::ExecutionContext& ctx,
|
||||
framework::Tensor* out, const framework::Tensor* prob,
|
||||
const framework::Tensor* labels, const bool softLabel) {
|
||||
const int batch_size = prob->dims()[0];
|
||||
if (softLabel) {
|
||||
auto in = EigenMatrix<T>::From(*prob);
|
||||
auto lbl = EigenMatrix<T>::From(*labels);
|
||||
auto loss = EigenMatrix<T>::From(*out);
|
||||
|
||||
loss.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
|
||||
-((lbl * in.log().unaryExpr(math::TolerableValue<T>()))
|
||||
.sum(Eigen::DSizes<int, 1>(1))
|
||||
.reshape(Eigen::DSizes<int, 2>(batch_size, 1)));
|
||||
} else {
|
||||
const int class_num = prob->dims()[1];
|
||||
const T* prob_data = prob->data<T>();
|
||||
T* loss_data = out->data<T>();
|
||||
|
||||
const int* label_data = labels->data<int>();
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
int index = i * class_num + label_data[i];
|
||||
loss_data[i] = -math::TolerableValue<T>()(std::log(prob_data[index]));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template class CrossEntropyFunctor<platform::CPUPlace, float>;
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,111 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/cross_entropy.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
|
||||
const int N, const int D) {
|
||||
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
|
||||
// CUDA_1D_KERNEL_LOOP(i, N) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
|
||||
i += blockDim.x * gridDim.x) {
|
||||
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
|
||||
Y[i] = -math::TolerableValue<T>()(log(X[i * D + label[i]]));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T sum_single_warp(T val) {
|
||||
val += __shfl_down(val, 16);
|
||||
val += __shfl_down(val, 8);
|
||||
val += __shfl_down(val, 4);
|
||||
val += __shfl_down(val, 2);
|
||||
val += __shfl_down(val, 1);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
|
||||
const int class_num) {
|
||||
int tid = threadIdx.x;
|
||||
extern __shared__ T d_sum[];
|
||||
d_sum[tid] = 0;
|
||||
|
||||
int cur_idx = tid;
|
||||
int next_idx = blockIdx.x * class_num + tid;
|
||||
while (cur_idx < class_num) {
|
||||
d_sum[tid] +=
|
||||
math::TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
|
||||
next_idx += blockDim.x;
|
||||
cur_idx += blockDim.x;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
|
||||
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
T val = d_sum[tid];
|
||||
val = sum_single_warp<T>(val);
|
||||
if (tid == 0) Y[blockIdx.x] = -val;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
class CrossEntropyFunctor<platform::GPUPlace, T> {
|
||||
public:
|
||||
void operator()(const framework::ExecutionContext& ctx,
|
||||
framework::Tensor* out, const framework::Tensor* prob,
|
||||
const framework::Tensor* labels, bool softLabel) {
|
||||
const T* prob_data = prob->data<T>();
|
||||
T* loss_data = out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int batch_size = prob->dims()[0];
|
||||
int class_num = prob->dims()[1];
|
||||
|
||||
if (softLabel) {
|
||||
const T* label_data = labels->data<T>();
|
||||
int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
|
||||
|
||||
SoftCrossEntropyKernel<
|
||||
T><<<batch_size, block, block * sizeof(T),
|
||||
reinterpret_cast<const platform::CUDADeviceContext&>(
|
||||
ctx.device_context())
|
||||
.stream()>>>(loss_data, prob_data, label_data, class_num);
|
||||
} else {
|
||||
const int* label_data = labels->data<int>();
|
||||
int block = 512;
|
||||
int grid = (batch_size + block - 1) / block;
|
||||
CrossEntropyKernel<T><<<
|
||||
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
|
||||
ctx.device_context())
|
||||
.stream()>>>(loss_data, prob_data, label_data,
|
||||
batch_size, class_num);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template class CrossEntropyFunctor<platform::GPUPlace, float>;
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,48 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/operator.h"
|
||||
#include "paddle/framework/tensor.h"
|
||||
#include "paddle/platform/hostdevice.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename T>
|
||||
struct TolerableValue {
|
||||
HOSTDEVICE T operator()(const T& x) const {
|
||||
PADDLE_ASSERT(std::is_floating_point<T>::value);
|
||||
const T kApproInf = 1e20;
|
||||
|
||||
if (x == INFINITY) return kApproInf;
|
||||
if (x == -INFINITY) return -kApproInf;
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class CrossEntropyFunctor {
|
||||
public:
|
||||
// (TODO caoying) it is much better to use DeviceContext as the first
|
||||
// parameter.
|
||||
void operator()(const framework::ExecutionContext& context,
|
||||
framework::Tensor* out, const framework::Tensor* prob,
|
||||
const framework::Tensor* labels, const bool softLabel);
|
||||
};
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,25 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/softmax.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template class SoftmaxFunctor<platform::GPUPlace, float>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,27 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "paddle/operators/math/softmax.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template class SoftmaxFunctor<platform::GPUPlace, float>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,73 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/operator.h"
|
||||
#include "paddle/framework/tensor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
|
||||
template <typename T>
|
||||
struct ValueClip {
|
||||
HOSTDEVICE T operator()(const T& x) const {
|
||||
const T kThreshold = -64.;
|
||||
return x < kThreshold ? kThreshold : x;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class SoftmaxFunctor {
|
||||
public:
|
||||
void operator()(const framework::ExecutionContext& context,
|
||||
const framework::Tensor* X, framework::Tensor* Y) {
|
||||
auto logits = EigenMatrix<T>::From(*X);
|
||||
auto softmax = EigenMatrix<T>::From(*Y);
|
||||
|
||||
const int kBatchDim = 0;
|
||||
const int kClassDim = 1;
|
||||
|
||||
const int batch_size = logits.dimension(kBatchDim);
|
||||
const int num_classes = logits.dimension(kClassDim);
|
||||
|
||||
Eigen::DSizes<int, 1> along_class(kClassDim);
|
||||
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
|
||||
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
|
||||
|
||||
auto shifted_logits = (logits -
|
||||
logits.maximum(along_class)
|
||||
.eval()
|
||||
.reshape(batch_by_one)
|
||||
.broadcast(one_by_class))
|
||||
.unaryExpr(ValueClip<T>());
|
||||
|
||||
softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();
|
||||
softmax.device(context.GetEigenDevice<Place>()) =
|
||||
(softmax *
|
||||
softmax.sum(along_class)
|
||||
.inverse()
|
||||
.eval()
|
||||
.reshape(batch_by_one)
|
||||
.broadcast(one_by_class));
|
||||
}
|
||||
};
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,169 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/softmax_with_cross_entropy_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class SoftmaxWithCrossEntropyOpMaker
|
||||
: public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
SoftmaxWithCrossEntropyOpMaker(framework::OpProto* proto,
|
||||
framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Logits",
|
||||
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
|
||||
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
|
||||
"and K is the class number.")
|
||||
.NotInGradient();
|
||||
AddInput(
|
||||
"Label",
|
||||
"(Tensor, default: Tensor<int>), The ground truth which is a 2-D "
|
||||
"tensor. "
|
||||
"If softLable is set to 0, Label is a Tensor<int> with shape [N x 1]. "
|
||||
"If softLable is set to 1, Label is a Tensor<float/double> "
|
||||
"with shape [N x K].");
|
||||
AddOutput(
|
||||
"Softmax",
|
||||
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. "
|
||||
"The outputs value of softmax activation by given the input batch, "
|
||||
"which will be used in backward calculation.")
|
||||
.AsIntermediate();
|
||||
AddOutput("Loss",
|
||||
"(Tensor, default: Tensor<float>), A 2-D tensor. The cross "
|
||||
"entropy loss with shape [N x 1].");
|
||||
AddAttr<bool>(
|
||||
"softLabel",
|
||||
"(bool, default: false), A flag to indicate whether to interpretate "
|
||||
"the given labels as soft labels.")
|
||||
.SetDefault(false);
|
||||
AddComment(R"DOC(
|
||||
Cross entropy loss with softmax are used as the output layer extensively. This
|
||||
operator computes the softmax normalized values for each row of the input
|
||||
tensor, after which cross-entropy loss is then computed. This provides a more
|
||||
numerically stable gradient.
|
||||
|
||||
Because this operators performs a softmax on logits internally, it expects
|
||||
unscaled logits. Please do not call this op with the output of softmax operator,
|
||||
which will produce incorrect results.
|
||||
|
||||
This operators expects mutually exclusive hard labels, each sample in a batch
|
||||
is in exactly one class with probabilities 1. Each sample in the batch with one
|
||||
and only one label.
|
||||
|
||||
Equation:
|
||||
|
||||
1) hard label (one-hot label)
|
||||
|
||||
Loss_j = -\text{Logit}_{Label_j} + \log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right), j = 1, ..., K
|
||||
|
||||
2) soft label (a distribution over all classes)
|
||||
|
||||
Loss_j = -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i-\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right), j = 1,...,K
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext& ctx) const override {
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Logits"),
|
||||
"Input(Logits) should be not null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
|
||||
"Input(Label) should be not null.");
|
||||
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Softmax"),
|
||||
"Output(Softmax) should be not null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Loss"),
|
||||
"Output(Loss) should be not null.");
|
||||
|
||||
const Tensor* logits = ctx.Input<Tensor>("Logits");
|
||||
const Tensor* labels = ctx.Input<Tensor>("Label");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
logits->dims().size(), 2UL,
|
||||
"The input of softmax_with_cross_entropy should be a 2-D tensor.");
|
||||
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Label")->dims().size(), 2UL,
|
||||
"The labels should be a 2-D tensor.");
|
||||
|
||||
if (ctx.Attr<bool>("softLabel")) {
|
||||
PADDLE_ENFORCE_EQ(logits->dims()[1], labels->dims()[1],
|
||||
"If Attr(softLabel) == true, the 2nd dimension of "
|
||||
"Input(X) and Input(Label) should be equal.");
|
||||
} else {
|
||||
PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL,
|
||||
"If Attr(softLabel) == false, the 2nd dimension of "
|
||||
"Input(Label) should be 1.");
|
||||
}
|
||||
|
||||
ctx.Output<framework::Tensor>("Softmax")->Resize(logits->dims());
|
||||
ctx.Output<framework::Tensor>("Loss")->Resize({logits->dims()[0], 1});
|
||||
|
||||
ctx.ShareLoD("Logits", /*->*/ "Softmax");
|
||||
ctx.ShareLoD("Logits", /*->*/ "Loss");
|
||||
}
|
||||
};
|
||||
|
||||
class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext& ctx) const override {
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")),
|
||||
"Input(Loss@Grad) should not be null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"),
|
||||
"Input(Softmax) should be not null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
|
||||
"Input(Label) should be not null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("Logits")),
|
||||
"Output(Logits@Grad) should be not null.");
|
||||
|
||||
const Tensor* softmax = ctx.Input<Tensor>("Softmax");
|
||||
const Tensor* labels = ctx.Input<Tensor>("Label");
|
||||
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Label")->dims().size(), 2UL,
|
||||
"The labels should be a 2-D tensor.");
|
||||
|
||||
if (ctx.Attr<bool>("softLabel")) {
|
||||
PADDLE_ENFORCE_EQ(softmax->dims()[1], labels->dims()[1],
|
||||
"When Attr(softLabel) == true, the 2nd dimension of "
|
||||
"Input(X) and Input(Label) should be equal.");
|
||||
} else {
|
||||
PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL,
|
||||
"When Attr(softLabel) == false, the 2nd dimension of "
|
||||
"Input(Label) should be 1.");
|
||||
}
|
||||
|
||||
ctx.Output<framework::LoDTensor>(framework::GradVarName("Logits"))
|
||||
->Resize(ctx.Input<Tensor>("Softmax")->dims());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
|
||||
ops::SoftmaxWithCrossEntropyOpMaker,
|
||||
softmax_with_cross_entropy_grad,
|
||||
ops::SoftmaxWithCrossEntropyOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
|
||||
ops::SoftmaxWithCrossEntropyKernel<float>);
|
||||
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
|
||||
ops::SoftmaxWithCrossEntropyGradKernel<float>);
|
@ -0,0 +1,119 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "paddle/operators/softmax_with_cross_entropy_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
__global__ void CrossEntropyGrad(T* out_grad, const T* in_grad,
|
||||
const int* labels, const int batch_size,
|
||||
const int class_num) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int sample_idx = tid / class_num;
|
||||
|
||||
if (tid < batch_size * class_num) out_grad[tid] *= in_grad[sample_idx];
|
||||
__syncthreads();
|
||||
|
||||
if (tid < batch_size) {
|
||||
PADDLE_ASSERT(labels[sample_idx] >= 0 && labels[sample_idx] < class_num);
|
||||
out_grad[tid * class_num + labels[tid]] -= 1.;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
|
||||
const T* loss_grad,
|
||||
const T* labels,
|
||||
const int batch_size,
|
||||
const int class_num) {
|
||||
int ids = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (ids < batch_size * class_num) {
|
||||
int row_ids = ids / class_num;
|
||||
logit_grad[ids] = logit_grad[ids] * loss_grad[row_ids] - labels[ids];
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
|
||||
"This kernel only runs on GPU device.");
|
||||
const Tensor* logits = context.Input<Tensor>("Logits");
|
||||
const Tensor* labels = context.Input<Tensor>("Label");
|
||||
Tensor* softmax = context.Output<Tensor>("Softmax");
|
||||
|
||||
Tensor* loss = context.Output<Tensor>("Loss");
|
||||
softmax->mutable_data<T>(context.GetPlace());
|
||||
loss->mutable_data<T>(context.GetPlace());
|
||||
|
||||
math::SoftmaxFunctor<platform::GPUPlace, T>()(context, logits, softmax);
|
||||
math::CrossEntropyFunctor<platform::GPUPlace, T>()(
|
||||
context, loss, softmax, labels, context.Attr<bool>("softLabel"));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
|
||||
"This kernel only runs on GPU device.");
|
||||
const Tensor* labels = context.Input<Tensor>("Label");
|
||||
const T* loss_grad_data =
|
||||
context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
|
||||
Tensor* logit_grad =
|
||||
context.Output<Tensor>(framework::GradVarName("Logits"));
|
||||
logit_grad->ShareDataWith<T>(*context.Input<Tensor>("Softmax"));
|
||||
T* logit_grad_data = logit_grad->data<T>();
|
||||
|
||||
const int batch_size = logit_grad->dims()[0];
|
||||
const int class_num = logit_grad->dims()[1];
|
||||
int block = 512;
|
||||
int grid = (batch_size * class_num + block - 1) / block;
|
||||
|
||||
if (context.Attr<bool>("softLabel")) {
|
||||
const T* label_data = labels->data<T>();
|
||||
SoftCrossEntropyGradientKernel<T><<<
|
||||
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
|
||||
context.device_context())
|
||||
.stream()>>>(logit_grad_data, loss_grad_data,
|
||||
label_data, batch_size, class_num);
|
||||
} else {
|
||||
const int* label_data = labels->data<int>();
|
||||
CrossEntropyGrad<T><<<
|
||||
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
|
||||
context.device_context())
|
||||
.stream()>>>(logit_grad_data, loss_grad_data,
|
||||
label_data, batch_size, class_num);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy,
|
||||
ops::SoftmaxWithCrossEntropyCUDAKernel<float>);
|
||||
REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy_grad,
|
||||
ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>);
|
@ -0,0 +1,86 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/math/cross_entropy.h"
|
||||
#include "paddle/operators/math/softmax.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
|
||||
template <typename T>
|
||||
class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
PADDLE_ENFORCE(platform::is_cpu_place(context.GetPlace()),
|
||||
"This kernel only runs on CPU.");
|
||||
const Tensor* logits = context.Input<Tensor>("Logits");
|
||||
const Tensor* labels = context.Input<Tensor>("Label");
|
||||
Tensor* softmax = context.Output<Tensor>("Softmax");
|
||||
Tensor* loss = context.Output<Tensor>("Loss");
|
||||
|
||||
softmax->mutable_data<T>(context.GetPlace());
|
||||
loss->mutable_data<T>(context.GetPlace());
|
||||
|
||||
math::SoftmaxFunctor<platform::CPUPlace, T>()(context, logits, softmax);
|
||||
math::CrossEntropyFunctor<platform::CPUPlace, T>()(
|
||||
context, loss, softmax, labels, context.Attr<bool>("softLabel"));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
const Tensor* out_grad =
|
||||
context.Input<Tensor>(framework::GradVarName("Loss"));
|
||||
const Tensor* labels = context.Input<Tensor>("Label");
|
||||
Tensor* logit_grad =
|
||||
context.Output<Tensor>(framework::GradVarName("Logits"));
|
||||
logit_grad->ShareDataWith<T>(*context.Input<Tensor>("Softmax"));
|
||||
|
||||
const int class_num = logit_grad->dims()[1];
|
||||
if (context.Attr<bool>("softLabel")) {
|
||||
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
|
||||
auto logit_grad_mat = EigenMatrix<T>::From(*logit_grad);
|
||||
auto lbl_mat = EigenMatrix<T>::From(*labels);
|
||||
|
||||
logit_grad_mat.device(context.GetEigenDevice<platform::CPUPlace>()) =
|
||||
logit_grad_mat *
|
||||
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) -
|
||||
lbl_mat;
|
||||
} else {
|
||||
const int batch_size = logit_grad->dims()[0];
|
||||
const int* label_data = labels->data<int>();
|
||||
const T* out_grad_data = out_grad->data<T>();
|
||||
T* logit_grad_data = logit_grad->data<T>();
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
int index = i * class_num + label_data[i];
|
||||
logit_grad_data[index] =
|
||||
(out_grad_data[i] * logit_grad_data[index] - 1.);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,70 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from op_test import OpTest
|
||||
from test_softmax_op import stable_softmax
|
||||
|
||||
|
||||
class TestSoftmaxWithCrossEntropyOp(OpTest):
|
||||
"""
|
||||
Test softmax with cross entropy operator with discreate one-hot labels.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "softmax_with_cross_entropy"
|
||||
batch_size = 3
|
||||
class_num = 37
|
||||
|
||||
logits = np.random.uniform(0.1, 1.0,
|
||||
[batch_size, class_num]).astype("float32")
|
||||
softmax = np.apply_along_axis(stable_softmax, 1, logits)
|
||||
labels = np.random.randint(0, class_num, [batch_size, 1], dtype="int32")
|
||||
|
||||
cross_entropy = np.asmatrix(
|
||||
[[-np.log(softmax[i][labels[i][0]])]
|
||||
for i in range(softmax.shape[0])],
|
||||
dtype="float32")
|
||||
|
||||
self.inputs = {"Logits": logits, "Label": labels}
|
||||
self.outputs = {"Softmax": softmax, "Loss": cross_entropy}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(["Logits"], "Loss", max_relative_error=0.05)
|
||||
|
||||
|
||||
class TestSoftmaxWithCrossEntropyOp2(OpTest):
|
||||
"""
|
||||
Test softmax with cross entropy operator with soft labels.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "softmax_with_cross_entropy"
|
||||
batch_size = 2
|
||||
class_num = 17
|
||||
|
||||
logits = np.random.uniform(0.1, 1.0,
|
||||
[batch_size, class_num]).astype("float32")
|
||||
softmax = np.apply_along_axis(stable_softmax, 1, logits)
|
||||
labels = np.random.uniform(0.1, 1.0,
|
||||
[batch_size, class_num]).astype("float32")
|
||||
labels /= np.sum(labels, axis=1, keepdims=True)
|
||||
|
||||
cross_entropy = (-labels * np.log(softmax)).sum(
|
||||
axis=1, keepdims=True).astype("float32")
|
||||
|
||||
self.inputs = {"Logits": logits, "Label": labels}
|
||||
self.outputs = {"Softmax": softmax, "Loss": cross_entropy}
|
||||
self.attrs = {"softLabel": True}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(["Logits"], "Loss", max_relative_error=0.05)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue