diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 5968a07bba..a24166c5b5 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -65,6 +65,7 @@ static std::map tbe_func_adapter_map = { {"dropout_do_mask", "drop_out_do_mask"}, {"strided_slice", "strided_slice_d"}, {"strided_slice_grad", "strided_slice_grad_d"}, + {"sparse_apply_ftrl", "sparse_apply_ftrl_d"}, {"transpose", "transpose_d"}, {"fill", "fill_d"}, {"unsorted_segment_sum", "unsorted_segment_sum_d"}, diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 6d11ba3752..a85db03759 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -112,6 +112,9 @@ from .softplus_grad import _softplus_grad_tbe from .softmax_grad_ext import _softmax_grad_ext_tbe from .square import _square_tbe from .sqrt import _sqrt_tbe +from .sparse_apply_ftrl_d import _sparse_apply_ftrl_d +from .sparse_apply_proximal_adagrad import _sparse_apply_proximal_adagrad +from .apply_proximal_adagrad import _apply_proximal_adagrad from .transpose_d import _transpose_d_tbe from .unsorted_segment_sum import _unsorted_segment_sum_tbe from .logsoftmax_grad import _logsoftmax_grad_tbe diff --git a/mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py b/mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py new file mode 100644 index 0000000000..9099c6e24f --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ApplyProximalAdagrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +apply_proximal_adagrad_op_info = TBERegOp("ApplyProximalAdagrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("apply_proximal_adagrad.so") \ + .compute_cost(10) \ + .kernel_name("apply_proximal_adagrad") \ + .partial_flag(True) \ + .attr("use_locking", "optional", "bool", "true,false", "false") \ + .input(0, "var", False, "required", "all") \ + .input(1, "accum", False, "required", "all") \ + .input(2, "lr", False, "required", "all") \ + .input(3, "l1", False, "required", "all") \ + .input(4, "l2", False, "required", "all") \ + .input(5, "grad", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \ + .get_op_info() + + +@op_info_register(apply_proximal_adagrad_op_info) +def _apply_proximal_adagrad(): + """ApplyProximalAdagrad TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py b/mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py new file mode 100644 index 0000000000..a61f6174b9 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py @@ -0,0 +1,51 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SparseApplyFtrl op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +sparse_apply_ftrl_d_op_info = TBERegOp("SparseApplyFtrl") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("sparse_apply_ftrl.so") \ + .compute_cost(10) \ + .kernel_name("sparse_apply_ftrl") \ + .partial_flag(True) \ + .attr("lr", "required", "float", "all") \ + .attr("l1", "required", "float", "all") \ + .attr("l2", "required", "float", "all") \ + .attr("lr_power", "required", "float", "all") \ + .attr("use_locking", "optional", "bool", "true,false", "false") \ + .input(0, "var", False, "required", "all") \ + .input(1, "accum", False, "required", "all") \ + .input(2, "linear", False, "required", "all") \ + .input(3, "grad", False, "required", "all") \ + .input(4, "indices", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .output(1, "accum", False, "required", "all") \ + .output(2, "linear", False, "required", "all") \ + .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, + DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, + DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(sparse_apply_ftrl_d_op_info) +def _sparse_apply_ftrl_d(): + """SparseApplyFtrl TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py b/mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py new file mode 100644 index 0000000000..f665890c55 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py @@ -0,0 +1,101 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SparseApplyProximalAdagrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +sparse_apply_proximal_adagrad_op_info = TBERegOp("SparseApplyProximalAdagrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("sparse_apply_proximal_adagrad.so") \ + .compute_cost(10) \ + .kernel_name("sparse_apply_proximal_adagrad") \ + .partial_flag(True) \ + .attr("use_locking", "optional", "bool", "true,false", "false") \ + .input(0, "var", False, "required", "all") \ + .input(1, "accum", False, "required", "all") \ + .input(2, "lr", False, "required", "all") \ + .input(3, "l1", False, "required", "all") \ + .input(4, "l2", False, "required", "all") \ + .input(5, "grad", False, "required", "all") \ + .input(6, "indices", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, + DataType.F32_NCHW, DataType.F32_NCHW, DataType.I16_NCHW, DataType.F32_NCHW) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD, DataType.F32_5HD, DataType.I16_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, + DataType.F32_NHWC, DataType.F32_NHWC, DataType.I16_NHWC, DataType.F32_NHWC) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.I16_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, + DataType.F32_FracZ, DataType.F32_FracZ, DataType.I16_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, + DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, + DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.F32_NHWC) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, + DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, + DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, + DataType.F32_NHWC, DataType.F32_NHWC, DataType.I64_NHWC, DataType.F32_NHWC) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, + DataType.F32_FracZ, DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, + DataType.F32_NCHW, DataType.F32_NCHW, DataType.U16_NCHW, DataType.F32_NCHW) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD, DataType.F32_5HD, DataType.U16_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, + DataType.F32_NHWC, DataType.F32_NHWC, DataType.U16_NHWC, DataType.F32_NHWC) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.U16_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, + DataType.F32_FracZ, DataType.F32_FracZ, DataType.U16_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, + DataType.F32_NCHW, DataType.F32_NCHW, DataType.U32_NCHW, DataType.F32_NCHW) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD, DataType.F32_5HD, DataType.U32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, + DataType.F32_NHWC, DataType.F32_NHWC, DataType.U32_NHWC, DataType.F32_NHWC) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.U32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, + DataType.F32_FracZ, DataType.F32_FracZ, DataType.U32_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, + DataType.F32_NCHW, DataType.F32_NCHW, DataType.U64_NCHW, DataType.F32_NCHW) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD, DataType.F32_5HD, DataType.U64_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, + DataType.F32_NHWC, DataType.F32_NHWC, DataType.U64_NHWC, DataType.F32_NHWC) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.U64_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, + DataType.F32_FracZ, DataType.F32_FracZ, DataType.U64_FracZ, DataType.F32_FracZ) \ + .get_op_info() + + +@op_info_register(sparse_apply_proximal_adagrad_op_info) +def _sparse_apply_proximal_adagrad(): + """SparseApplyProximalAdagrad TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 39cd97e5c0..71c11f492d 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -68,7 +68,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, SmoothL1Loss, Softmax, Softplus, SoftmaxCrossEntropyWithLogits, ROIAlign, SparseSoftmaxCrossEntropyWithLogits, Tanh, - TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, + TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, + ApplyProximalAdagrad, SparseApplyProximalAdagrad, ApplyRMSProp, ApplyCenteredRMSProp) from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop from . import _quant_ops @@ -265,6 +266,9 @@ __all__ = [ "Round", "ApplyFtrl", "SpaceToBatch", + "SparseApplyFtrl", + "ApplyProximalAdagrad", + "SparseApplyProximalAdagrad", "BatchToSpace", "Atan2", "ApplyRMSProp", diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 358087ccb1..ed7237b04c 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2807,6 +2807,126 @@ class SparseApplyAdagrad(PrimitiveWithInfer): return var_type +class ApplyProximalAdagrad(PrimitiveWithInfer): + r""" + Update relevant entries according to the proximal adagrad algorithm. + + .. math:: + accum += grad * grad + .. math:: + prox_v = var - lr * grad * \frac{1}{\sqrt{accum}} + .. math:: + var = \frac{sign(prox_v)}{1 + lr * l2} * \max(\left| prox_v \right| - lr * l1, 0) + + Args: + use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False. + + Inputs: + - **var** (Tensor) - Variable to be updated. + - **accum** (Tensor) - Accum to be updated. The shape must be the same as `var`'s shape. + - **lr** (Union[Number, Tensor]): The learning rate value, must be positive. It should be + a scalar tensor or number. + - **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero. + It should be a scalar tensor or number. + - **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero. + It should be a scalar tensor or number. + - **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape. + + Outputs: + Tensor, has the same shape and type as `var`. + + Examples: + >>> var = Tensor(np.random.random((3, 3)), mindspore.float32) + >>> accum = Tensor(np.random.random((3, 3)), mindspore.float32) + >>> grad = Tensor(np.random.random((3, 3)), mindspore.float32) + >>> lr = 0.01 + >>> l1 = 0.0 + >>> l2 = 0.0 + >>> apply_proximal_ada_grad = P.ApplyProximalAdagrad() + >>> output = apply_proximal_ada_grad(var, accum, lr, l1, l2, grad) + """ + + @prim_attr_register + def __init__(self, use_locking=False): + self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'], outputs=['output']) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) + + def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape): + validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) + validator.check('var shape', var_shape, 'grad shape', grad_shape, Rel.EQ, self.name) + return var_shape + + def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype): + valid_types = [mstype.float16, mstype.float32] + args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} + validator.check_tensor_type_same(args, valid_types, self.name) + scalar_args = {"lr": lr_dtype, "l1": l1_dtype, "l2": l2_dtype} + validator.check_scalar_or_tensor_type_same(scalar_args, valid_types, self.name) + return var_dtype + + +class SparseApplyProximalAdagrad(PrimitiveWithInfer): + r""" + Update relevant entries according to the proximal adagrad algorithm. Compared with ApplyProximalAdagrad, + an additional index tensor is input. + + .. math:: + accum += grad * grad + .. math:: + prox_v = var - lr * grad * \frac{1}{\sqrt{accum}} + .. math:: + var = \frac{sign(prox_v)}{1 + lr * l2} * \max(\left| prox_v \right| - lr * l1, 0) + + Args: + use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False. + + Inputs: + - **var** (Tensor) - Variable tensor to be updated. + - **accum** (Tensor) - Variable tensor to be updated. The shape must be the same as `var`'s shape. + - **lr** (Union[Number, Tensor]): The learning rate value, must be positive. It should be + a scalar tensor or number. + - **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero. + It should be a scalar tensor or number. + - **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero. + It should be a scalar tensor or number. + - **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. + - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. + + Outputs: + Tensor, has the same shape and type as `var`. + + Examples: + >>> var = Tensor(np.random.random((3, 3)), mindspore.float32) + >>> accum = Tensor(np.random.random((3, 3)), mindspore.float32) + >>> grad = Tensor(np.random.random((3, 3)), mindspore.float32) + >>> indices = Tensor(np.ones((3,), np.int32)) + >>> lr = 0.01 + >>> l1 = 0.0 + >>> l2 = 0.0 + >>> sparse_apply_proximal_ada_grad = P.SparseApplyProximalAdagrad() + >>> output = sparse_apply_proximal_ada_grad(var, accum, lr, l1, l2, grad, indices) + """ + + @prim_attr_register + def __init__(self, use_locking=False): + self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'], + outputs=['output']) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) + + def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): + return var_shape + + def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): + args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} + validator.check_tensor_type_same(args, [mstype.float32], self.name) + scalar_args = {"lr": lr_dtype, "l1": l1_dtype, "l2": l2_dtype} + validator.check_scalar_or_tensor_type_same(scalar_args, [mstype.float32], self.name) + valid_types = [mstype.int16, mstype.int32, mstype.int64, + mstype.uint16, mstype.uint32, mstype.uint64] + validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name) + return var_dtype + + class LARSUpdate(PrimitiveWithInfer): """ Conduct lars (layer-wise adaptive rate scaling) update on the square sum of gradient. @@ -2963,6 +3083,85 @@ class ApplyFtrl(PrimitiveWithInfer): return var_type +class SparseApplyFtrl(PrimitiveWithInfer): + """ + Update relevant entries according to the FTRL-proximal scheme. + + Args: + lr (float): The learning rate value, must be positive. + l1 (float): l1 regularization strength, must be greater than or equal to zero. + l2 (float): l2 regularization strength, must be greater than or equal to zero. + lr_power (float): Learning rate power controls how the learning rate decreases during training, + must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero. + use_locking (bool): Use locks for update operation if True . Default: False. + + Inputs: + - **var** (Tensor): The variable to be updated. + - **accum** (Tensor): The accum to be updated, must be same type and shape as `var`. + - **linear** (Tensor): The linear to be updated, must be same type and shape as `var`. + - **grad** (Tensor): A tensor of the same type as `var`, for the gradient. + - **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. + The shape of `indices` must be the same as `grad` in first dimension. The type must be int32. + + Outputs: + - **var** (Tensor): Tensor, has the same shape and type as `var`. + - **accum** (Tensor): Tensor, has the same shape and type as `accum`. + - **linear** (Tensor): Tensor, has the same shape and type as `linear`. + + Examples: + >>> import mindspore + >>> import mindspore.nn as nn + >>> import numpy as np + >>> from mindspore import Parameter + >>> from mindspore import Tensor + >>> from mindspore.ops import operations as P + >>> class SparseApplyFtrlNet(nn.Cell): + >>> def __init__(self): + >>> super(SparseApplyFtrlNet, self).__init__() + >>> self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5) + >>> self.var = Parameter(Tensor(np.random.random(3, 3).astype(np.float32)), name="var") + >>> self.accum = Parameter(Tensor(np.random.random(3, 3).astype(np.float32)), name="accum") + >>> self.linear = Parameter(Tensor(np.random.random(3, 3).astype(np.float32)), name="linear") + >>> + >>> def construct(self, grad, indices): + >>> out = self.apply_ftrl(self.var, self.accum, self.linear, grad, indices) + >>> return out + >>> + >>> net = SparseApplyFtrlNet() + >>> grad = Tensor(np.random.random(3, 3).astype(np.float32)) + >>> indices = Tnsor(np.ones([3]), mindspore.float32) + >>> output = net(grad, indices) + """ + + @prim_attr_register + def __init__(self, lr, l1, l2, lr_power, use_locking=False): + validator.check_value_type("lr", lr, [float], self.name) + validator.check_value_type("l1", l1, [float], self.name) + validator.check_value_type("l2", l2, [float], self.name) + validator.check_value_type("lr_power", lr_power, [float], self.name) + self.lr = validator.check_number("lr", lr, 0.0, Rel.GT, self.name) + self.l1 = validator.check_number("l1", l1, 0.0, Rel.GE, self.name) + self.l2 = validator.check_number("l2", l2, 0.0, Rel.GE, self.name) + self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) + + def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape): + validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) + validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) + if len(var_shape) > 1: + validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) + validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) + return var_shape, accum_shape, linear_shape + + def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): + args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype, + "linear_dtype": linear_dtype, "grad_dtype": grad_dtype} + validator.check_tensor_type_same(args, [mstype.float32], self.name) + validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) + return var_dtype, accum_dtype, linear_dtype + + class ConfusionMulGrad(PrimitiveWithInfer): """ `output0` is the result of which input0 dot multily input1. @@ -3124,7 +3323,7 @@ class CTCLoss(PrimitiveWithInfer): >>> labels_indices = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int64) >>> labels_values = Tensor(np.array([2, 2]), mindspore.int32) >>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32) - >>> ctc_loss = P.CTCloss() + >>> ctc_loss = P.CTCLoss() >>> output = ctc_loss(inputs, labels_indices, labels_values, sequence_length) """ diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 2a915aafb7..9482d7b1ee 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -225,6 +225,46 @@ class ApplyFtrlNet(nn.Cell): out = self.apply_ftrl(self.var, self.accum, self.linear, grad, self.lr, self.l1, self.l2, self.lr_power) return out + +class SparseApplyFtrlNet(nn.Cell): + def __init__(self): + super(SparseApplyFtrlNet, self).__init__() + self.sparse_apply_ftrl = P.SparseApplyFtrl(lr=0.001, l1=0.0, l2=0.0, lr_power=-0.5) + self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var") + self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum") + self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear") + + def construct(self, grad, indices): + out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices) + return out + + +class SparseApplyProximalAdagradNet(nn.Cell): + def __init__(self): + super(SparseApplyProximalAdagradNet, self).__init__() + self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad() + self.lr = 0.01 + self.l1 = 0.0 + self.l2 = 0.0 + + def construct(self, var, accum, grad, indices): + out = self.sparse_apply_proximal_adagrad(var, accum, self.lr, self.l1, self.l2, grad, indices) + return out + + +class ApplyProximalAdagradNet(nn.Cell): + def __init__(self): + super(ApplyProximalAdagradNet, self).__init__() + self.apply_proximal_adagrad = P.ApplyProximalAdagrad() + self.lr = 0.01 + self.l1 = 0.0 + self.l2 = 0.0 + + def construct(self, var, accum, grad): + out = self.apply_proximal_adagrad(var, accum, self.lr, self.l1, self.l2, grad) + return out + + class ApplyRMSNet(nn.Cell): def __init__(self): super(ApplyRMSNet, self).__init__() @@ -982,6 +1022,18 @@ test_case_nn_ops = [ 'block': P.SparseApplyAdagrad(0.5), 'desc_inputs': [[3, 3], [3, 3], [3, 3], Tensor(np.ones((3,), np.int32))], 'skip': ['backward']}), + ('SparseApplyFtrl', { + 'block': SparseApplyFtrlNet(), + 'desc_inputs': [[3, 3], Tensor(np.ones((3,), np.int32))], + 'skip': ['backward']}), + ('ApplyProximalAdagrad', { + 'block': ApplyProximalAdagradNet(), + 'desc_inputs': [[3, 3], [3, 3], [3, 3]], + 'skip': ['backward']}), + ('SparseApplyProximalAdagrad', { + 'block': SparseApplyProximalAdagradNet(), + 'desc_inputs': [[3, 3], [3, 3], [3, 3], Tensor(np.ones((3,), np.int32))], + 'skip': ['backward']}), ('Flatten_1', { 'block': NetForFlatten(), 'desc_inputs': [Tensor(np.ones([2, 3, 4]).astype(np.int32)), Tensor(np.ones([2, 12]).astype(np.int32))],