Adding the FTRL optimizer. (#5785)
* Adding the FTRL optimizer * Fixed the python test caserelease/0.11.0
parent
32b10d3bc4
commit
d883547bf0
@ -0,0 +1,139 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/ftrl_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class FTRLOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Param"),
|
||||
"Input(Param) of FTRL should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("SquaredAccumulator"),
|
||||
"Input(SquaredAccumulator) of FTRL should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("LinearAccumulator"),
|
||||
"Input(LinearAccumulator) of FTRL should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Grad"),
|
||||
"Input(Grad) of FTRL should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
|
||||
"Input(LearningRate) of FTRL should not be null.");
|
||||
|
||||
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
|
||||
"Output(ParamOut) of FTRL should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("SquaredAccumOut"),
|
||||
"Output(SquaredAccumOut) of FTRL should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("LinearAccumOut"),
|
||||
"Output(LinearAccumOut) of FTRL should not be null.");
|
||||
|
||||
auto param_dim = ctx->GetInputDim("Param");
|
||||
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
|
||||
"Two input of FTRL Op's dimension must be same.");
|
||||
|
||||
auto lr_dim = ctx->GetInputDim("LearningRate");
|
||||
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1,
|
||||
"Learning Rate should be a scalar.");
|
||||
|
||||
ctx->SetOutputDim("ParamOut", param_dim);
|
||||
ctx->SetOutputDim("SquaredAccumOut", param_dim);
|
||||
ctx->SetOutputDim("LinearAccumOut", param_dim);
|
||||
}
|
||||
};
|
||||
|
||||
class FTRLOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
FTRLOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Param",
|
||||
"(Tensor, default Tensor<float>) "
|
||||
"Input parameter value that has to be updated.");
|
||||
AddInput("SquaredAccumulator",
|
||||
"(Tensor, default Tensor<float>) "
|
||||
"Accumulator that accumulates squared gradients.");
|
||||
AddInput("LinearAccumulator",
|
||||
"(Tensor, default Tensor<float>) "
|
||||
"Accumulator that accumulates linear gradients.");
|
||||
AddInput("Grad",
|
||||
"(Tensor, default Tensor<float>) "
|
||||
"Input gradient of the parameter.");
|
||||
AddInput("LearningRate",
|
||||
"(Tensor, default Tensor<float>) "
|
||||
"The learning rate should be a tensor of size 1.");
|
||||
|
||||
AddOutput("ParamOut", "(Tensor) Output updated parameter value.");
|
||||
AddOutput("SquaredAccumOut",
|
||||
"(Tensor) Output accumulated squared"
|
||||
" gradients.");
|
||||
AddOutput("LinearAccumOut",
|
||||
"(Tensor) Output accumulated linear"
|
||||
" gradients.");
|
||||
|
||||
AddAttr<float>("l1",
|
||||
"(float, default 0.0) "
|
||||
"L1 regularization strength.")
|
||||
.SetDefault(0.0f);
|
||||
AddAttr<float>("l2",
|
||||
"(float, default 0.0) "
|
||||
"L2 regularization strength.")
|
||||
.SetDefault(0.0f);
|
||||
AddAttr<float>("lr_power",
|
||||
"(float, default -0.5f) "
|
||||
"Learning Rate Power.")
|
||||
.SetDefault(-0.5f);
|
||||
AddComment(R"DOC(
|
||||
FTRL (Follow The Regularized Leader) Operator.
|
||||
|
||||
Optimizer that implements the FTRL algorithm:
|
||||
|
||||
$$
|
||||
new\_accum = squared\_accum + grad^2 \\
|
||||
if (lr\_power == -0.5) {
|
||||
linear\_accum += grad - (\surd(new\_accum) - \surd(squared\_accum)) /
|
||||
(learning\_rate * param) \\
|
||||
} else {
|
||||
linear\_accum += grad -
|
||||
(new\_accum^{-lr\_power} - accum^{-lr\_power}) /
|
||||
(learning\_rate * param) \\
|
||||
}
|
||||
|
||||
x = (l1 * sign(linear\_accum) - linear\_accum)
|
||||
if (lr\_power == -0.5) {
|
||||
y = \frac{\surd(new\_accum)}{learning\_rate} + (2 * l2) \\
|
||||
pre\_shrink = \frac{x}{y} \\
|
||||
param = (abs(linear\_accum) > l1).select(pre\_shrink, 0.0) \\
|
||||
} else {
|
||||
y = \frac{new\_accum^{-lr\_power}}{learning\_rate} + (2 * l2) \\
|
||||
pre\_shrink = \frac{x}{y} \\
|
||||
param = (abs(linear\_accum) > l1).select(pre\_shrink, 0.0) \\
|
||||
}
|
||||
squared\_accum += grad^2;
|
||||
$$
|
||||
|
||||
The paper that proposed Follow The Regularized Leader (FTRL):
|
||||
(https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf)
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(ftrl, ops::FTRLOp, ops::FTRLOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(ftrl,
|
||||
ops::FTRLOpKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,19 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
You may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed
|
||||
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. */
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "paddle/operators/ftrl_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(ftrl,
|
||||
ops::FTRLOpKernel<paddle::platform::GPUPlace, float>);
|
@ -0,0 +1,96 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
||||
|
||||
template <typename Place, typename T>
|
||||
class FTRLOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* param_out = ctx.Output<Tensor>("ParamOut");
|
||||
auto* sq_accum_out = ctx.Output<Tensor>("SquaredAccumOut");
|
||||
auto* lin_accum_out = ctx.Output<Tensor>("LinearAccumOut");
|
||||
|
||||
param_out->mutable_data<T>(ctx.GetPlace());
|
||||
sq_accum_out->mutable_data<T>(ctx.GetPlace());
|
||||
lin_accum_out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto grad = ctx.Input<Tensor>("Grad");
|
||||
|
||||
auto l1 = static_cast<T>(ctx.Attr<float>("l1"));
|
||||
auto l2 = static_cast<T>(ctx.Attr<float>("l2"));
|
||||
auto lr_power = static_cast<T>(ctx.Attr<float>("lr_power"));
|
||||
|
||||
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
|
||||
auto sq_accum =
|
||||
EigenVector<T>::Flatten(*ctx.Input<Tensor>("SquaredAccumulator"));
|
||||
auto lin_accum =
|
||||
EigenVector<T>::Flatten(*ctx.Input<Tensor>("LinearAccumulator"));
|
||||
auto g = EigenVector<T>::Flatten(*grad);
|
||||
auto lr = EigenVector<T>::Flatten(*ctx.Input<Tensor>("LearningRate"));
|
||||
|
||||
auto p_out = EigenVector<T>::Flatten(*param_out);
|
||||
auto s_acc_out = EigenVector<T>::Flatten(*sq_accum_out);
|
||||
auto l_acc_out = EigenVector<T>::Flatten(*lin_accum_out);
|
||||
auto place = ctx.GetEigenDevice<Place>();
|
||||
|
||||
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
|
||||
|
||||
auto new_accum = sq_accum + g * g;
|
||||
// Special case for lr_power = -0.5
|
||||
if (lr_power == static_cast<T>(-0.5)) {
|
||||
l_acc_out.device(place) =
|
||||
lin_accum + g -
|
||||
((new_accum.sqrt() - sq_accum.sqrt()) / lr.broadcast(grad_dsize)) * p;
|
||||
} else {
|
||||
l_acc_out.device(place) =
|
||||
lin_accum + g -
|
||||
((new_accum.pow(-lr_power) - sq_accum.pow(-lr_power)) /
|
||||
lr.broadcast(grad_dsize)) *
|
||||
p;
|
||||
}
|
||||
|
||||
auto x = (l_acc_out.constant(l1) * l_acc_out.sign() - l_acc_out);
|
||||
if (lr_power == static_cast<T>(-0.5)) {
|
||||
auto y = (new_accum.sqrt() / lr.broadcast(grad_dsize)) +
|
||||
l_acc_out.constant(static_cast<T>(2) * l2);
|
||||
auto pre_shrink = x / y;
|
||||
p_out.device(place) =
|
||||
(l_acc_out.abs() > l_acc_out.constant(l1))
|
||||
.select(pre_shrink, p.constant(static_cast<T>(0)));
|
||||
} else {
|
||||
auto y = (new_accum.pow(-lr_power) / lr.broadcast(grad_dsize)) +
|
||||
l_acc_out.constant(static_cast<T>(2) * l2);
|
||||
auto pre_shrink = x / y;
|
||||
p_out.device(place) =
|
||||
(l_acc_out.abs() > l_acc_out.constant(l1))
|
||||
.select(pre_shrink, p.constant(static_cast<T>(0)));
|
||||
}
|
||||
|
||||
s_acc_out.device(place) = sq_accum + g * g;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,62 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestFTRLOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "ftrl"
|
||||
w = np.random.random((102, 105)).astype("float32")
|
||||
g = np.random.random((102, 105)).astype("float32")
|
||||
sq_accum = np.full((102, 105), 0.1).astype("float32")
|
||||
linear_accum = np.full((102, 105), 0.1).astype("float32")
|
||||
lr = np.array([0.01]).astype("float32")
|
||||
l1 = 0.1
|
||||
l2 = 0.2
|
||||
lr_power = -0.5
|
||||
|
||||
self.inputs = {
|
||||
'Param': w,
|
||||
'SquaredAccumulator': sq_accum,
|
||||
'LinearAccumulator': linear_accum,
|
||||
'Grad': g,
|
||||
'LearningRate': lr
|
||||
}
|
||||
self.attrs = {
|
||||
'l1': l1,
|
||||
'l2': l2,
|
||||
'lr_power': lr_power,
|
||||
'learning_rate': lr
|
||||
}
|
||||
new_accum = sq_accum + g * g
|
||||
if lr_power == -0.5:
|
||||
linear_out = linear_accum + g - (
|
||||
(np.sqrt(new_accum) - np.sqrt(sq_accum)) / lr) * w
|
||||
else:
|
||||
linear_out = linear_accum + g - ((np.power(
|
||||
new_accum, -lr_power) - np.power(sq_accum, -lr_power)) / lr) * w
|
||||
|
||||
x = (l1 * np.sign(linear_out) - linear_out)
|
||||
if lr_power == -0.5:
|
||||
y = (np.sqrt(new_accum) / lr) + (2 * l2)
|
||||
pre_shrink = x / y
|
||||
param_out = np.where(np.abs(linear_out) > l1, pre_shrink, 0.0)
|
||||
else:
|
||||
y = (np.power(new_accum, -lr_power) / lr) + (2 * l2)
|
||||
pre_shrink = x / y
|
||||
param_out = np.where(np.abs(linear_out) > l1, pre_shrink, 0.0)
|
||||
|
||||
sq_accum_out = sq_accum + g * g
|
||||
|
||||
self.outputs = {
|
||||
'ParamOut': param_out,
|
||||
'SquaredAccumOut': sq_accum_out,
|
||||
'LinearAccumOut': linear_out
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue