diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index 538bce1dbc..c8873039ab 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -45,8 +45,9 @@ class Embedding(Cell): Inputs: - **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The element of - the Tensor should be integer and not larger than vocab_size. else the corresponding embedding vector is zero - if larger than vocab_size. + the Tensor should be integer and not larger than vocab_size. else the corresponding embedding vector is zero + if larger than vocab_size. + Outputs: Tensor of shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`. diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 140d08a912..ac2cbe3e52 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -17,6 +17,7 @@ from .. import operations as P from ..operations import _grad_ops as G +from ..operations import _inner_ops as inner from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..functional import broadcast_gradient_args from .. import functional as F @@ -341,7 +342,7 @@ def get_bprop_sparse_gather_v2(self): return bprop -@bprop_getters.register(P.Range) +@bprop_getters.register(inner.Range) def get_bprop_range(self): """Generate bprop for Range""" diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index b5ca29f126..00c7db7cec 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -23,7 +23,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Diag, DiagPart, DType, ExpandDims, Eye, Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation, IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, - Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range, + Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, Shape, Size, Slice, Split, EmbeddingLookup, @@ -75,7 +75,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, - CheckValid, MakeRefKey, Partial, Depend, CheckBprop, ConfusionMatrix) + CheckValid, MakeRefKey, Partial, Depend, CheckBprop) from . import _quant_ops from ._quant_ops import * from .thor_ops import * @@ -303,13 +303,12 @@ __all__ = [ "Atan", "Atanh", "BasicLSTMCell", - "ConfusionMatrix", "BroadcastTo", - "Range", "DataFormatDimMap", "ApproximateEqual", "InplaceUpdate", "InTopK", + "DataFormatDimMap" ] __all__.extend(_quant_ops.__all__) diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 38f399316a..f2b589558e 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -15,9 +15,10 @@ """Inner operators.""" +from ..._checkparam import Rel from ..._checkparam import Validator as validator from ...common import dtype as mstype -from ..primitive import PrimitiveWithInfer, prim_attr_register +from ..primitive import PrimitiveWithInfer, prim_attr_register class ExtractImagePatches(PrimitiveWithInfer): @@ -96,3 +97,61 @@ class ExtractImagePatches(PrimitiveWithInfer): """infer dtype""" validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) return input_x + + +class Range(PrimitiveWithInfer): + r""" + Creates a sequence of numbers. + Set `input_x` as :math:`x_i` for each element, `output` as follows: + + .. math:: + \text{output}(x_i) = x_i * \text{delta} + \text{start} + + Args: + start (float): If `limit` is `None`, the value acts as limit in the range and first entry + defaults to `0`. Otherwise, it acts as first entry in the range. + limit (float): Acts as upper limit of sequence. If `None`, defaults to the value of `start` + while set the first entry of the range to `0`. It can not be equal to `start`. + delta (float): Increment of the range. It can not be equal to zero. Default: 1.0. + + Inputs: + - **input_x** (Tensor) - The assistant data. A `1-D` tensor of type float32 or int32. + + Outputs: + Tensor, has the same shape and dtype as `input_x`. + + Examples: + >>> range = P.Range(1.0, 8.0, 2.0) + >>> x = Tensor(np.array([1, 2, 3, 2]), mindspore.int32) + >>> range(x) + [3, 5, 7, 5] + """ + + @prim_attr_register + def __init__(self, start, limit=None, delta=1.0): + self.init_prim_io_names(inputs=['x'], outputs=['y']) + self.delta = validator.check_value_type("delta", delta, [float], self.name) + validator.check_value_type("start", start, [float], self.name) + if limit is None: + self.start = 0.0 + self.limit = start + self.add_prim_attr("start", self.start) + self.add_prim_attr("limit", self.limit) + else: + validator.check_value_type("limit", limit, [float], self.name) + validator.check('start', self.start, 'limit', self.limit, Rel.NE, self.name) + if self.delta == 0.0: + raise ValueError("The input of `delta` can not be equal to zero.") + if self.delta > 0.0 and self.start > self.limit: + raise ValueError(f"Limit should be greater than start when delta:{self.delta} is more than zero, " + f"but got start:{self.start}, limit:{self.limit}") + if self.delta < 0.0 and self.start < self.limit: + raise ValueError(f"Start should be greater than limit when delta:{self.delta} is less than zero, " + f"but got start:{self.start}, limit:{self.limit}") + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_dtype): + validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name) + return x_dtype diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 8eb7011154..86074954c7 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -556,64 +556,6 @@ class SparseGatherV2(GatherV2): """ -class Range(PrimitiveWithInfer): - r""" - Creates a sequence of numbers. - Set `input_x` as :math:`x_i` for each element, `output` as follows: - - .. math:: - \text{output}(x_i) = x_i * \text{delta} + \text{start} - - Args: - start (float): If `limit` is `None`, the value acts as limit in the range and first entry - defaults to `0`. Otherwise, it acts as first entry in the range. - limit (float): Acts as upper limit of sequence. If `None`, defaults to the value of `start` - while set the first entry of the range to `0`. It can not be equal to `start`. - delta (float): Increment of the range. It can not be equal to zero. Default: 1.0. - - Inputs: - - **input_x** (Tensor) - The assistant data. A `1-D` tensor of type float32 or int32. - - Outputs: - Tensor, has the same shape and dtype as `input_x`. - - Examples: - >>> range = P.Range(1.0, 8.0, 2.0) - >>> x = Tensor(np.array([1, 2, 3, 2]), mindspore.int32) - >>> range(x) - [3, 5, 7, 5] - """ - - @prim_attr_register - def __init__(self, start, limit=None, delta=1.0): - self.init_prim_io_names(inputs=['x'], outputs=['y']) - self.delta = validator.check_value_type("delta", delta, [float], self.name) - validator.check_value_type("start", start, [float], self.name) - if limit is None: - self.start = 0.0 - self.limit = start - self.add_prim_attr("start", self.start) - self.add_prim_attr("limit", self.limit) - else: - validator.check_value_type("limit", limit, [float], self.name) - validator.check('start', self.start, 'limit', self.limit, Rel.NE, self.name) - if self.delta == 0.0: - raise ValueError("The input of `delta` can not be equal to zero.") - if self.delta > 0.0 and self.start > self.limit: - raise ValueError(f"Limit should be greater than start when delta:{self.delta} is more than zero, " - f"but got start:{self.start}, limit:{self.limit}") - if self.delta < 0.0 and self.start < self.limit: - raise ValueError(f"Start should be greater than limit when delta:{self.delta} is less than zero, " - f"but got start:{self.start}, limit:{self.limit}") - - def infer_shape(self, x_shape): - return x_shape - - def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name) - return x_dtype - - class EmbeddingLookup(PrimitiveWithInfer): """ Returns a slice of input tensor based on the specified indices. diff --git a/tests/ut/python/ops/test_array_ops.py b/tests/ut/python/ops/test_array_ops.py index 8c266bf8c6..ae71579973 100644 --- a/tests/ut/python/ops/test_array_ops.py +++ b/tests/ut/python/ops/test_array_ops.py @@ -25,6 +25,7 @@ from mindspore import Tensor from mindspore.common import dtype as mstype from mindspore.nn import Cell from mindspore.ops import operations as P +from mindspore.ops.operations import _inner_ops as inner from mindspore.ops import prim_attr_register from mindspore.ops.primitive import PrimitiveWithInfer import mindspore.context as context @@ -286,19 +287,10 @@ class SpaceToBatchNDNet(Cell): return self.space_to_batch_nd(x) -class ConfusionMatrixNet(Cell): - def __init__(self): - super(ConfusionMatrixNet, self).__init__() - self.confusion_matrix = P.ConfusionMatrix(4, "int32") - - def construct(self, x, y): - return self.confusion_matrix(x, y) - - class RangeNet(Cell): def __init__(self): super(RangeNet, self).__init__() - self.range_ops = P.Range(1.0, 8.0, 2.0) + self.range_ops = inner.Range(1.0, 8.0, 2.0) def construct(self, x): return self.range_ops(x) @@ -344,9 +336,6 @@ test_case_array_ops = [ ('BatchToSpaceNDNet', { 'block': BatchToSpaceNDNet(), 'desc_inputs': [Tensor(np.random.rand(4, 1, 1, 1).astype(np.float16))]}), - ('ConfusionMatrixNet', { - 'block': ConfusionMatrixNet(), - 'desc_inputs': [Tensor([0, 1, 1, 3], ms.int32), Tensor([0, 1, 1, 3], ms.int32)]}), ('RangeNet', { 'block': RangeNet(), 'desc_inputs': [Tensor(np.array([1, 2, 3, 2]), ms.int32)]}), diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index b662537b7e..500bafe9ff 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -25,6 +25,7 @@ from mindspore.common import dtype as mstype from mindspore.ops import functional as F from mindspore.ops import operations as P from mindspore.ops.operations import _grad_ops as G +from mindspore.ops.operations import _inner_ops as inner from ..ut_filter import non_graph_engine from ....mindspore_test_framework.mindspore_test import mindspore_test from ....mindspore_test_framework.pipeline.forward.compile_forward \ @@ -1051,7 +1052,7 @@ test_case_nn_ops = [ 'desc_inputs': [[3, 1, 2], Tensor(np.array([0, 1]).astype(np.int32))], 'desc_bprop': [[2, 1, 2]]}), ('Range', { - 'block': P.Range(1.0, 5.0), + 'block': inner.Range(1.0, 5.0), 'desc_inputs': [Tensor(np.ones([10]).astype(np.float32))], 'desc_bprop': [[10]]}), ('UnsortedSegmentSum', { @@ -1454,7 +1455,7 @@ test_case_array_ops = [ 'desc_inputs': [(Tensor(np.array([1], np.float32)), Tensor(np.array([1], np.float32)), Tensor(np.array([1], np.float32)))], - 'desc_bprop': [[3,]]}), + 'desc_bprop': [[3, ]]}), ('Pack_0', { 'block': NetForPackInput(P.Pack()), 'desc_inputs': [[2, 2], [2, 2], [2, 2]], @@ -1527,7 +1528,7 @@ test_case_array_ops = [ Tensor(np.array([0, 1, 1]).astype(np.int32))], 'desc_bprop': [Tensor(np.array([[1, 2, 3], [4, 2, 1]]).astype(np.float32))]}), ('BroadcastTo', { - 'block': P.BroadcastTo((2,3)), + 'block': P.BroadcastTo((2, 3)), 'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.float32))], 'desc_bprop': [Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.float32))]}), ('InTopK', {