From 4cbcd8e90736158f54e0fceb87ec9262a0f85acc Mon Sep 17 00:00:00 2001 From: Ziyan Date: Mon, 30 Mar 2020 15:17:05 +0800 Subject: [PATCH] enable use float type learning rate in lars optimizer --- mindspore/nn/optim/lars.py | 9 ++++++--- tests/ut/python/nn/optim/test_lars.py | 19 ++++++++++++++++++- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/mindspore/nn/optim/lars.py b/mindspore/nn/optim/lars.py index cdfe45de62..a69057215d 100755 --- a/mindspore/nn/optim/lars.py +++ b/mindspore/nn/optim/lars.py @@ -13,12 +13,14 @@ # limitations under the License. # ============================================================================ """lars optimizer""" +from typing import Iterable from mindspore.common import dtype as mstype +from mindspore.common import Tensor from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.ops import functional as F -from mindspore.common.parameter import Parameter from mindspore.nn.cell import Cell from .optimizer import grad_scale @@ -111,7 +113,8 @@ class LARS(Cell): self.gather = None self.global_step = None self.axis = None - if not isinstance(self.learning_rate, float): + if isinstance(self.learning_rate.default_input, Iterable) or \ + (isinstance(self.learning_rate.default_input, Tensor) and self.learning_rate.default_input.dim() == 1): self.dynamic_lr = True self.assignadd = P.AssignAdd() self.gather = P.GatherV2() @@ -124,7 +127,7 @@ class LARS(Cell): lr = self.gather(self.learning_rate, self.global_step, self.axis) F.control_depend(lr, self.assignadd(self.global_step, 1)) else: - lr = F.scalar_to_array(self.learning_rate) + lr = self.learning_rate if self.reciprocal_scale != 1.0: gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) diff --git a/tests/ut/python/nn/optim/test_lars.py b/tests/ut/python/nn/optim/test_lars.py index 92d218a32b..17bbe69fe6 100644 --- a/tests/ut/python/nn/optim/test_lars.py +++ b/tests/ut/python/nn/optim/test_lars.py @@ -46,7 +46,7 @@ class Net(nn.Cell): return x -def test_lars(): +def test_lars_multi_step_lr(): inputs = Tensor(np.ones([1, 64]).astype(np.float32)) label = Tensor(np.zeros([1, 10]).astype(np.float32)) net = Net() @@ -61,3 +61,20 @@ def test_lars(): net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, optimizer) _executor.compile(train_network, inputs, label) + + +def test_lars_float_lr(): + inputs = Tensor(np.ones([1, 64]).astype(np.float32)) + label = Tensor(np.zeros([1, 10]).astype(np.float32)) + net = Net() + net.set_train() + loss = nn.SoftmaxCrossEntropyWithLogits() + + lr = 0.1 + SGD = Momentum(net.trainable_params(), lr, 0.9) + optimizer = LARS(SGD, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name, + lars_filter=lambda x: 'bn' not in x.name) + + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label) \ No newline at end of file