|
|
|
@ -18,6 +18,9 @@
|
|
|
|
|
from ..._checkparam import Rel
|
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
|
from ..._c_expression import signature_rw as sig_rw
|
|
|
|
|
from ..._c_expression import signature_kind as sig_kind
|
|
|
|
|
from ..._c_expression import signature_dtype as sig_dtype
|
|
|
|
|
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -330,6 +333,183 @@ class EmbeddingLookup(PrimitiveWithInfer):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SparseApplyFtrlNoReturn(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
Update relevant entries according to the FTRL-proximal scheme.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
lr (float): The learning rate value, must be positive.
|
|
|
|
|
l1 (float): l1 regularization strength, must be greater than or equal to zero.
|
|
|
|
|
l2 (float): l2 regularization strength, must be greater than or equal to zero.
|
|
|
|
|
lr_power (float): Learning rate power controls how the learning rate decreases during training,
|
|
|
|
|
must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero.
|
|
|
|
|
use_locking (bool): Use locks for update operation if True . Default: False.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **var** (Parameter): The variable to be updated. The data type must be float32.
|
|
|
|
|
- **accum** (Parameter): The accum to be updated, must be same type and shape as `var`.
|
|
|
|
|
- **linear** (Parameter): The linear to be updated, must be same type and shape as `var`.
|
|
|
|
|
- **grad** (Tensor): A tensor of the same type as `var`, for the gradient.
|
|
|
|
|
- **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. The shape
|
|
|
|
|
of `indices` must be the same as `grad` in first dimension. The type must be int32.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tuple of 3 Tensor, this operator will update the input parameters directly, the outputs are useless.
|
|
|
|
|
|
|
|
|
|
- **var** (Tensor) - A Tensor with shape (1,).
|
|
|
|
|
- **accum** (Tensor) - A Tensor with shape (1,).
|
|
|
|
|
- **linear** (Tensor) - A Tensor with shape (1,).
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> import mindspore
|
|
|
|
|
>>> import mindspore.nn as nn
|
|
|
|
|
>>> import numpy as np
|
|
|
|
|
>>> from mindspore import Parameter
|
|
|
|
|
>>> from mindspore import Tensor
|
|
|
|
|
>>> from mindspore.ops import operations as P
|
|
|
|
|
>>> class SparseApplyFtrlNet(nn.Cell):
|
|
|
|
|
>>> def __init__(self):
|
|
|
|
|
>>> super(SparseApplyFtrlNet, self).__init__()
|
|
|
|
|
>>> self.sparse_apply_ftrl = P.SparseApplyFtrlV2(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5)
|
|
|
|
|
>>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var")
|
|
|
|
|
>>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum")
|
|
|
|
|
>>> self.linear = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="linear")
|
|
|
|
|
>>>
|
|
|
|
|
>>> def construct(self, grad, indices):
|
|
|
|
|
>>> out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
|
|
|
|
|
>>> return out
|
|
|
|
|
>>>
|
|
|
|
|
>>> net = SparseApplyFtrlNet()
|
|
|
|
|
>>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
|
|
|
|
|
>>> indices = Tensor(np.array([0, 1]).astype(np.int32))
|
|
|
|
|
>>> output = net(grad, indices)
|
|
|
|
|
"""
|
|
|
|
|
__mindspore_signature__ = (
|
|
|
|
|
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
|
|
|
|
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
|
|
|
|
('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
|
|
|
|
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
|
|
|
|
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, lr, l1, l2, lr_power, use_locking=False):
|
|
|
|
|
self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'],
|
|
|
|
|
outputs=['output'])
|
|
|
|
|
validator.check_value_type("lr", lr, [float], self.name)
|
|
|
|
|
validator.check_value_type("l1", l1, [float], self.name)
|
|
|
|
|
validator.check_value_type("l2", l2, [float], self.name)
|
|
|
|
|
validator.check_value_type("lr_power", lr_power, [float], self.name)
|
|
|
|
|
self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name)
|
|
|
|
|
self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
|
|
|
|
self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name)
|
|
|
|
|
self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
|
|
|
|
|
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
|
|
|
|
self.add_prim_attr('primitive_target', 'CPU')
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
|
|
|
|
|
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
|
|
|
|
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
|
|
|
|
|
if len(var_shape) > 1:
|
|
|
|
|
validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
|
|
|
|
|
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
|
|
|
|
validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
|
|
|
|
|
return [1], [1], [1]
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
|
|
|
|
|
args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
|
|
|
|
|
"linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
|
|
|
|
|
validator.check_tensor_type_same(args, [mstype.float32], self.name)
|
|
|
|
|
validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name)
|
|
|
|
|
return var_dtype, accum_dtype, linear_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SparseApplyProximalAdagradNoReturn(PrimitiveWithInfer):
|
|
|
|
|
r"""
|
|
|
|
|
Updates relevant entries according to the proximal adagrad algorithm.
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
accum += grad * grad
|
|
|
|
|
.. math::
|
|
|
|
|
\text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}}
|
|
|
|
|
.. math::
|
|
|
|
|
var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **var** (Parameter) - Variable tensor to be updated. The data type must be float32.
|
|
|
|
|
- **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
|
|
|
|
|
- **lr** (Tensor): The learning rate value. The data type must be float32.
|
|
|
|
|
- **l1** (Tensor): l1 regularization strength. The data type must be float32.
|
|
|
|
|
- **l2** (Tensor): l2 regularization strength. The data type must be float32.
|
|
|
|
|
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32.
|
|
|
|
|
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The data type
|
|
|
|
|
must be int32.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tuple of 2 Tensor, this operator will update the input parameters directly, the outputs are useless.
|
|
|
|
|
|
|
|
|
|
- **var** (Tensor) - A Tensor with shape (1,).
|
|
|
|
|
- **accum** (Tensor) - A Tensor with shape (1,).
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> import numpy as np
|
|
|
|
|
>>> import mindspore.nn as nn
|
|
|
|
|
>>> from mindspore import Tensor, Parameter
|
|
|
|
|
>>> from mindspore.ops import operations as P
|
|
|
|
|
>>> class Net(nn.Cell):
|
|
|
|
|
>>> def __init__(self):
|
|
|
|
|
>>> super(Net, self).__init__()
|
|
|
|
|
>>> self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagradV2()
|
|
|
|
|
>>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var")
|
|
|
|
|
>>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum")
|
|
|
|
|
>>> self.lr = Tensor(0.01, mstype.float32)
|
|
|
|
|
>>> self.l1 = Tensor(0.0, mstype.float32)
|
|
|
|
|
>>> self.l2 = Tensor(0.0, mstype.float32)
|
|
|
|
|
>>> def construct(self, grad, indices):
|
|
|
|
|
>>> out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1,
|
|
|
|
|
>>> self.l2, grad, indices)
|
|
|
|
|
>>> return out
|
|
|
|
|
>>> net = Net()
|
|
|
|
|
>>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
|
|
|
|
|
>>> indices = Tensor(np.array([0, 1]).astype(np.int32))
|
|
|
|
|
>>> output = net(grad, indices)
|
|
|
|
|
"""
|
|
|
|
|
__mindspore_signature__ = (
|
|
|
|
|
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
|
|
|
|
('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
|
|
|
|
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
|
|
|
|
('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
|
|
|
|
('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
|
|
|
|
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
|
|
|
|
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, use_locking=False):
|
|
|
|
|
self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'],
|
|
|
|
|
outputs=['output'])
|
|
|
|
|
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
|
|
|
|
self.add_prim_attr('primitive_target', 'CPU')
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape):
|
|
|
|
|
validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
|
|
|
|
|
return [1], [1]
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype):
|
|
|
|
|
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
|
|
|
|
|
validator.check_tensor_type_same(args, [mstype.float32], self.name)
|
|
|
|
|
validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float32], self.name)
|
|
|
|
|
validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, [mstype.float32], self.name)
|
|
|
|
|
validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, [mstype.float32], self.name)
|
|
|
|
|
valid_types = [mstype.int16, mstype.int32, mstype.int64,
|
|
|
|
|
mstype.uint16, mstype.uint32, mstype.uint64]
|
|
|
|
|
validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name)
|
|
|
|
|
return var_dtype, accum_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LinSpace(PrimitiveWithInfer):
|
|
|
|
|
r"""
|
|
|
|
|
Generates values in an interval. And return the corresponding interpolation accroding to assist.
|
|
|
|
|