From 8c5db0f9198a80a1d826ba9fa3f4b869cc664b5b Mon Sep 17 00:00:00 2001 From: buxue Date: Thu, 6 Aug 2020 14:49:42 +0800 Subject: [PATCH] add attr 'shape' and 'dtype' and interface 'asnumpy' for Tensor --- mindspore/ccsrc/utils/tensor_py.cc | 4 +-- mindspore/common/tensor.py | 42 +++++++++++++++++++++--------- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/mindspore/ccsrc/utils/tensor_py.cc b/mindspore/ccsrc/utils/tensor_py.cc index fe3b7deb6d..606b95527d 100644 --- a/mindspore/ccsrc/utils/tensor_py.cc +++ b/mindspore/ccsrc/utils/tensor_py.cc @@ -268,7 +268,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { }), py::arg("input"), py::arg("dtype") = nullptr) .def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag) - .def_property_readonly("dtype", &Tensor::Dtype, R"mydelimiter( + .def_property_readonly("_dtype", &Tensor::Dtype, R"mydelimiter( Get the tensor's data type. Returns: @@ -279,7 +279,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { >>> data.dtype Int32 )mydelimiter") - .def_property_readonly("shape", TensorPy::GetPyTupleShape, R"mydelimiter( + .def_property_readonly("_shape", TensorPy::GetPyTupleShape, R"mydelimiter( Get the tensor's shape. Returns: diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 106d0eccad..86e175a03b 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -208,13 +208,41 @@ class Tensor(Tensor_): return "Unknown Tensor type!" return str(self.asnumpy()) + @property + def shape(self): + """The shape of tensor.""" + return self._shape + + @property + def dtype(self): + """The dtype of tensor.""" + return self._dtype + + @property + def virtual_flag(self): + """Mark tensor is virtual.""" + return self._virtual_flag + + @virtual_flag.setter + def virtual_flag(self, value): + """The setter of virtual_flag.""" + if not isinstance(value, bool): + raise TypeError("virtual_flag must be bool.") + self._virtual_flag = value + + def asnumpy(self): + """Convert tensor to numpy array.""" + return Tensor_.asnumpy(self) + def all(self, axis=(), keep_dims=False): """ Check all array elements along a given axis evaluate to True. Args: axis (Union[None, int, tuple(int)): Dimensions of reduction. + Default: (), reduce all dimensions. keep_dims (bool): Whether to keep the reduced dimensions. + Default : False, don't keep these reduced dimensions. Returns: Tensor, has the same data type as x. @@ -228,7 +256,9 @@ class Tensor(Tensor_): Args: axis (Union[None, int, tuple(int)): Dimensions of reduction. + Default: (), reduce all dimensions. keep_dims (bool): Whether to keep the reduced dimensions. + Default : False, don't keep these reduced dimensions. Returns: Tensor, has the same data type as x. @@ -236,18 +266,6 @@ class Tensor(Tensor_): return tensor_operator_registry.get('any')(keep_dims)(self, axis) - @property - def virtual_flag(self): - """Mark tensor is virtual.""" - return self._virtual_flag - - @virtual_flag.setter - def virtual_flag(self, value): - """The setter of virtual_flag.""" - if not isinstance(value, bool): - raise TypeError("virtual_flag must be bool.") - self._virtual_flag = value - class IndexedSlices: """