Add LAMB Optimizer support (#17489)

* Add LAMB optimizer

* Expose LAMB Optimizer's APIs

test=develop, test=document_preview

* Cleanup code & doc

test=develop, test=document_preview

* Update lamb optimizer's formula

test=develop
ema_dev
Yibing Liu 6 years ago committed by GitHub
parent 99ab57123c
commit f9796b1249
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -516,6 +516,12 @@ paddle.fluid.optimizer.DGCMomentumOptimizer.apply_optimize (ArgSpec(args=['self'
paddle.fluid.optimizer.DGCMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f'))
paddle.fluid.optimizer.DGCMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.DGCMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea'))
paddle.fluid.optimizer.LambOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'lamb_weight_decay', 'beta1', 'beta2', 'epsilon', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.01, 0.9, 0.999, 1e-06, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.LambOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871'))
paddle.fluid.optimizer.LambOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae'))
paddle.fluid.optimizer.LambOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f'))
paddle.fluid.optimizer.LambOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.LambOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea'))
paddle.fluid.backward.append_backward (ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '08a5dd9f6f376ff3d55e0b1d92115cbd'))
paddle.fluid.regularizer.L1DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.regularizer.L2DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))

@ -18,67 +18,64 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class AdamOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment1"),
"Input(Moment1) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment2"),
"Input(Moment2) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
"Input(Beta1Pow) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"),
"Input(Beta2Pow) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"),
"Output(Moment1Out) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"),
"Output(Moment2Out) of AdamOp should not be null.");
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 dimension");
auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow");
PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
"Beta1 power accumulator should have 1 dimension");
auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow");
PADDLE_ENFORCE_EQ(framework::product(beta2_pow_dims), 1,
"Beta2 power accumulator should have 1 dimension");
auto param_dims = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"),
"Param and Grad input of AdamOp should have same dimension");
}
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment1"),
"Param and Moment1 input of AdamOp should have same dimension");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment2"),
"Param and Moment2 input of AdamOp should have same dimension");
ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("Moment1Out", param_dims);
ctx->SetOutputDim("Moment2Out", param_dims);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = ctx.Input<Tensor>("Param")->type();
return framework::OpKernelType(input_data_type, ctx.GetPlace());
void AdamOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment1"),
"Input(Moment1) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment2"),
"Input(Moment2) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
"Input(Beta1Pow) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"),
"Input(Beta2Pow) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"),
"Output(Moment1Out) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"),
"Output(Moment2Out) of AdamOp should not be null.");
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 dimension");
auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow");
PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
"Beta1 power accumulator should have 1 dimension");
auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow");
PADDLE_ENFORCE_EQ(framework::product(beta2_pow_dims), 1,
"Beta2 power accumulator should have 1 dimension");
auto param_dims = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"),
"Param and Grad input of AdamOp should have same dimension");
}
};
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment1"),
"Param and Moment1 input of AdamOp should have same dimension");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment2"),
"Param and Moment2 input of AdamOp should have same dimension");
ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("Moment1Out", param_dims);
ctx->SetOutputDim("Moment2Out", param_dims);
}
framework::OpKernelType AdamOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto input_data_type = ctx.Input<framework::Tensor>("Param")->type();
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
public:

@ -29,6 +29,15 @@ namespace operators {
namespace scatter = paddle::operators::math::scatter;
class AdamOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
struct GPUAdam;
struct CPUAdam;

@ -0,0 +1,95 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/lamb_op.h"
#include "paddle/fluid/operators/optimizers/adam_op.h"
namespace paddle {
namespace operators {
class LambOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param",
"(LoDTensor, default LoDTensor<float>) "
"Input parameter that has to be updated.");
AddInput("Grad",
"(LoDTensor, default LoDTensor<float>) "
"Input gradient of the parameter.");
AddInput("LearningRate", "(Tensor) Learning rate.");
AddInput("Moment1", "(Tensor) Input first moment.");
AddInput("Moment2", "(Tensor) Input second moment.");
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator.");
AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator.");
AddOutput("ParamOut", "(Tensor) Output parameter.");
AddOutput("Moment1Out", "(Tensor) Output first moment.");
AddOutput("Moment2Out", "(Tensor) Output second moment.");
AddAttr<float>("weight_decay", "(float) Weight decay rate.");
AddAttr<float>("beta1",
"(float, default 0.9) The exponential decay rate for the "
"1st moment estimates.")
.SetDefault(0.9);
AddAttr<float>("beta2",
"(float, default 0.999) The exponential decay rate for the "
"2nd moment estimates.")
.SetDefault(0.999);
AddAttr<float>("epsilon",
"(float, default 1.0e-6) "
"Constant for numerical stability.")
.SetDefault(1.0e-6f);
AddComment(R"DOC(
LAMB (Layer-wise Adaptive Moments optimizer for Batching training) Optimizer.
LAMB Optimizer is designed to scale up the batch size of training without losing
accuracy, which supports adaptive element-wise updating and accurate layer-wise
correction. For more information, please refer to https://arxiv.org/abs/1904.00962.
The updating of parameters follows:
$$
m_t^l &= \beta_1 m_{t - 1}^l + (1 - \beta_1)g_t^l \\
v_t^l &= \beta_2 v_{t - 1}^l + (1 - \beta_2)g_t^l \odot g_t^l \\
\widehat{m}_t^l &= m_t^l/(1 - \beta_1^t) \\
\widehat{v}_t^l &= v_t^l/(1 - \beta_2^t) \\
r_1 &= \left \| w_{t-1}^l \right \|_2 \\
r_2 &= \left \| \frac{\widehat{m}_t^l}{\sqrt{\widehat{v}_t^l+\epsilon}} + \lambda w_{t-1}^l \right \|_2 \\
r &= r_1 / r_2 \\
\eta^l &= r \times \eta \\
w_t^l &= w_{t-1}^l -\eta ^l \times (\frac{\widehat{m}_t^l}{\sqrt{\widehat{v}_t^l+\epsilon}} + \lambda w_{t-1}^l)
$$
where $m$ is the 1st moment, and $v$ the 2nd moment, $\eta$ the
learning rate, $\lambda$ the weight decay rate.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(lamb, ops::AdamOp, ops::LambOpMaker);
REGISTER_OP_CPU_KERNEL(
lamb, ops::LambOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::LambOpKernel<paddle::platform::CPUDeviceContext, double>);

@ -0,0 +1,20 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/lamb_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
lamb, ops::LambOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::LambOpKernel<paddle::platform::CUDADeviceContext, double>);

File diff suppressed because it is too large Load Diff

@ -132,7 +132,7 @@ def train(net_type, use_cuda, save_dirname, is_local):
# Test program
test_program = train_program.clone(for_test=True)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer = fluid.optimizer.Lamb(learning_rate=0.001)
mp_optimizer = fluid.contrib.mixed_precision.decorate(
optimizer=optimizer,

@ -42,7 +42,7 @@ __all__ = [
'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer',
'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer',
'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'LarsMomentum',
'LarsMomentumOptimizer', 'DGCMomentumOptimizer'
'LarsMomentumOptimizer', 'DGCMomentumOptimizer', 'LambOptimizer'
]
@ -1851,6 +1851,133 @@ class FtrlOptimizer(Optimizer):
return ftrl_op
class LambOptimizer(AdamOptimizer):
"""
LAMB (Layer-wise Adaptive Moments optimizer for Batching training) Optimizer.
LAMB Optimizer is designed to scale up the batch size of training without losing
accuracy, which supports adaptive element-wise updating and accurate layer-wise
correction. For more information, please refer to `Reducing BERT Pre-Training
Time from 3 Days to 76 Minutes <https://arxiv.org/abs/1904.00962>`_ .
The updating of parameters follows:
.. math::
m_t^l & = \\beta_1 m_{t - 1}^l + (1 - \\beta_1)g_t^l
v_t^l & = \\beta_2 v_{t - 1}^l + (1 - \\beta_2)g_t^l \odot g_t^l
\\widehat{m}_t^l & = m_t^l/(1 - \\beta_1^t)
\\widehat{v}_t^l & = v_t^l/(1 - \\beta_2^t)
r_1 & = \\left \| w_{t-1}^l \\right \|_2
r_2 & = \\left \| \\frac{\\widehat{m}_t^l}{\\sqrt{\\widehat{v}_t^l+\\epsilon}} + \\lambda w_{t-1}^l \\right \|_2
r & = r_1 / r_2
\\eta^l & = r \\times \\eta
w_t^l & = w_{t-1}^l -\\eta ^l \\times (\\frac{\\widehat{m}_t^l}{\\sqrt{\\widehat{v}_t^l+\\epsilon}} + \\lambda w_{t-1}^l)
where :math:`m` is the 1st moment, and :math:`v` the 2nd moment, :math:`\\eta` the
learning rate, :math:`\\lambda` the LAMB weight decay rate.
Args:
learning_rate (float|Variable): the learning rate used to update parameters. \
Can be a float value or a Variable with one \
float value as data element.
lamb_weight_decay (float): The LAMB weight decay rate.
beta1 (float): The exponential decay rate for the 1st moment estimates.
beta2 (float): The exponential decay rate for the 2nd moment estimates.
epsilon (float): A small float value for numerical stability.
regularization: A Regularizer, such as
fluid.regularizer.L1DecayRegularizer.
name (str|None): An optional name prefix.
Examples:
.. code-block:: python
import paddle.fluid as fluid
data = fluid.layers.data(name='x', shape=[5], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
cost = fluid.layers.mean(hidden)
optimizer = fluid.optimizer.Lamb(learning_rate=0.002)
optimizer.minimize(cost)
"""
_moment1_acc_str = "moment1"
_moment2_acc_str = "moment2"
_beta1_pow_acc_str = "beta1_pow_acc"
_beta2_pow_acc_str = "beta2_pow_acc"
def __init__(self,
learning_rate=0.001,
lamb_weight_decay=0.01,
beta1=0.9,
beta2=0.999,
epsilon=1e-6,
regularization=None,
name=None):
assert learning_rate is not None
assert lamb_weight_decay is not None
assert beta1 is not None
assert beta2 is not None
assert epsilon is not None
super(LambOptimizer, self).__init__(
learning_rate=learning_rate,
regularization=regularization,
beta1=beta1,
beta2=beta2,
epsilon=epsilon,
name=name)
self.type = "lamb"
self._weight_decay = lamb_weight_decay
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
moment1 = self._get_accumulator(self._moment1_acc_str,
param_and_grad[0])
moment2 = self._get_accumulator(self._moment2_acc_str,
param_and_grad[0])
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param_and_grad[0])
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
param_and_grad[0])
# create the lamb optimize op
lamb_op = block.append_op(
type=self.type,
inputs={
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"LearningRate": self._create_param_lr(param_and_grad),
"Moment1": moment1,
"Moment2": moment2,
"Beta1Pow": beta1_pow_acc,
"Beta2Pow": beta2_pow_acc
},
outputs={
"ParamOut": param_and_grad[0],
"Moment1Out": moment1,
"Moment2Out": moment2
},
attrs={
"beta1": self._beta1,
"beta2": self._beta2,
"epsilon": self._epsilon,
"weight_decay": self._weight_decay
},
stop_gradient=True)
return lamb_op
# We short the class name, since users will use the optimizer with the package
# name. The sample code:
#
@ -1869,6 +1996,7 @@ Adadelta = AdadeltaOptimizer
RMSProp = RMSPropOptimizer
Ftrl = FtrlOptimizer
LarsMomentum = LarsMomentumOptimizer
Lamb = LambOptimizer
class ModelAverage(Optimizer):

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save