Adadelta Optimizer (#26590)
* add doc; notest * fix doc; notest * update doc; notest * refine optimizer && adam * refine optimizer; notest * add adam * fix doc * fix doc && add adamw; notest * add error message * bug fix * refine rmsprop && adamax * fix ci * buf fix * update comment * unify arguments place; notest * fix ut, test=develop * bug fix * fix conflicts, test=develop * add examples code * bug fix * fix comments * fix sample code * add sample code for Optimizer * add adamax ut, test=develop * fix rmsprop ut, test=develop * add ut for optimizer.py and adamw.py * first commit of adadelta optimizer * fix learning rate * fix adadelta doc and add sgd momentum * remove unused fluid * fix codestyle * Update test_adam_op.py * Update test_adam_op.py * fix SGD in 2 unittests * fix SGD in 2 unittests * fix ci * fix ut Co-authored-by: MRXLT <xlt2024@gmail.com> Co-authored-by: mapingshuo <mps2012@yeah.net>revert-26856-strategy_example2
parent
346689c6f1
commit
a1b99fae07
@ -0,0 +1,144 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .optimizer import Optimizer
|
||||
from ..fluid import core
|
||||
from ..fluid import framework
|
||||
from ..fluid.framework import Variable, name_scope
|
||||
|
||||
__all__ = ["Adadelta"]
|
||||
|
||||
|
||||
class Adadelta(Optimizer):
|
||||
"""
|
||||
**Notes: This API does not support sparse parameter optimization.**
|
||||
|
||||
Adadelta Optimizer. Please refer to this for details:
|
||||
`ADADELTA: AN ADAPTIVE LEARNING RATE METHOD <https://arxiv.org/abs/1212.5701>`_.
|
||||
|
||||
The update is done as follows:
|
||||
|
||||
.. math::
|
||||
|
||||
E(g_t^2) &= \\rho * E(g_{t-1}^2) + (1-\\rho) * g^2
|
||||
|
||||
learning\_rate &= \sqrt{ ( E(dx_{t-1}^2) + \\epsilon ) / ( E(g_t^2) + \\epsilon ) }
|
||||
|
||||
E(dx_t^2) &= \\rho * E(dx_{t-1}^2) + (1-\\rho) * (-g*learning\_rate)^2
|
||||
|
||||
Args:
|
||||
learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``.
|
||||
It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001.
|
||||
epsilon (float): a small float number for numeric stability. Default 1.0e-6.
|
||||
rho (float): a floating point value indicating the decay rate. Default 0.95.
|
||||
parameters (list, optional): List of ``Tensor`` to update to minimize ``loss``. \
|
||||
This parameter is required in dygraph mode. \
|
||||
The default value is None in static mode, at this time all parameters will be updated.
|
||||
weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
|
||||
It canbe a float value as coeff of L2 regularization or \
|
||||
:ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
|
||||
If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \
|
||||
the regularization setting here in optimizer will be ignored for this parameter. \
|
||||
Otherwise, the regularization setting here in optimizer will take effect. \
|
||||
Default None, meaning there is no regularization.
|
||||
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
|
||||
some derived class of ``GradientClipBase`` . There are three cliping strategies
|
||||
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
|
||||
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
|
||||
name (str, optional): The default value is None. Normally there is no need for user
|
||||
to set this property. For more information, please refer to
|
||||
:ref:`api_guide_Name` .
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
import paddle
|
||||
import numpy as np
|
||||
paddle.disable_static()
|
||||
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
|
||||
linear = paddle.nn.Linear(10, 10)
|
||||
inp = paddle.to_tensor(inp)
|
||||
out = linear(inp)
|
||||
loss = paddle.mean(out)
|
||||
beta1 = paddle.to_tensor([0.9], dtype="float32")
|
||||
beta2 = paddle.to_tensor([0.99], dtype="float32")
|
||||
adadelta = paddle.optimizer.Adadelta(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01)
|
||||
back = out.backward()
|
||||
adadelta.step()
|
||||
adadelta.clear_grad()
|
||||
|
||||
"""
|
||||
|
||||
_avg_squared_grad_acc_str = "_avg_squared_grad"
|
||||
_avg_squared_update_acc_str = "_avg_squared_update"
|
||||
|
||||
def __init__(self,
|
||||
learning_rate=0.001,
|
||||
epsilon=1.0e-6,
|
||||
rho=0.95,
|
||||
parameters=None,
|
||||
weight_decay=None,
|
||||
grad_clip=None,
|
||||
name=None):
|
||||
if learning_rate is None:
|
||||
raise ValueError("learning_rate is not set.")
|
||||
if epsilon is None:
|
||||
raise ValueError("epsilon is not set.")
|
||||
if rho is None:
|
||||
raise ValueError("rho is not set.")
|
||||
super(Adadelta, self).__init__(
|
||||
learning_rate=learning_rate,
|
||||
parameters=parameters,
|
||||
weight_decay=weight_decay,
|
||||
grad_clip=grad_clip,
|
||||
name=name)
|
||||
self.type = "adadelta"
|
||||
self._epsilon = epsilon
|
||||
self._rho = rho
|
||||
|
||||
def _create_accumulators(self, block, parameters):
|
||||
if not isinstance(block, framework.Block):
|
||||
raise TypeError("block is not instance of framework.Block.")
|
||||
|
||||
for p in parameters:
|
||||
self._add_accumulator(self._avg_squared_grad_acc_str, p)
|
||||
self._add_accumulator(self._avg_squared_update_acc_str, p)
|
||||
|
||||
def _append_optimize_op(self, block, param_and_grad):
|
||||
if not isinstance(block, framework.Block):
|
||||
raise TypeError("block is not instance of framework.Block.")
|
||||
|
||||
avg_squared_grad_acc = self._get_accumulator(
|
||||
self._avg_squared_grad_acc_str, param_and_grad[0])
|
||||
avg_squared_update_acc = self._get_accumulator(
|
||||
self._avg_squared_update_acc_str, param_and_grad[0])
|
||||
|
||||
# Create the adadelta optimizer op
|
||||
adadelta_op = block.append_op(
|
||||
type=self.type,
|
||||
inputs={
|
||||
"Param": param_and_grad[0],
|
||||
"Grad": param_and_grad[1],
|
||||
"AvgSquaredGrad": avg_squared_grad_acc,
|
||||
"AvgSquaredUpdate": avg_squared_update_acc
|
||||
},
|
||||
outputs={
|
||||
"ParamOut": param_and_grad[0],
|
||||
"AvgSquaredGradOut": avg_squared_grad_acc,
|
||||
"AvgSquaredUpdateOut": avg_squared_update_acc
|
||||
},
|
||||
attrs={"epsilon": self._epsilon,
|
||||
"rho": self._rho},
|
||||
stop_gradient=True)
|
||||
|
||||
return adadelta_op
|
@ -0,0 +1,149 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .optimizer import Optimizer
|
||||
from ..fluid import core
|
||||
from ..fluid import framework
|
||||
from ..fluid.framework import Variable, name_scope
|
||||
|
||||
__all__ = ["Momentum"]
|
||||
|
||||
|
||||
class Momentum(Optimizer):
|
||||
"""
|
||||
|
||||
Simple Momentum optimizer with velocity state
|
||||
|
||||
This optimizer has a flag for Nestrov Momentum.
|
||||
|
||||
The update equations are as follows:
|
||||
|
||||
.. math::
|
||||
|
||||
& velocity = mu * velocity + gradient
|
||||
|
||||
& if (use\_nesterov):
|
||||
|
||||
&\quad param = param - (gradient + mu * velocity) * learning\_rate
|
||||
|
||||
& else:
|
||||
|
||||
&\quad param = param - learning\_rate * velocity
|
||||
|
||||
Parameters:
|
||||
|
||||
learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``.
|
||||
It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001.
|
||||
momentum (float): Momentum factor. The default value is 0.9.
|
||||
parameters (list, optional): List of ``Tensor`` to update to minimize ``loss``. \
|
||||
This parameter is required in dygraph mode. \
|
||||
The default value is None in static mode, at this time all parameters will be updated.
|
||||
weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
|
||||
It canbe a float value as coeff of L2 regularization or \
|
||||
:ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
|
||||
If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \
|
||||
the regularization setting here in optimizer will be ignored for this parameter. \
|
||||
Otherwise, the regularization setting here in optimizer will take effect. \
|
||||
Default None, meaning there is no regularization.
|
||||
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
|
||||
some derived class of ``GradientClipBase`` . There are three cliping strategies
|
||||
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
|
||||
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
|
||||
name (str, optional): The default value is None. Normally there is no need for user
|
||||
to set this property. For more information, please refer to
|
||||
:ref:`api_guide_Name` .
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
import numpy as np
|
||||
paddle.disable_static()
|
||||
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
|
||||
linear = paddle.nn.Linear(10, 10)
|
||||
inp = paddle.to_tensor(inp)
|
||||
out = linear(inp)
|
||||
loss = paddle.mean(out)
|
||||
beta1 = paddle.to_tensor([0.9], dtype="float32")
|
||||
beta2 = paddle.to_tensor([0.99], dtype="float32")
|
||||
momentum = paddle.optimizer.Momentum(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01)
|
||||
back = out.backward()
|
||||
momentum.step()
|
||||
momentum.clear_grad()
|
||||
"""
|
||||
_velocity_acc_str = "velocity"
|
||||
|
||||
def __init__(self,
|
||||
learning_rate=0.001,
|
||||
momentum=0.9,
|
||||
parameters=None,
|
||||
use_nesterov=False,
|
||||
weight_decay=None,
|
||||
grad_clip=None,
|
||||
name=None):
|
||||
if learning_rate is None:
|
||||
raise ValueError("learning_rate is not set")
|
||||
if momentum is None:
|
||||
raise ValueError("momentum is not set")
|
||||
super(Momentum, self).__init__(
|
||||
learning_rate=learning_rate,
|
||||
parameters=parameters,
|
||||
weight_decay=weight_decay,
|
||||
grad_clip=grad_clip,
|
||||
name=name)
|
||||
self.type = "momentum"
|
||||
self._momentum = momentum
|
||||
self._use_nesterov = bool(use_nesterov)
|
||||
|
||||
def _create_accumulators(self, block, parameters):
|
||||
assert isinstance(block, framework.Block)
|
||||
|
||||
for p in parameters:
|
||||
self._add_accumulator(self._velocity_acc_str, p)
|
||||
|
||||
def _append_optimize_op(self, block, param_and_grad):
|
||||
assert isinstance(block, framework.Block)
|
||||
|
||||
velocity_acc = self._get_accumulator(self._velocity_acc_str,
|
||||
param_and_grad[0])
|
||||
lr = self._create_param_lr(param_and_grad)
|
||||
|
||||
if framework.in_dygraph_mode():
|
||||
_, _ = core.ops.momentum(param_and_grad[0], param_and_grad[1],
|
||||
velocity_acc, lr, param_and_grad[0],
|
||||
velocity_acc, 'mu', self._momentum,
|
||||
'use_nesterov', self._use_nesterov)
|
||||
return None
|
||||
|
||||
attrs = {"mu": self._momentum, "use_nesterov": self._use_nesterov}
|
||||
inputs = {
|
||||
"Param": [param_and_grad[0]],
|
||||
"Grad": [param_and_grad[1]],
|
||||
"Velocity": [velocity_acc],
|
||||
"LearningRate": [lr]
|
||||
}
|
||||
|
||||
outputs = {
|
||||
"ParamOut": [param_and_grad[0]],
|
||||
"VelocityOut": [velocity_acc]
|
||||
}
|
||||
# create the momentum optimize op
|
||||
momentum_op = block.append_op(
|
||||
type=self.type,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
attrs=attrs,
|
||||
stop_gradient=True)
|
||||
|
||||
return momentum_op
|
@ -0,0 +1,108 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .optimizer import Optimizer
|
||||
from ..fluid import core
|
||||
from ..fluid import framework
|
||||
from ..fluid.framework import Variable, name_scope
|
||||
from ..fluid.dygraph import no_grad
|
||||
__all__ = ["SGD"]
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
"""
|
||||
Optimizer of the stochastic gradient descent algorithm.
|
||||
|
||||
.. math::
|
||||
|
||||
param\_out = param - learning\_rate * grad
|
||||
|
||||
Parameters:
|
||||
learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``.
|
||||
It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001.
|
||||
parameters (list, optional): List of ``Tensor`` to update to minimize ``loss``. \
|
||||
This parameter is required in dygraph mode. \
|
||||
The default value is None in static mode, at this time all parameters will be updated.
|
||||
weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
|
||||
It canbe a float value as coeff of L2 regularization or \
|
||||
:ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
|
||||
If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \
|
||||
the regularization setting here in optimizer will be ignored for this parameter. \
|
||||
Otherwise, the regularization setting here in optimizer will take effect. \
|
||||
Default None, meaning there is no regularization.
|
||||
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
|
||||
some derived class of ``GradientClipBase`` . There are three cliping strategies
|
||||
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
|
||||
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
|
||||
name (str, optional): The default value is None. Normally there is no need for user
|
||||
to set this property. For more information, please refer to
|
||||
:ref:`api_guide_Name` .
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
import numpy as np
|
||||
paddle.disable_static()
|
||||
inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
|
||||
linear = paddle.nn.Linear(10, 10)
|
||||
inp = paddle.to_tensor(inp)
|
||||
out = linear(inp)
|
||||
loss = paddle.mean(out)
|
||||
beta1 = paddle.to_tensor([0.9], dtype="float32")
|
||||
beta2 = paddle.to_tensor([0.99], dtype="float32")
|
||||
sgd = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01)
|
||||
back = out.backward()
|
||||
sgd.step()
|
||||
sgd.clear_grad()
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate=0.001,
|
||||
parameters=None,
|
||||
weight_decay=None,
|
||||
grad_clip=None,
|
||||
name=None):
|
||||
if learning_rate is None:
|
||||
raise ValueError("learning_rate is not set")
|
||||
super(SGD, self).__init__(
|
||||
learning_rate=learning_rate,
|
||||
parameters=parameters,
|
||||
weight_decay=weight_decay,
|
||||
grad_clip=grad_clip,
|
||||
name=name)
|
||||
self.type = "sgd"
|
||||
|
||||
@no_grad()
|
||||
def _append_optimize_op(self, block, param_and_grad):
|
||||
lr = self._create_param_lr(param_and_grad)
|
||||
if framework.in_dygraph_mode():
|
||||
core.ops.sgd(param_and_grad[0], lr, param_and_grad[1],
|
||||
param_and_grad[0])
|
||||
return None
|
||||
|
||||
assert isinstance(block, framework.Block)
|
||||
# create the optimize op
|
||||
sgd_op = block.append_op(
|
||||
type=self.type,
|
||||
inputs={
|
||||
"Param": param_and_grad[0],
|
||||
"Grad": param_and_grad[1],
|
||||
"LearningRate": lr
|
||||
},
|
||||
outputs={"ParamOut": param_and_grad[0]},
|
||||
stop_gradient=True)
|
||||
|
||||
return sgd_op
|
Loading…
Reference in new issue