|
|
|
@ -13,12 +13,14 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""lars optimizer"""
|
|
|
|
|
from typing import Iterable
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
from mindspore.common import Tensor
|
|
|
|
|
from mindspore.common.initializer import initializer
|
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.ops import functional as F
|
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore.nn.cell import Cell
|
|
|
|
|
from .optimizer import grad_scale
|
|
|
|
|
|
|
|
|
@ -111,7 +113,8 @@ class LARS(Cell):
|
|
|
|
|
self.gather = None
|
|
|
|
|
self.global_step = None
|
|
|
|
|
self.axis = None
|
|
|
|
|
if not isinstance(self.learning_rate, float):
|
|
|
|
|
if isinstance(self.learning_rate.default_input, Iterable) or \
|
|
|
|
|
(isinstance(self.learning_rate.default_input, Tensor) and self.learning_rate.default_input.dim() == 1):
|
|
|
|
|
self.dynamic_lr = True
|
|
|
|
|
self.assignadd = P.AssignAdd()
|
|
|
|
|
self.gather = P.GatherV2()
|
|
|
|
@ -124,7 +127,7 @@ class LARS(Cell):
|
|
|
|
|
lr = self.gather(self.learning_rate, self.global_step, self.axis)
|
|
|
|
|
F.control_depend(lr, self.assignadd(self.global_step, 1))
|
|
|
|
|
else:
|
|
|
|
|
lr = F.scalar_to_array(self.learning_rate)
|
|
|
|
|
lr = self.learning_rate
|
|
|
|
|
if self.reciprocal_scale != 1.0:
|
|
|
|
|
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
|
|
|
|
|
|
|
|
|
|