diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index c007097ca6..5bacb5741b 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -232,6 +232,11 @@ class Tensor(Tensor_): raise TypeError("virtual_flag must be bool.") self._virtual_flag = value + @staticmethod + def from_numpy(array): + """Convert numpy array to Tensor without copy data.""" + return Tensor(Tensor_.from_numpy(array)) + def asnumpy(self): """Convert tensor to numpy array.""" return Tensor_.asnumpy(self) diff --git a/tests/ut/python/ir/test_tensor.py b/tests/ut/python/ir/test_tensor.py index 09d2f2eaa8..9ed92b418d 100644 --- a/tests/ut/python/ir/test_tensor.py +++ b/tests/ut/python/ir/test_tensor.py @@ -480,6 +480,7 @@ def test_tensor_operation(): def test_tensor_from_numpy(): a = np.ones((2, 3)) t = ms.Tensor.from_numpy(a) + assert isinstance(t, ms.Tensor) assert np.all(t.asnumpy() == 1) # 't' and 'a' share same data. a[1] = 2 @@ -489,3 +490,6 @@ def test_tensor_from_numpy(): del a assert np.all(t.asnumpy()[0] == 1) assert np.all(t.asnumpy()[1] == 2) + with pytest.raises(TypeError): + # incorrect input. + t = ms.Tensor.from_numpy([1, 2, 3])