!18 enable use float type learning rate in lars optimizer

Merge pull request !18 from gziyan/master
pull/18/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 352c6faf85

@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""lars optimizer""" """lars optimizer"""
from typing import Iterable
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common import Tensor
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from .optimizer import grad_scale from .optimizer import grad_scale
@ -111,7 +113,8 @@ class LARS(Cell):
self.gather = None self.gather = None
self.global_step = None self.global_step = None
self.axis = None self.axis = None
if not isinstance(self.learning_rate, float): if isinstance(self.learning_rate.default_input, Iterable) or \
(isinstance(self.learning_rate.default_input, Tensor) and self.learning_rate.default_input.dim() == 1):
self.dynamic_lr = True self.dynamic_lr = True
self.assignadd = P.AssignAdd() self.assignadd = P.AssignAdd()
self.gather = P.GatherV2() self.gather = P.GatherV2()
@ -124,7 +127,7 @@ class LARS(Cell):
lr = self.gather(self.learning_rate, self.global_step, self.axis) lr = self.gather(self.learning_rate, self.global_step, self.axis)
F.control_depend(lr, self.assignadd(self.global_step, 1)) F.control_depend(lr, self.assignadd(self.global_step, 1))
else: else:
lr = F.scalar_to_array(self.learning_rate) lr = self.learning_rate
if self.reciprocal_scale != 1.0: if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)

@ -46,7 +46,7 @@ class Net(nn.Cell):
return x return x
def test_lars(): def test_lars_multi_step_lr():
inputs = Tensor(np.ones([1, 64]).astype(np.float32)) inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32)) label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net() net = Net()
@ -61,3 +61,20 @@ def test_lars():
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer) train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label) _executor.compile(train_network, inputs, label)
def test_lars_float_lr():
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
lr = 0.1
SGD = Momentum(net.trainable_params(), lr, 0.9)
optimizer = LARS(SGD, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name,
lars_filter=lambda x: 'bn' not in x.name)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
Loading…
Cancel
Save