diff --git a/mindspore/nn/learning_rate_schedule.py b/mindspore/nn/learning_rate_schedule.py index b8cde1673b..181db58e44 100644 --- a/mindspore/nn/learning_rate_schedule.py +++ b/mindspore/nn/learning_rate_schedule.py @@ -24,10 +24,22 @@ from .._checkparam import Rel class LearningRateSchedule(Cell): + """Basic class of learning rate schedule.""" def __init__(self): super(LearningRateSchedule, self).__init__() def construct(self, global_step): + """ + Defines the computation to get the current learning rate. + + This method should be overridden by all subclasses. + + Note: + The output should be a Tensor of scalar. + + Inputs: + Tensor. The current step number. + """ raise NotImplementedError diff --git a/mindspore/nn/optim/proximal_ada_grad.py b/mindspore/nn/optim/proximal_ada_grad.py index 2ef320fd9c..616f070d32 100644 --- a/mindspore/nn/optim/proximal_ada_grad.py +++ b/mindspore/nn/optim/proximal_ada_grad.py @@ -24,7 +24,7 @@ _proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt") @_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "IndexedSlices", "Tensor", "Tensor") -def _tensor_run_opt_with_sparse(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum): +def _tensor_run_opt_with_sparse(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum): """Apply sparse proximal_ada_grad optimizer to the weight parameter.""" success = True success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values(), gradient.indices()))