parent
a585b585dd
commit
fcff9758ed
@ -0,0 +1,85 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/label_smooth_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class LabelSmoothOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
LabelSmoothOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of LabelSmoothOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of LabelSmoothOp should not be null.");
|
||||
auto in_dims = ctx->GetInputDim("X");
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
ctx->SetOutputDim("Out", in_dims);
|
||||
}
|
||||
};
|
||||
|
||||
class LabelSmoothOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
LabelSmoothOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X", "The input label of LabelSmooth operator.");
|
||||
AddOutput("Out", "The smoothed label of LabelSmooth operator.");
|
||||
AddAttr<float>("epsilon",
|
||||
"(float, default 0.0f)"
|
||||
"The smoothing parameter of LabelSmooth operator.")
|
||||
.SetDefault(0.0f);
|
||||
AddComment(R"DOC(
|
||||
LabelSmooth Operator.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class LabelSmoothGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
LabelSmoothGradOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
||||
"Input(Out@GRAD) shouldn't be null.");
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP(label_smooth, ops::LabelSmoothOp, ops::LabelSmoothOpMaker,
|
||||
label_smooth_grad, ops::LabelSmoothGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
label_smooth,
|
||||
ops::LabelSmoothKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::LabelSmoothKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
label_smooth_grad,
|
||||
ops::LabelSmoothGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::LabelSmoothGradKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,26 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/label_smooth_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
label_smooth,
|
||||
ops::LabelSmoothKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::LabelSmoothKernel<paddle::platform::CUDADeviceContext, double>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
label_smooth_grad,
|
||||
ops::LabelSmoothGradKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::LabelSmoothGradKernel<paddle::platform::CUDADeviceContext, double>);
|
@ -0,0 +1,58 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class LabelSmoothKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const {
|
||||
auto* out_t = ctx.Output<framework::LoDTensor>("Out");
|
||||
auto* in_t = ctx.Input<framework::LoDTensor>("X");
|
||||
auto label_dim = in_t->dims()[1];
|
||||
out_t->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto epsilon = ctx.Attr<float>("epsilon");
|
||||
auto out = framework::EigenVector<T>::Flatten(*out_t);
|
||||
auto in = framework::EigenVector<T>::Flatten(*in_t);
|
||||
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
out.device(dev) =
|
||||
static_cast<T>(1 - epsilon) * in + static_cast<T>(epsilon / label_dim);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class LabelSmoothGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const {
|
||||
auto* d_out_t = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||
auto* d_in_t = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
d_in_t->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto d_out = framework::EigenVector<T>::Flatten(*d_out_t);
|
||||
auto d_in = framework::EigenVector<T>::Flatten(*d_in_t);
|
||||
|
||||
auto epsilon = ctx.Attr<float>("epsilon");
|
||||
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
d_in.device(dev) = static_cast<T>(1 - epsilon) * d_out;
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestLabelSmoothOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "label_smooth"
|
||||
epsilon = 0.1
|
||||
batch_size, label_dim = 5, 10
|
||||
label = np.zeros((batch_size, label_dim)).astype("float64")
|
||||
nonzero_index = np.random.randint(label_dim, size=(batch_size))
|
||||
label[np.arange(batch_size), nonzero_index] = 1
|
||||
smoothed_label = (1 - epsilon) * label + epsilon / label_dim
|
||||
self.inputs = {'X': label}
|
||||
self.attrs = {'epsilon': epsilon}
|
||||
self.outputs = {'Out': smoothed_label}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(["X"], "Out")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue