fix bug in sparse proximal ada grad

pull/3327/head
wangnan39@huawei.com 5 years ago
parent e09d50e4d6
commit 19762375a5

@ -24,10 +24,22 @@ from .._checkparam import Rel
class LearningRateSchedule(Cell): class LearningRateSchedule(Cell):
"""Basic class of learning rate schedule."""
def __init__(self): def __init__(self):
super(LearningRateSchedule, self).__init__() super(LearningRateSchedule, self).__init__()
def construct(self, global_step): def construct(self, global_step):
"""
Defines the computation to get the current learning rate.
This method should be overridden by all subclasses.
Note:
The output should be a Tensor of scalar.
Inputs:
Tensor. The current step number.
"""
raise NotImplementedError raise NotImplementedError

@ -24,7 +24,7 @@ _proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt")
@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "IndexedSlices", "Tensor", @_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "IndexedSlices", "Tensor",
"Tensor") "Tensor")
def _tensor_run_opt_with_sparse(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum): def _tensor_run_opt_with_sparse(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum):
"""Apply sparse proximal_ada_grad optimizer to the weight parameter.""" """Apply sparse proximal_ada_grad optimizer to the weight parameter."""
success = True success = True
success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values(), gradient.indices())) success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values(), gradient.indices()))

Loading…
Cancel
Save