add AdamWeightDecayOp

pull/12826/head
VectorSL 4 years ago
parent f4a5eb5219
commit c1a619ccfe

@ -40,12 +40,47 @@ __global__ void ApplyAdamKernel(const size_t size, const T *gradient, const T *b
}
}
template <typename T>
__global__ void AdamWeightDecayKernel(const size_t size, const T *gradient, const float *learning_rate,
const float *beta1, const float *beta2, const float *epsilon, const float *decay,
T *variable, T *m, T *v) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
T next_m = beta1[0] * m[i] + (1 - beta1[0]) * gradient[i];
T next_v = beta2[0] * v[i] + (1 - beta2[0]) * gradient[i] * gradient[i];
T update = next_m / (sqrt(next_v) + epsilon[0]);
update += decay[0] * variable[i];
variable[i] -= learning_rate[0] * update;
m[i] = next_m;
v[i] = next_v;
}
}
template <>
__global__ void AdamWeightDecayKernel(const size_t size, const half *gradient, const float *learning_rate,
const float *beta1, const float *beta2, const float *epsilon, const float *decay,
half *variable, half *m, half *v) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
half next_m = __float2half(beta1[0]) * m[i] + __float2half(1 - beta1[0]) * gradient[i];
half next_v = __float2half(beta2[0]) * v[i] + __float2half(1 - beta2[0]) * gradient[i] * gradient[i];
half update = next_m / (hsqrt(next_v) + __float2half(epsilon[0]));
update += __float2half(decay[0]) * variable[i];
variable[i] -= __float2half(learning_rate[0]) * update;
m[i] = next_m;
v[i] = next_v;
}
}
template <typename T>
void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate,
const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream) {
ApplyAdamKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
size, gradient, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, variable, m, v);
}
template <typename T>
void AdamWeightDecayOp(const size_t size, const T *gradient, const float *learning_rate, const float *beta1,
const float *beta2, const float *epsilon, const float *decay, T *variable, T *m, T *v,
cudaStream_t cuda_stream) {
AdamWeightDecayKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, gradient, learning_rate, beta1, beta2,
epsilon, decay, variable, m, v);
}
template void ApplyAdam<float>(const size_t size, const float *gradient, const float *beta1_power,
const float *beta2_power, const float *learning_rate, const float *beta1,
@ -54,3 +89,9 @@ template void ApplyAdam<float>(const size_t size, const float *gradient, const f
template void ApplyAdam<half>(const size_t size, const half *gradient, const half *beta1_power, const half *beta2_power,
const half *learning_rate, const half *beta1, const half *beta2, const half *epsilon,
half *variable, half *m, half *v, cudaStream_t cuda_stream);
template void AdamWeightDecayOp<float>(const size_t size, const float *gradient, const float *learning_rate,
const float *beta1, const float *beta2, const float *epsilon, const float *decay,
float *variable, float *m, float *v, cudaStream_t cuda_stream);
template void AdamWeightDecayOp<half>(const size_t size, const half *gradient, const float *learning_rate,
const float *beta1, const float *beta2, const float *epsilon, const float *decay,
half *variable, half *m, half *v, cudaStream_t cuda_stream);

@ -21,5 +21,9 @@
template <typename T>
void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate,
const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream);
template <typename T>
void AdamWeightDecayOp(const size_t size, const T *gradient, const float *learning_rate, const float *beta1,
const float *beta2, const float *epsilon, const float *decay, T *variable, T *m, T *v,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_

@ -0,0 +1,52 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/adam_weight_decay_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(AdamWeightDecay,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
AdamWeightDecayGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(AdamWeightDecay,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
AdamWeightDecayGpuKernel, half)
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,137 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAM_WEIGHT_DECAY_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAM_WEIGHT_DECAY_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class AdamWeightDecayGpuKernel : public GpuKernel {
public:
AdamWeightDecayGpuKernel()
: variable_size_(0),
m_size_(0),
v_size_(0),
learning_rate_size_(0),
beta1_size_(0),
beta2_size_(0),
epsilon_size_(0),
decay_size_(0),
gradient_size_(0) {}
~AdamWeightDecayGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
void *stream_ptr) override {
T *variable = GetDeviceAddress<T>(inputs, 0);
T *m = GetDeviceAddress<T>(inputs, 1);
T *v = GetDeviceAddress<T>(inputs, 2);
float *lr = GetDeviceAddress<float>(inputs, 3);
float *beta1 = GetDeviceAddress<float>(inputs, 4);
float *beta2 = GetDeviceAddress<float>(inputs, 5);
float *epsilon = GetDeviceAddress<float>(inputs, 6);
float *decay = GetDeviceAddress<float>(inputs, 7);
T *gradient = GetDeviceAddress<T>(inputs, 8);
AdamWeightDecayOp(inputs[0]->size / sizeof(T), gradient, lr, beta1, beta2, epsilon, decay, variable, m, v,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 9) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but adam needs 9 inputs.";
return false;
}
variable_size_ = sizeof(T);
m_size_ = sizeof(T);
v_size_ = sizeof(T);
learning_rate_size_ = sizeof(float);
beta1_size_ = sizeof(float);
beta2_size_ = sizeof(float);
epsilon_size_ = sizeof(float);
decay_size_ = sizeof(float);
gradient_size_ = sizeof(T);
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < variable_shape.size(); i++) {
variable_size_ *= variable_shape[i];
}
auto m_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
for (size_t i = 0; i < m_shape.size(); i++) {
m_size_ *= m_shape[i];
}
auto v_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
for (size_t i = 0; i < v_shape.size(); i++) {
v_size_ *= v_shape[i];
}
auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 8);
for (size_t i = 0; i < gradient_shape.size(); i++) {
gradient_size_ *= gradient_shape[i];
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(variable_size_);
input_size_list_.push_back(m_size_);
input_size_list_.push_back(v_size_);
input_size_list_.push_back(learning_rate_size_);
input_size_list_.push_back(beta1_size_);
input_size_list_.push_back(beta2_size_);
input_size_list_.push_back(epsilon_size_);
input_size_list_.push_back(decay_size_);
input_size_list_.push_back(gradient_size_);
output_size_list_.push_back(0);
output_size_list_.push_back(0);
output_size_list_.push_back(0);
}
private:
size_t variable_size_;
size_t m_size_;
size_t v_size_;
size_t learning_rate_size_;
size_t beta1_size_;
size_t beta2_size_;
size_t epsilon_size_;
size_t decay_size_;
size_t gradient_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAM_WEIGHT_DECAY_GPU_KERNEL_H_

@ -42,7 +42,7 @@ from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSumm
TensorSummary, HistogramSummary, Print, Assert)
from .control_ops import ControlDepend, GeSwitch, Merge
from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey,
FusedWeightScaleApplyMomentum)
FusedWeightScaleApplyMomentum, AdamWeightDecay)
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
BitwiseAnd, BitwiseOr,
@ -149,6 +149,7 @@ __all__ = [
'TopK',
'LinSpace',
'Adam',
'AdamWeightDecay',
'FusedSparseAdam',
'FusedSparseLazyAdam',
'AdamNoUpdateParam',

@ -441,3 +441,93 @@ class FusedWeightScaleApplyMomentum(PrimitiveWithInfer):
validator.check_scalar_or_tensor_types_same({"d_dtype": d_dtype}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"s_dtype": s_dtype}, valid_dtypes, self.name)
return v_dtype
class AdamWeightDecay(PrimitiveWithInfer):
r"""
Updates gradients by the Adaptive Moment Estimation (AdamWeightDecay) algorithm with weight decay.
The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
The updating formulas are as follows,
.. math::
\begin{array}{ll} \\
m = \beta_1 * m + (1 - \beta_1) * g \\
v = \beta_2 * v + (1 - \beta_2) * g * g \\
w = w - lr * \frac{m}{\sqrt{v} + \epsilon}
\end{array}
:math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
`gradient`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
:math:`\lr` represents `learning_rate`, :math:`w` represents `var`, :math:`\epsilon` represents
`epsilon`.
Args:
use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
If true, updates of the var, m, and v tensors will be protected by a lock.
If false, the result is unpredictable. Default: False.
Inputs:
- **var** (Tensor) - Weights to be updated.
- **m** (Tensor) - The 1st moment vector in the updating formula, has the same type as `var`.
- **v** (Tensor) - the 2nd moment vector in the updating formula.
Mean square gradients with the same type as `var`.
- **lr** (float) - :math:`l` in the updating formula.
- **beta1** (float) - The exponential decay rate for the 1st moment estimations.
- **beta2** (float) - The exponential decay rate for the 2nd moment estimations.
- **epsilon** (float) - Term added to the denominator to improve numerical stability.
- **gradient** (Tensor) - Gradient, has the same type as `var`.
Outputs:
Tuple of 3 Tensor, the updated parameters.
- **var** (Tensor) - The same shape and data type as `var`.
- **m** (Tensor) - The same shape and data type as `m`.
- **v** (Tensor) - The same shape and data type as `v`.
Supported Platforms:
``GPU``
Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as ops
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.adam_weight_decay = ops.AdamWeightDecay()
... self.var = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="var")
... self.m = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="m")
... self.v = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="v")
... def construct(self, lr, beta1, beta2, epsilon, decay, grad):
... out = self.adam_weight_decay(self.var, self.m, self.v, lr, beta1, beta2,
... epsilon, decay, grad)
... return out
>>> np.random.seed(0)
>>> net = Net()
>>> gradient = Tensor(np.random.rand(2, 2).astype(np.float32))
>>> output = net(0.9, 0.9, 0.999, 1e-8, 1e-5, gradient)
"""
@prim_attr_register
def __init__(self, use_locking=False):
validator.check_value_type("use_locking", use_locking, [bool], self.name)
def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape,
epsilon_shape, decay_shape, grad_shape):
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
return var_shape, m_shape, v_shape
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
epsilon_dtype, decay, grad_dtype):
args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype,
"decay": decay}
validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
return var_dtype, m_dtype, v_dtype

@ -36,7 +36,7 @@ from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithL
BertTrainAccumulationAllReduceEachWithLossScaleCell, \
BertTrainAccumulationAllReducePostWithLossScaleCell, \
BertTrainOneStepWithLossScaleCellForAdam, \
AdamWeightDecayForBert
AdamWeightDecayForBert, AdamWeightDecayOp
from src.dataset import create_bert_dataset
from src.config import cfg, bert_net_cfg
from src.utils import LossCallBack, BertLearningRate
@ -96,6 +96,8 @@ def _get_optimizer(args_opt, network):
{'order_params': params}]
if args_opt.enable_lossscale == "true" and args_opt.device_target == 'GPU':
optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
elif context.get_context("mode") == context.PYNATIVE_MODE and args_opt.device_target == 'GPU':
optimizer = AdamWeightDecayOp(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
else:
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
elif cfg.optimizer == "Thor":

@ -23,7 +23,7 @@ from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \
SaturateCast, CreateAttentionMaskFromInputMask
from .adam import AdamWeightDecayForBert
from .adam import AdamWeightDecayForBert, AdamWeightDecayOp
__all__ = [
"BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss",
"GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell",
@ -33,5 +33,5 @@ __all__ = [
"BertSelfAttention", "BertTransformer", "EmbeddingLookup",
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert",
"RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask",
"BertTrainOneStepWithLossScaleCellForAdam"
"BertTrainOneStepWithLossScaleCellForAdam", "AdamWeightDecayOp"
]

@ -28,6 +28,20 @@ _adam_opt = C.MultitypeFuncGraph("adam_opt")
_scaler_one = Tensor(1, mstype.int32)
_scaler_ten = Tensor(10, mstype.float32)
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_kernel(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter):
"""
Update parameters by AdamWeightDecay op.
"""
if optim_filter:
adam = P.AdamWeightDecay()
if decay_flags:
next_param = adam(param, m, v, lr, beta1, beta2, eps, Tensor(weight_decay, mstype.float32), gradient)
else:
next_param = adam(param, m, v, lr, beta1, beta2, eps, Tensor(0.0, mstype.float32), gradient)
return next_param
return gradient
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
@ -252,7 +266,7 @@ class AdamWeightDecayForBert(Optimizer):
Examples:
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.AdamWeightDecay(params=net.trainable_params())
>>> optim = AdamWeightDecay(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
@ -260,7 +274,7 @@ class AdamWeightDecayForBert(Optimizer):
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
... {'params': no_conv_params, 'lr': 0.01},
... {'order_params': net.trainable_params()}]
>>> optim = nn.AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0)
>>> optim = AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
@ -305,3 +319,105 @@ class AdamWeightDecayForBert(Optimizer):
if self.use_parallel:
self.broadcast_params(optim_result)
return optim_result
class AdamWeightDecayOp(Optimizer):
"""
Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops.
Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
To improve parameter groups performance, the customized order of parameters can be supported.
Args:
params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
"lr", "weight_decay" and "order_params" are the keys can be parsed.
- params: Required. The value must be a list of `Parameter`.
- lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
If not, the `learning_rate` in the API will be used.
- weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
will be used. If not, the `weight_decay` in the API will be used.
- order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
which in the 'order_params' must be in one of group parameters.
learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
use dynamic learning rate, the i-th learning rate will be calculated during the process of training
according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
Should be in range (0.0, 1.0).
beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
Should be in range (0.0, 1.0).
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
Should be greater than 0.
weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
tuple[bool], all elements are True.
Supported Platforms:
``GPU``
Examples:
>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = AdamWeightDecayOp(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
... {'params': no_conv_params, 'lr': 0.01},
... {'order_params': net.trainable_params()}]
>>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)
"""
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay)
_check_param_value(beta1, beta2, eps, self.cls_name)
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
self.eps = Tensor(np.array([eps]).astype(np.float32))
self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
self.hyper_map = C.HyperMap()
def construct(self, gradients):
"""AdamWeightDecayOp"""
lr = self.get_lr()
if self.is_group:
if self.is_group_lr:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
self.weight_decay, self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
else:
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
self.parameters, self.moments1, self.moments2,
gradients, self.decay_flags, self.optim_filter)
if self.use_parallel:
self.broadcast_params(optim_result)
return optim_result

Loading…
Cancel
Save