diff --git a/mindspore/ccsrc/frontend/operator/prim_others.cc b/mindspore/ccsrc/frontend/operator/prim_others.cc index 52a444db31..0810222e43 100644 --- a/mindspore/ccsrc/frontend/operator/prim_others.cc +++ b/mindspore/ccsrc/frontend/operator/prim_others.cc @@ -387,16 +387,10 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be positive, but got " << dense_shape_vec[i]; } - if (i == 0) { - if (dense_shape_vec[i] < values_shp[i]) { - MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape should be greator or equal to the " << i - << "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i]; - } - } else { - if (dense_shape_vec[i] != values_shp[i]) { - MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i - << "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i]; - } + // The 0th mode might be less or exceed dense_shape[0] due to duplicated selection + if (i != 0 && dense_shape_vec[i] != values_shp[i]) { + MS_EXCEPTION(TypeError) << "The " << i << "th element of dense_shape must be same with the " << i + << "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i]; } } auto ret = std::make_shared(values->element()->BuildType(), dense_shape_vec); diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 3fff045fea..a6db3da76f 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -213,9 +213,73 @@ class Tensor(Tensor_): class IndexedSlices: + """ + A sparse representation of a set of tensor slices at given indices. + + An IndexedSlices is typically used to represent a subset of a larger + tensor dense of shape [LARGE0, D1, .. , DN] where LARGE0 >> D0. + The values in indices are the indices in the first dimension of the slices + that have been extracted from the larger tensor. + The dense tensor dense represented by an IndexedSlices slices has + `dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]`. + IndexedSlices can only be used in `Cell`'s contruct method. + + Args: + indices (Tensor): A 1-D integer Tensor of shape [D0]. + values (Tensor): A Tensor of any dtype of shape [D0, D1, ..., Dn]. + dense_shape: (tuple): A integer tuple containing the shape + of the corresponding dense tensor. + + Returns: + IndexedSlices, composed of `indices`, `values`, `dense_shape`. + + Examples: + >>> # Create a IndexedSlices. + >>> indices = Tensor([1, 2]) + >>> values = Tensor([[0, 0], [1, 2]], dtype=ms.float32) + >>> dense_shape = (3, 2) + >>> indexed_slices = IndexedSlices(indices, values, dense_shape) + >>> + >>> # Get atrr. + >>> indices = indexed_slices.indices() + >>> values = indexed_slices.values() + >>> dense_shape = indexed_slices.dense_shape() + """ def __init__(self, indices, values, dense_shape): raise NotImplementedError + class SparseTensor: + """ + A sparse representation of a set of nonzero elememts from a tensor at given indices. + + SparseTensor can only be used in `Cell`'s contruct method. + For a tensor dense, its SparseTensor(indices, values, dense_shape) has + `dense[indices[i]] = values[i]`. + + Args: + indices (Tensor): A 2-D integer Tensor of shape `[N, ndims]`, + where N and ndims are the number of values and number of dimensions in + the SparseTensor, respectively. + values (Tensor): A 1-D tensor of any type and shape `[N]`, which + supplies the values for each element in indices. + dense_shape: (tuple): A integer tuple of size `ndims`, + which specifies the dense_shape of the sparse tensor. + + Returns: + SparseTensor, composed of `indices`, `values`, `dense_shape`. + + Examples: + >>> # Create a SparseTensor. + >>> indices = Tensor([[0, 1], [1, 2]]) + >>> values = Tensor([1, 2], dtype=ms.float32) + >>> dense_shape = (3, 4) + >>> sparse_tensor = SparseTensor(indices, values, dense_shape) + >>> + >>> # Get atrr. + >>> indices = sparse_tensor.indices() + >>> values = sparse_tensor.values() + >>> dense_shape = sparse_tensor.dense_shape() + """ def __init__(self, indices, values, dense_shape): raise NotImplementedError diff --git a/tests/ut/python/parallel/test_embeddinglookup.py b/tests/ut/python/parallel/test_embeddinglookup.py index 22f0485285..db84ab26eb 100644 --- a/tests/ut/python/parallel/test_embeddinglookup.py +++ b/tests/ut/python/parallel/test_embeddinglookup.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================ import numpy as np -import pytest import mindspore as ms import mindspore.nn as nn @@ -100,7 +99,6 @@ def test_embeddinglookup_reducescatter_true_grad(): _executor.compile(net, x, y) -@pytest.mark.skip(reason="waiting for fix by parallel strategy") def test_embeddinglookup_semi_auto1(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") shape = [64, 32] @@ -115,7 +113,6 @@ def test_embeddinglookup_semi_auto1(): _executor.compile(net, x, y) -@pytest.mark.skip(reason="waiting for fix by parallel strategy") def test_embeddinglookup_semi_auto2(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") shape = [64, 32] diff --git a/tests/ut/python/parallel/test_sparse_gather_v2.py b/tests/ut/python/parallel/test_sparse_gather_v2.py index f9f430d1ee..2d4d0c2bf2 100644 --- a/tests/ut/python/parallel/test_sparse_gather_v2.py +++ b/tests/ut/python/parallel/test_sparse_gather_v2.py @@ -61,7 +61,6 @@ class Net(nn.Cell): return out -@pytest.mark.skip(reason="waiting for fix by parallel strategy") def test_gatherv2_semi_auto0(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((1, 8), (1, 1)) @@ -134,7 +133,6 @@ def test_gatherv2_semi_auto5(): _executor.compile(net, x, y) -@pytest.mark.skip(reason="waiting for fix by parallel strategy") def test_gatherv2_semi_auto6(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") strategy2 = ((4, 2, 1), (4, 2, 1)) @@ -169,7 +167,6 @@ def test_gatherv2_semi_auto8(): _executor.compile(net, x, y) -@pytest.mark.skip(reason="waiting for fix by parallel strategy") def test_gatherv2_auto0(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") net = GradWrap(NetWithLoss(Net(0)))