|
|
|
@ -54,13 +54,12 @@ class KLDivLossKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto input_t = EigenVector<T>::Flatten(*input);
|
|
|
|
|
auto target_t = EigenVector<T>::Flatten(*target);
|
|
|
|
|
auto loss_t = EigenVector<T>::Flatten(*loss);
|
|
|
|
|
// auto target_mask = (target_t > target_t.constant(0)).template cast<T>();
|
|
|
|
|
// auto output = (target_t * (target_t.log() - input_t)) * target_mask;
|
|
|
|
|
auto output = target_t.binaryExpr(input_t, KLDivLossForward<T>());
|
|
|
|
|
if ("none" == reduction) {
|
|
|
|
|
loss_t.device(place) = output;
|
|
|
|
|
} else if ("batchmean" == reduction) {
|
|
|
|
|
loss_t.device(place) = output.sum() / static_cast<T>(n);
|
|
|
|
|
auto output_sum = output.sum().eval();
|
|
|
|
|
loss_t.device(place) = output_sum / output_sum.constant(n);
|
|
|
|
|
} else if ("mean" == reduction) {
|
|
|
|
|
loss_t.device(place) = output.mean();
|
|
|
|
|
} else if ("sum" == reduction) {
|
|
|
|
@ -74,19 +73,17 @@ class KLDivLossGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
auto* input = ctx.Input<Tensor>("X");
|
|
|
|
|
auto* target = ctx.Input<Tensor>("Target");
|
|
|
|
|
auto reduction = ctx.Attr<std::string>("reduction");
|
|
|
|
|
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
|
|
|
|
|
|
|
|
|
|
const int n = input->dims()[0];
|
|
|
|
|
const int numel = input->numel();
|
|
|
|
|
const int n = input_grad->dims()[0];
|
|
|
|
|
const int numel = input_grad->numel();
|
|
|
|
|
const int expand = numel / loss_grad->numel();
|
|
|
|
|
|
|
|
|
|
input_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto input_t = EigenVector<T>::Flatten(*input);
|
|
|
|
|
auto target_t = EigenVector<T>::Flatten(*target);
|
|
|
|
|
|
|
|
|
|
auto input_grad_t = EigenVector<T>::Flatten(*input_grad);
|
|
|
|
@ -96,14 +93,6 @@ class KLDivLossGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand));
|
|
|
|
|
input_grad_t.device(place) =
|
|
|
|
|
target_t * target_t.constant(-1.0) * loss_grad_expand * target_mask;
|
|
|
|
|
// if (reduction == "none") {
|
|
|
|
|
// input_grad_t.device(place) =
|
|
|
|
|
// target_t * loss_grad_t * target_t.constant(-1.0);
|
|
|
|
|
// } else {
|
|
|
|
|
// auto loss_grad_expand = loss_grad_t.broadcast(Array1(numel));
|
|
|
|
|
// input_grad_t.device(place) =
|
|
|
|
|
// target_t * loss_grad_expand * target_t.constant(-1.0);
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
if ("mean" == reduction) {
|
|
|
|
|
input_grad_t.device(place) = input_grad_t / static_cast<T>(numel);
|
|
|
|
|