From e70b2f5430a378f94ba821acd98535319596f87a Mon Sep 17 00:00:00 2001 From: guohongzilong <2713219276@qq.com> Date: Mon, 11 May 2020 14:03:08 +0800 Subject: [PATCH] add optimizer.get_lr_parameter() method --- mindspore/nn/optim/optimizer.py | 31 +++++++++++++++ .../test_optimize_with_parameter_groups.py | 38 +++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 6f7f60a216..d931e5a52f 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -257,6 +257,7 @@ class Optimizer(Cell): logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.") for param in group_param['params']: + validator.check_value_type("parameter", param, [Parameter], self.cls_name) if param in params_store: raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.") params_store.append(param) @@ -286,6 +287,36 @@ class Optimizer(Cell): F.control_depend(lr, self.assignadd(self.global_step, 1)) return lr + def get_lr_parameter(self, param): + """ + Get the learning rate of parameter. + + Args: + param (Union[Parameter, list[Parameter]]): The `Parameter` or list of `Parameter`. + + Returns: + Parameter, single `Parameter` or `list[Parameter]` according to the input type. + """ + if not isinstance(param, (Parameter, list)): + raise TypeError(f"The 'param' only support 'Parameter' or 'list' type.") + + if isinstance(param, list): + lr = [] + for p in param: + validator.check_value_type("parameter", p, [Parameter], self.cls_name) + if self.is_group_lr: + index = self.parameters.index(p) + lr.append(self.learning_rate[index]) + else: + lr.append(self.learning_rate) + else: + if self.is_group_lr: + index = self.parameters.index(param) + lr = self.learning_rate[index] + else: + lr = self.learning_rate + return lr + def construct(self, *hyper_params): raise NotImplementedError diff --git a/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py b/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py index 24ee9254a9..6755820488 100644 --- a/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py +++ b/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py @@ -210,3 +210,41 @@ def test_group_repeat_param(): {'params': no_conv_params}] with pytest.raises(RuntimeError): Adam(group_params, learning_rate=default_lr) + + +def test_get_lr_parameter_with_group(): + net = LeNet5() + conv_lr = 0.1 + default_lr = 0.3 + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + group_params = [{'params': conv_params, 'lr': conv_lr}, + {'params': no_conv_params, 'lr': default_lr}] + opt = SGD(group_params) + assert opt.is_group_lr is True + for param in opt.parameters: + lr = opt.get_lr_parameter(param) + assert lr.name == 'lr_' + param.name + + lr_list = opt.get_lr_parameter(conv_params) + for lr, param in zip(lr_list, conv_params): + assert lr.name == 'lr_' + param.name + + +def test_get_lr_parameter_with_no_group(): + net = LeNet5() + conv_weight_decay = 0.8 + + conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) + no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) + group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay}, + {'params': no_conv_params}] + opt = SGD(group_params) + assert opt.is_group_lr is False + for param in opt.parameters: + lr = opt.get_lr_parameter(param) + assert lr.name == opt.learning_rate.name + + params_error = [1, 2, 3] + with pytest.raises(TypeError): + opt.get_lr_parameter(params_error)