|
|
|
@ -17,69 +17,106 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/cross_entropy.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
#include "paddle/fluid/platform/for_range.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
class CrossEntropyOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
|
|
|
|
|
"This kernel only runs on CPU.");
|
|
|
|
|
const Tensor* x = ctx.Input<Tensor>("X");
|
|
|
|
|
const Tensor* labels = ctx.Input<Tensor>("Label");
|
|
|
|
|
Tensor* y = ctx.Output<Tensor>("Y");
|
|
|
|
|
auto* x = ctx.Input<Tensor>("X");
|
|
|
|
|
auto* labels = ctx.Input<Tensor>("Label");
|
|
|
|
|
auto* y = ctx.Output<Tensor>("Y");
|
|
|
|
|
y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
|
|
|
|
|
ctx.template device_context<platform::CPUDeviceContext>(), y, x, labels,
|
|
|
|
|
math::CrossEntropyFunctor<DeviceContext, T>()(
|
|
|
|
|
ctx.template device_context<DeviceContext>(), y, x, labels,
|
|
|
|
|
ctx.Attr<bool>("soft_label"));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class XeSoftlabelGradFunctor {
|
|
|
|
|
public:
|
|
|
|
|
XeSoftlabelGradFunctor(T* dx,
|
|
|
|
|
const T* dy, // NOLINT
|
|
|
|
|
const T* x, // NOLINT
|
|
|
|
|
const T* label, // NOLINT
|
|
|
|
|
size_t num_classes)
|
|
|
|
|
: dx_(dx), dy_(dy), x_(x), label_(label), num_classes_(num_classes) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE void operator()(size_t i) {
|
|
|
|
|
auto row_ids = i / num_classes_;
|
|
|
|
|
dx_[i] = -label_[i] * dy_[row_ids] / x_[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
T* dx_;
|
|
|
|
|
const T* dy_;
|
|
|
|
|
const T* x_;
|
|
|
|
|
const T* label_;
|
|
|
|
|
size_t num_classes_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class XeGradFunctor {
|
|
|
|
|
public:
|
|
|
|
|
XeGradFunctor(T* dx,
|
|
|
|
|
const T* dy, // NOLINT
|
|
|
|
|
const T* x, // NOLINT
|
|
|
|
|
const int64_t* label, // NOLINT
|
|
|
|
|
size_t num_classes)
|
|
|
|
|
: dx_(dx), dy_(dy), x_(x), label_(label), num_classes_(num_classes) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE void operator()(size_t label_id) {
|
|
|
|
|
auto x_is_true_offset = label_id * num_classes_ + label_[label_id];
|
|
|
|
|
for (size_t x_offset = label_id * num_classes_;
|
|
|
|
|
x_offset < (label_id + 1) * num_classes_; ++x_offset) {
|
|
|
|
|
dx_[x_offset] = x_offset != x_is_true_offset
|
|
|
|
|
? static_cast<T>(0)
|
|
|
|
|
: -dy_[label_id] / x_[x_offset];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
T* dx_;
|
|
|
|
|
const T* dy_;
|
|
|
|
|
const T* x_;
|
|
|
|
|
const int64_t* label_;
|
|
|
|
|
size_t num_classes_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
|
|
|
|
|
"This kernel only runs on CPU.");
|
|
|
|
|
const Tensor* x = ctx.Input<Tensor>("X");
|
|
|
|
|
const Tensor* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
const Tensor* label = ctx.Input<Tensor>("Label");
|
|
|
|
|
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto* x = ctx.Input<Tensor>("X");
|
|
|
|
|
auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
auto* label = ctx.Input<Tensor>("Label");
|
|
|
|
|
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
int64_t class_num = x->dims()[1];
|
|
|
|
|
if (ctx.Attr<bool>("soft_label")) {
|
|
|
|
|
auto x_mat = EigenMatrix<T>::From(*x);
|
|
|
|
|
auto dy_mat = EigenMatrix<T>::From(*dy);
|
|
|
|
|
auto lbl_mat = EigenMatrix<T>::From(*label);
|
|
|
|
|
auto dx_mat = EigenMatrix<T>::From(*dx);
|
|
|
|
|
|
|
|
|
|
dx_mat.device(*ctx.template device_context<platform::CPUDeviceContext>()
|
|
|
|
|
.eigen_device()) =
|
|
|
|
|
-(lbl_mat *
|
|
|
|
|
dy_mat.broadcast(Eigen::DSizes<int64_t, 2>(1, class_num)) / x_mat);
|
|
|
|
|
XeSoftlabelGradFunctor<T> functor(dx_data, dy->data<T>(), x->data<T>(),
|
|
|
|
|
label->data<T>(),
|
|
|
|
|
static_cast<size_t>(class_num));
|
|
|
|
|
platform::ForRange<DeviceContext> for_range(
|
|
|
|
|
ctx.template device_context<DeviceContext>(),
|
|
|
|
|
static_cast<size_t>(dx->numel()));
|
|
|
|
|
for_range(functor);
|
|
|
|
|
} else {
|
|
|
|
|
int64_t batch_size = x->dims()[0];
|
|
|
|
|
const T* dy_data = dy->data<T>();
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
const int64_t* label_data = label->data<int64_t>();
|
|
|
|
|
|
|
|
|
|
math::SetConstant<platform::CPUDeviceContext, T> functor;
|
|
|
|
|
functor(ctx.template device_context<platform::CPUDeviceContext>(), dx, 0);
|
|
|
|
|
|
|
|
|
|
for (int64_t i = 0; i < batch_size; ++i) {
|
|
|
|
|
PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num);
|
|
|
|
|
int64_t index = i * class_num + label_data[i];
|
|
|
|
|
dx_data[index] = math::TolerableValue<T>()(-dy_data[i] / x_data[index]);
|
|
|
|
|
}
|
|
|
|
|
XeGradFunctor<T> functor(dx_data, dy->data<T>(), x->data<T>(),
|
|
|
|
|
label->data<int64_t>(),
|
|
|
|
|
static_cast<size_t>(class_num));
|
|
|
|
|
platform::ForRange<DeviceContext> for_range(
|
|
|
|
|
ctx.template device_context<DeviceContext>(),
|
|
|
|
|
static_cast<size_t>(dy->numel()));
|
|
|
|
|
for_range(functor);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|