|
|
|
@ -138,15 +138,48 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct HardLabelCrossEntropyForwardFunctor {
|
|
|
|
|
HardLabelCrossEntropyForwardFunctor(const T* x, T* y, T* match_x,
|
|
|
|
|
const int64_t* label,
|
|
|
|
|
int64_t ignore_index,
|
|
|
|
|
int64_t feature_size)
|
|
|
|
|
: x_(x),
|
|
|
|
|
y_(y),
|
|
|
|
|
match_x_(match_x),
|
|
|
|
|
label_(label),
|
|
|
|
|
ignore_index_(ignore_index),
|
|
|
|
|
feature_size_(feature_size) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE void operator()(int64_t idx) const {
|
|
|
|
|
auto label = label_[idx];
|
|
|
|
|
if (label != ignore_index_) {
|
|
|
|
|
auto match_x = x_[idx * feature_size_ + label];
|
|
|
|
|
y_[idx] = -math::TolerableValue<T>()(real_log(match_x));
|
|
|
|
|
match_x_[idx] = match_x;
|
|
|
|
|
} else {
|
|
|
|
|
y_[idx] = 0;
|
|
|
|
|
match_x_[idx] = 0; // any value is ok
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T* x_;
|
|
|
|
|
T* y_;
|
|
|
|
|
T* match_x_;
|
|
|
|
|
const int64_t* label_;
|
|
|
|
|
int64_t ignore_index_;
|
|
|
|
|
int64_t feature_size_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct HardLabelCrossEntropyBackwardFunctor {
|
|
|
|
|
HardLabelCrossEntropyBackwardFunctor(T* dx, const T* y, const T* dy,
|
|
|
|
|
HardLabelCrossEntropyBackwardFunctor(T* dx, const T* dy, const T* match_x,
|
|
|
|
|
const int64_t* label,
|
|
|
|
|
int64_t ignore_index,
|
|
|
|
|
int64_t feature_size)
|
|
|
|
|
: dx_(dx),
|
|
|
|
|
y_(y),
|
|
|
|
|
dy_(dy),
|
|
|
|
|
match_x_(match_x),
|
|
|
|
|
label_(label),
|
|
|
|
|
ignore_index_(ignore_index),
|
|
|
|
|
feature_size_(feature_size) {}
|
|
|
|
@ -156,15 +189,15 @@ struct HardLabelCrossEntropyBackwardFunctor {
|
|
|
|
|
auto col_idx = idx % feature_size_;
|
|
|
|
|
auto label = label_[row_idx];
|
|
|
|
|
if (label == col_idx && label != ignore_index_) {
|
|
|
|
|
dx_[idx] = -dy_[row_idx] * real_exp(y_[row_idx]);
|
|
|
|
|
dx_[idx] = -dy_[row_idx] / match_x_[row_idx];
|
|
|
|
|
} else {
|
|
|
|
|
dx_[idx] = 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
T* dx_;
|
|
|
|
|
const T* y_;
|
|
|
|
|
const T* dy_;
|
|
|
|
|
const T* match_x_;
|
|
|
|
|
const int64_t* label_;
|
|
|
|
|
int64_t ignore_index_;
|
|
|
|
|
int64_t feature_size_;
|
|
|
|
@ -174,20 +207,26 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
class CrossEntropyOpKernel2 : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto* x_original = ctx.Input<Tensor>("X");
|
|
|
|
|
int rank = x_original->dims().size();
|
|
|
|
|
|
|
|
|
|
auto x = framework::ReshapeToMatrix(*x_original, rank - 1);
|
|
|
|
|
auto label =
|
|
|
|
|
framework::ReshapeToMatrix(*ctx.Input<Tensor>("Label"), rank - 1);
|
|
|
|
|
auto* x = ctx.Input<Tensor>("X");
|
|
|
|
|
auto* label = ctx.Input<Tensor>("Label");
|
|
|
|
|
auto* y = ctx.Output<Tensor>("Y");
|
|
|
|
|
y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto* match_x = ctx.Output<Tensor>("MatchX");
|
|
|
|
|
|
|
|
|
|
auto& x_dims = x->dims();
|
|
|
|
|
auto feature_size = x_dims[x_dims.size() - 1];
|
|
|
|
|
auto batch_size = framework::product(x->dims()) / feature_size;
|
|
|
|
|
|
|
|
|
|
auto* p_x = x->data<T>();
|
|
|
|
|
auto* p_label = label->data<int64_t>();
|
|
|
|
|
auto* p_y = y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto* p_match_x = match_x->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto ignore_index = ctx.Attr<int>("ignore_index");
|
|
|
|
|
|
|
|
|
|
math::CrossEntropyFunctor<DeviceContext, T>()(
|
|
|
|
|
ctx.template device_context<DeviceContext>(), y, &x, &label, false,
|
|
|
|
|
ignore_index);
|
|
|
|
|
platform::ForRange<DeviceContext> for_range(
|
|
|
|
|
ctx.template device_context<DeviceContext>(), batch_size);
|
|
|
|
|
for_range(HardLabelCrossEntropyForwardFunctor<T>(
|
|
|
|
|
p_x, p_y, p_match_x, p_label, ignore_index, feature_size));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -196,13 +235,13 @@ class CrossEntropyGradientOpKernel2 : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* y = ctx.Input<Tensor>("Y");
|
|
|
|
|
auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
auto* match_x = ctx.Input<Tensor>("MatchX");
|
|
|
|
|
auto* label = ctx.Input<Tensor>("Label");
|
|
|
|
|
|
|
|
|
|
auto* p_dx = dx->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto* p_y = y->data<T>();
|
|
|
|
|
auto* p_dy = dy->data<T>();
|
|
|
|
|
auto* p_match_x = match_x->data<T>();
|
|
|
|
|
auto* p_label = label->data<int64_t>();
|
|
|
|
|
|
|
|
|
|
int64_t ignore_index = ctx.Attr<int>("ignore_index");
|
|
|
|
@ -214,7 +253,7 @@ class CrossEntropyGradientOpKernel2 : public framework::OpKernel<T> {
|
|
|
|
|
ctx.template device_context<DeviceContext>(),
|
|
|
|
|
batch_size * feature_size);
|
|
|
|
|
for_range(HardLabelCrossEntropyBackwardFunctor<T>(
|
|
|
|
|
p_dx, p_y, p_dy, p_label, ignore_index, feature_size));
|
|
|
|
|
p_dx, p_dy, p_match_x, p_label, ignore_index, feature_size));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|