|
|
|
@ -14,7 +14,6 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""FTRL"""
|
|
|
|
|
from mindspore.ops import functional as F, composite as C, operations as P
|
|
|
|
|
from mindspore.common.initializer import initializer
|
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore.common import Tensor
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
@ -23,6 +22,8 @@ from mindspore._checkparam import Rel
|
|
|
|
|
from .optimizer import Optimizer, apply_decay, grad_scale
|
|
|
|
|
|
|
|
|
|
ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ftrl_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor")
|
|
|
|
|
def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment):
|
|
|
|
|
"""Apply ftrl optimizer to the weight parameter."""
|
|
|
|
@ -30,8 +31,10 @@ def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weig
|
|
|
|
|
success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power))
|
|
|
|
|
return success
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0,
|
|
|
|
|
prim_name=None):
|
|
|
|
|
"""Check param."""
|
|
|
|
|
validator.check_value_type("initial_accum", initial_accum, [float], prim_name)
|
|
|
|
|
validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name)
|
|
|
|
|
|
|
|
|
@ -104,7 +107,7 @@ class FTRL(Optimizer):
|
|
|
|
|
self.lr_power = lr_power
|
|
|
|
|
self.reciprocal_scale = 1.0 / loss_scale
|
|
|
|
|
self.weight_decay = weight_decay
|
|
|
|
|
self.decay_tf = tuple((lambda:True)() for x in self.parameters)
|
|
|
|
|
self.decay_tf = tuple((lambda: True)() for x in self.parameters)
|
|
|
|
|
self.hyper_map = C.HyperMap()
|
|
|
|
|
self.opt = P.ApplyFtrl(use_locking=use_locking)
|
|
|
|
|
self.one = Tensor(1, mstype.int32)
|
|
|
|
@ -118,5 +121,6 @@ class FTRL(Optimizer):
|
|
|
|
|
if self.reciprocal_scale != 1.0:
|
|
|
|
|
grads = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), grads)
|
|
|
|
|
lr = self.learning_rate
|
|
|
|
|
success = self.hyper_map(F.partial(ftrl_opt, self.opt, lr, self.l1, self.l2, self.lr_power), linear, grads, params, moments)
|
|
|
|
|
success = self.hyper_map(F.partial(ftrl_opt, self.opt, lr, self.l1, self.l2, self.lr_power),
|
|
|
|
|
linear, grads, params, moments)
|
|
|
|
|
return success
|
|
|
|
|