|
|
|
@ -24,7 +24,7 @@ _proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt")
|
|
|
|
|
|
|
|
|
|
@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "IndexedSlices", "Tensor",
|
|
|
|
|
"Tensor")
|
|
|
|
|
def _tensor_run_opt_with_sparse(opt, sparse_opt, learning_rate, l1, l2, gradient, weight, accum):
|
|
|
|
|
def _tensor_run_opt_with_sparse(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum):
|
|
|
|
|
"""Apply sparse proximal_ada_grad optimizer to the weight parameter."""
|
|
|
|
|
success = True
|
|
|
|
|
success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values(), gradient.indices()))
|
|
|
|
|