diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index 049d77c540..dd2e3dc14e 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -446,7 +446,9 @@ Tensor::Tensor(const Tensor &tensor) event_(tensor.event_), sync_status_(tensor.sync_status_), device_sync_(tensor.device_sync_), - padding_type_(tensor.padding_type()) {} + padding_type_(tensor.padding_type()) { + CheckShape(tensor.shape_); +} Tensor::Tensor(const Tensor &tensor, TypeId data_type) : MetaTensor(data_type, tensor.shape_), @@ -456,29 +458,43 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type) event_(tensor.event_), sync_status_(tensor.sync_status_), device_sync_(tensor.device_sync_), - padding_type_(tensor.padding_type()) {} + padding_type_(tensor.padding_type()) { + CheckShape(tensor.shape_); +} Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data) - : MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {} + : MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) { + CheckShape(shape); +} Tensor::Tensor(TypeId data_type, const ShapeVector &shape) - : Tensor(data_type, shape, MakeTensorData(data_type, shape)) {} + : Tensor(data_type, shape, MakeTensorData(data_type, shape)) { + CheckShape(shape); +} Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len) - : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, data_len)) {} + : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, data_len)) { + CheckShape(shape); +} Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type) - : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, src_data_type)) {} + : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, src_data_type)) { + CheckShape(shape); +} Tensor::Tensor(const std::vector &input, const TypePtr &data_type) : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {static_cast(input.size())}), data_(MakeTensorData(data_type_, shape_, input.data(), input.size())), - id_(MakeId()) {} + id_(MakeId()) { + CheckShape(shape_); +} Tensor::Tensor(const std::vector &input, const TypePtr &data_type) : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {static_cast(input.size())}), data_(MakeTensorData(data_type_, shape_, input.data(), input.size())), - id_(MakeId()) {} + id_(MakeId()) { + CheckShape(shape_); +} Tensor::Tensor(int64_t input, const TypePtr &data_type) : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {}), @@ -497,6 +513,7 @@ bool Tensor::operator==(const Tensor &tensor) const { bool Tensor::ValueEqual(const Tensor &tensor) const { return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_))); } + // assgin value to this tensor Tensor &Tensor::AssignValue(const Tensor &tensor) { if (this != &tensor) { @@ -573,6 +590,17 @@ std::string Tensor::ToStringRepr() const { return buf.str(); } +void Tensor::CheckShape(const ShapeVector &shape) const { + // Check tensor's shape, ignore one-dimensional tensor, including empty tensor with shape=(0,). + if (shape.size() > 1) { + for (const auto &s : shape) { + if (s == 0) { + MS_EXCEPTION(ValueError) << "Zero is not supported in the shape of Tensor !"; + } + } + } +} + void Tensor::data_sync(bool need_wait) const { if (need_wait) { Wait(); diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index 5d45b280af..97b9924caa 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -261,6 +261,8 @@ class Tensor : public MetaTensor { std::string ToStringRepr() const; + void CheckShape(const ShapeVector &shape) const; + bool is_init() const { return init_flag_; } void set_init_flag(bool flag) { init_flag_ = flag; } diff --git a/tests/st/ops/gpu/test_tensoradd.py b/tests/st/ops/gpu/test_tensoradd.py index 24958f10b2..b77f66d169 100644 --- a/tests/st/ops/gpu/test_tensoradd.py +++ b/tests/st/ops/gpu/test_tensoradd.py @@ -33,11 +33,6 @@ class TensroAdd(nn.Cell): self.add = P.TensorAdd() - self.x = Parameter(initializer( - Tensor(np.random.randn(2, 0).astype(np.float32)), [2, 0]), name='x') - self.y = Parameter(initializer( - Tensor(np.random.randn(2, 1).astype(np.float32)), [2, 1]), name='y') - self.x1 = Parameter(initializer( Tensor(np.arange(3).reshape(3).astype(np.float32)), [3]), name='x1') self.y1 = Parameter(initializer( @@ -55,20 +50,17 @@ class TensroAdd(nn.Cell): @ms_function def construct(self): - return ( - self.add(self.x, self.y), self.add(self.x1, self.y1), self.add(self.x2, self.y2), - self.add(self.x3, self.y3)) + return (self.add(self.x1, self.y1), self.add(self.x2, self.y2), self.add(self.x3, self.y3)) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_TensroAdd(): +def test_TensorAdd(): add = TensroAdd() output = add() - expect0 = np.array([]) - expect1 = np.array([2, 3, 4]) - expect2 = np.array( + expect0 = np.array([2, 3, 4]) + expect1 = np.array( [[[[0., 2., 4.], [6., 8., 10.], [12., 14., 16.]], @@ -96,7 +88,7 @@ def test_TensroAdd(): [[144., 146., 148.], [150., 152., 154.], [156., 158., 160.]]]]) - expect3 = np.array( + expect2 = np.array( [[[[0., 2., 4.], [6., 8., 10.], [12., 14., 16.]], @@ -128,4 +120,26 @@ def test_TensroAdd(): assert (output[0].asnumpy() == expect0).all() assert (output[1].asnumpy() == expect1).all() assert (output[2].asnumpy() == expect2).all() - assert (output[3].asnumpy() == expect3).all() + + +class TensorAdd2(nn.Cell): + def __init__(self): + super(TensorAdd2, self).__init__() + self.add = P.TensorAdd() + self.x = Parameter(initializer( + Tensor(np.random.randn(2, 0).astype(np.float32)), [2, 0]), name='x') + self.y = Parameter(initializer( + Tensor(np.random.randn(2, 1).astype(np.float32)), [2, 1]), name='y') + + @ms_function + def construct(self): + return self.add(self.x, self.y) + + +# Constructing a tensor with 0 in shape is not support, excluding empty tensor. +@pytest.mark.skip(reason='0 in shape is not support') +def test_TensorAdd_shape_has_zero(): + add = TensorAdd2() + output = add() + expect = np.array([]) + assert (output.asnumpy() == expect).all() diff --git a/tests/st/pynative/test_tensor_index.py b/tests/st/pynative/test_tensor_index.py index d1d496e034..2cc87fb1bd 100644 --- a/tests/st/pynative/test_tensor_index.py +++ b/tests/st/pynative/test_tensor_index.py @@ -786,8 +786,8 @@ def test_tensor_assign_exception(): t = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32) tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32) # Error for A[Slice] = Number - # 1. A[Slice] = Number, Slice error - with pytest.raises(IndexError): + # 1. A[Slice] = Number, 0 in shape + with pytest.raises(ValueError): net_e2(t, 2) # Error for A[Slice] = U, U is a Tensor diff --git a/tests/ut/python/ir/test_tensor.py b/tests/ut/python/ir/test_tensor.py index 9ed92b418d..f8fcb79763 100644 --- a/tests/ut/python/ir/test_tensor.py +++ b/tests/ut/python/ir/test_tensor.py @@ -68,6 +68,18 @@ def test_tensor(): assert t4.dtype == ms.int64 +def test_tensor_empty(): + t = ms.Tensor(np.ones(0), ms.float32) + assert isinstance(t, ms.Tensor) + assert t.shape == (0,) + + +def test_tensor_shape_has_zero(): + with pytest.raises(ValueError): + t = ms.Tensor(np.ones((1, 0)), ms.float32) + print(t) + + def test_tensor_type_float16(): t_float16 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float16)) assert isinstance(t_float16, ms.Tensor) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 068be69000..0279bb19bd 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -14,6 +14,7 @@ # ============================================================================ """ test ops """ import functools +import pytest import numpy as np @@ -770,6 +771,8 @@ class StridedSliceNet(nn.Cell): return out_0, out_1, out_2, out_3 +# Constructing a tensor with 0 in shape is not support, excluding empty tensor. +@pytest.mark.skip(reason='0 in shape is not support') def test_strided_slice_const(): class StridedSLiceConstNet(nn.Cell): """StridedSLiceConstNet net definition""" diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index 0856d6c12d..71abbdd2e2 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -464,8 +464,8 @@ def test_tensor_assign(): net(Ta, b, Tck) net2(t, b, tck) # Error for A[Slice] = Number - # 1. A[Slice] = Number, Slice error - with pytest.raises(IndexError): + # 1. A[Slice] = Number, 0 in shape + with pytest.raises(ValueError): net_e2(t, Tensor(2, mstype.int32)) # Error for A[Slice] = U, U is a Tensor