enable use float type learning rate in lars optimizer

pull/18/head
Ziyan 5 years ago
parent 930a1fb0a8
commit 4cbcd8e907

@ -13,12 +13,14 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""lars optimizer""" """lars optimizer"""
from typing import Iterable
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common import Tensor
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from .optimizer import grad_scale from .optimizer import grad_scale
@ -111,7 +113,8 @@ class LARS(Cell):
self.gather = None self.gather = None
self.global_step = None self.global_step = None
self.axis = None self.axis = None
if not isinstance(self.learning_rate, float): if isinstance(self.learning_rate.default_input, Iterable) or \
(isinstance(self.learning_rate.default_input, Tensor) and self.learning_rate.default_input.dim() == 1):
self.dynamic_lr = True self.dynamic_lr = True
self.assignadd = P.AssignAdd() self.assignadd = P.AssignAdd()
self.gather = P.GatherV2() self.gather = P.GatherV2()
@ -124,7 +127,7 @@ class LARS(Cell):
lr = self.gather(self.learning_rate, self.global_step, self.axis) lr = self.gather(self.learning_rate, self.global_step, self.axis)
F.control_depend(lr, self.assignadd(self.global_step, 1)) F.control_depend(lr, self.assignadd(self.global_step, 1))
else: else:
lr = F.scalar_to_array(self.learning_rate) lr = self.learning_rate
if self.reciprocal_scale != 1.0: if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)

@ -46,7 +46,7 @@ class Net(nn.Cell):
return x return x
def test_lars(): def test_lars_multi_step_lr():
inputs = Tensor(np.ones([1, 64]).astype(np.float32)) inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32)) label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net() net = Net()
@ -61,3 +61,20 @@ def test_lars():
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer) train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label) _executor.compile(train_network, inputs, label)
def test_lars_float_lr():
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
lr = 0.1
SGD = Momentum(net.trainable_params(), lr, 0.9)
optimizer = LARS(SGD, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name,
lars_filter=lambda x: 'bn' not in x.name)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
Loading…
Cancel
Save