parent
6375fe45d7
commit
3e3a983a69
@ -0,0 +1,150 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/kldiv_loss_op.h"
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
class KLDivLossOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of KLDivLossOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Target"),
|
||||
"Input(Target) of KLDivLossOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Loss"),
|
||||
"Output(Loss) of KLDivLossOp should not be null.");
|
||||
|
||||
auto dim_x = ctx->GetInputDim("X");
|
||||
auto dim_target = ctx->GetInputDim("Target");
|
||||
PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(),
|
||||
"Input(X) rank and Input(Target) rank should be same.");
|
||||
for (size_t i = 0; i < dim_x.size(); i++) {
|
||||
PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i],
|
||||
"Input(X) and Input(Target) should in same shape.");
|
||||
}
|
||||
|
||||
auto reduction = ctx->Attrs().Get<std::string>("reduction");
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
"mean" == reduction || "sum" == reduction || "batchmean" == reduction ||
|
||||
"none" == reduction,
|
||||
"Attr(reduction) can only be 'none'|'batchmean'|'sum'|'mean'.");
|
||||
|
||||
if ("none" == reduction) {
|
||||
ctx->SetOutputDim("Loss", dim_x);
|
||||
} else {
|
||||
ctx->SetOutputDim("Loss", framework::make_ddim({1}));
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class KLDivLossOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"The input tensor of KL divergence loss operator, "
|
||||
"This is a tensor with shape of [N, *], where N is the"
|
||||
"batch size, * means any number of additional dimensions.");
|
||||
AddInput("Target",
|
||||
"The tensor of KL divergence loss operator, "
|
||||
"This is a tensor with shape of Input(X).");
|
||||
AddOutput(
|
||||
"Loss",
|
||||
"The output KL divergence loss tensor. if Attr(reduction) is "
|
||||
"'none', this tensor should be in same shape of of Input(X), else "
|
||||
"this tensor should be in shape of [1].");
|
||||
|
||||
AddAttr<std::string>(
|
||||
"reduction",
|
||||
"The reduction type to apply to the output, available types "
|
||||
"are 'none' | 'batchmean' | 'mean' | 'sum', 'none' for no "
|
||||
"reduction, 'batchmean' for the sum of output divided by "
|
||||
"batch size, 'mean' for the average valud of all output, "
|
||||
"'sum' for the sum of the output.")
|
||||
.SetDefault("mean");
|
||||
|
||||
AddComment(R"DOC(
|
||||
This operator calculates the Kullback-Leibler divergence loss
|
||||
between Input(X) and Input(Target).
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class KLDivLossOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Target"), "Input(Target) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
|
||||
"Input(Loss@GRAD) should not be null");
|
||||
auto dim_x = ctx->GetInputDim("X");
|
||||
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class KLDivLossOpGradMaker : public framework::SingleGradOpDescMaker {
|
||||
public:
|
||||
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<framework::OpDesc> Apply() const override {
|
||||
auto* op = new framework::OpDesc();
|
||||
op->SetType("kldiv_loss_grad");
|
||||
op->SetInput("X", Input("X"));
|
||||
op->SetInput("Target", Input("Target"));
|
||||
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
|
||||
|
||||
op->SetAttrMap(Attrs());
|
||||
|
||||
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
||||
return std::unique_ptr<framework::OpDesc>(op);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker,
|
||||
ops::KLDivLossOpGradMaker);
|
||||
REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
kldiv_loss, ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
kldiv_loss_grad,
|
||||
ops::KLDivLossGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::KLDivLossGradKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,21 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
#include "paddle/fluid/operators/kldiv_loss_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
sum, ops::KLDivLossKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::KLDivLossKernel<paddle::platform::CUDADeviceContext, double>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
sum_grad,
|
||||
ops::KLDivLossGradKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::KLDivLossGradKernel<paddle::platform::CUDADeviceContext, double>);
|
@ -0,0 +1,117 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/platform/hostdevice.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
||||
|
||||
using Array1 = Eigen::DSizes<int64_t, 1>;
|
||||
|
||||
template <typename T>
|
||||
struct KLDivLossForward {
|
||||
HOSTDEVICE KLDivLossForward() {}
|
||||
|
||||
HOSTDEVICE T operator()(const T& target, const T& input) const {
|
||||
if (target < 0) {
|
||||
return 0;
|
||||
} else {
|
||||
return target * (std::log(target) - input);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class KLDivLossKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
auto* input = ctx.Input<Tensor>("X");
|
||||
auto* target = ctx.Input<Tensor>("Target");
|
||||
auto* loss = ctx.Output<Tensor>("Loss");
|
||||
auto reduction = ctx.Attr<std::string>("reduction");
|
||||
|
||||
const int n = input->dims()[0];
|
||||
|
||||
loss->mutable_data<T>(ctx.GetPlace());
|
||||
auto input_t = EigenVector<T>::Flatten(*input);
|
||||
auto target_t = EigenVector<T>::Flatten(*target);
|
||||
auto loss_t = EigenVector<T>::Flatten(*loss);
|
||||
// auto target_mask = (target_t > target_t.constant(0)).template cast<T>();
|
||||
// auto output = (target_t * (target_t.log() - input_t)) * target_mask;
|
||||
auto output = target_t.binaryExpr(input_t, KLDivLossForward<T>());
|
||||
if ("none" == reduction) {
|
||||
loss_t.device(place) = output;
|
||||
} else if ("batchmean" == reduction) {
|
||||
loss_t.device(place) = output.sum() / static_cast<T>(n);
|
||||
} else if ("mean" == reduction) {
|
||||
loss_t.device(place) = output.mean();
|
||||
} else if ("sum" == reduction) {
|
||||
loss_t.device(place) = output.sum();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class KLDivLossGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
auto* input = ctx.Input<Tensor>("X");
|
||||
auto* target = ctx.Input<Tensor>("Target");
|
||||
auto reduction = ctx.Attr<std::string>("reduction");
|
||||
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
|
||||
|
||||
const int n = input->dims()[0];
|
||||
const int numel = input->numel();
|
||||
const int expand = numel / loss_grad->numel();
|
||||
|
||||
input_grad->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto input_t = EigenVector<T>::Flatten(*input);
|
||||
auto target_t = EigenVector<T>::Flatten(*target);
|
||||
|
||||
auto input_grad_t = EigenVector<T>::Flatten(*input_grad);
|
||||
auto loss_grad_t = EigenVector<T>::Flatten(*loss_grad);
|
||||
auto target_mask = (target_t > target_t.constant(0)).template cast<T>();
|
||||
|
||||
auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand));
|
||||
input_grad_t.device(place) =
|
||||
target_t * target_t.constant(-1.0) * loss_grad_expand * target_mask;
|
||||
// if (reduction == "none") {
|
||||
// input_grad_t.device(place) =
|
||||
// target_t * loss_grad_t * target_t.constant(-1.0);
|
||||
// } else {
|
||||
// auto loss_grad_expand = loss_grad_t.broadcast(Array1(numel));
|
||||
// input_grad_t.device(place) =
|
||||
// target_t * loss_grad_expand * target_t.constant(-1.0);
|
||||
// }
|
||||
|
||||
if ("mean" == reduction) {
|
||||
input_grad_t.device(place) = input_grad_t / static_cast<T>(numel);
|
||||
} else if ("batchmean" == reduction) {
|
||||
input_grad_t.device(place) = input_grad_t / static_cast<T>(n);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import division
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def kldiv_loss(x, target, reduction):
|
||||
output = target * (np.log(target) - x)
|
||||
loss = np.where(target > 0, output, np.zeros_like(x))
|
||||
|
||||
if reduction == "batchmean":
|
||||
return loss.sum() / x.shape[0]
|
||||
if reduction == "mean":
|
||||
return loss.mean()
|
||||
if reduction == "sum":
|
||||
return loss.sum()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class TestKLDivLossOp(OpTest):
|
||||
def setUp(self):
|
||||
self.initTestCase()
|
||||
self.op_type = 'kldiv_loss'
|
||||
x = np.random.uniform(-10, 10, self.x_shape).astype('float32')
|
||||
target = np.random.uniform(-10, 10, self.x_shape).astype('float32')
|
||||
|
||||
self.attrs = {"reduction": self.reduction}
|
||||
|
||||
self.inputs = {
|
||||
'X': x,
|
||||
'Target': target,
|
||||
}
|
||||
loss = kldiv_loss(x, target, self.reduction)
|
||||
self.outputs = {'Loss': loss}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(
|
||||
['X'], 'Loss', no_grad_set=set(["Target"]), max_relative_error=0.1)
|
||||
|
||||
def initTestCase(self):
|
||||
self.x_shape = (2, 3, 5, 5)
|
||||
self.reduction = 'batchmean'
|
||||
|
||||
|
||||
# class TestKLDivLossOp2(TestKLDivLossOp):
|
||||
# def initTestCase(self):
|
||||
# self.x_shape = (3, 7, 7)
|
||||
# self.reduction = 'batchmean'
|
||||
#
|
||||
#
|
||||
# class TestKLDivLossOp3(TestKLDivLossOp):
|
||||
# def initTestCase(self):
|
||||
# self.x_shape = (2, 3, 5, 7, 9)
|
||||
# self.reduction = 'mean'
|
||||
#
|
||||
#
|
||||
# class TestKLDivLossOp4(TestKLDivLossOp):
|
||||
# def initTestCase(self):
|
||||
# self.x_shape = (5, 7)
|
||||
# self.reduction = 'sum'
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue