From 63eb3ed2d99bd60daa8dc7fe79723d65a9bc0f55 Mon Sep 17 00:00:00 2001 From: lilei Date: Mon, 1 Feb 2021 15:26:41 +0800 Subject: [PATCH] modify Tensor shape --- mindspore/ccsrc/pybind_api/ir/tensor_py.cc | 5 +++++ mindspore/common/tensor.py | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc index 758da1d558..204db5479f 100644 --- a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc @@ -406,6 +406,11 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { return std::make_shared(data_type, GetShapeFromTuple(shape)); }), py::arg("dtype"), py::arg("shape")) + .def(py::init([](const TypePtr &type_ptr, const py::list &shape) { + auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64; + return std::make_shared(data_type, GetShapeFromTuple(shape)); + }), + py::arg("dtype"), py::arg("shape")) .def(py::init([](const py::array &input, const TypePtr &type_ptr) { return TensorPy::MakeTensor(input, type_ptr); }), diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index f49b8510cd..5ac9e05626 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """Tensor implementation.""" +import numbers import numpy as np from mindspore import log as logger @@ -43,6 +44,9 @@ class Tensor(Tensor_): shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of output. Default: None. init (class:'Initializer'): the information of init data. + 'init' is used for delayed initialization in parallel mode. Usually, it is not recommended to + use 'init' interface to initialize parameters in other conditions. If 'init' interface is used + to initialize parameters, the `init_data` API need to be called to convert `Tensor` to the actual data. Outputs: Tensor, with the same shape as `input_data`. @@ -76,6 +80,9 @@ class Tensor(Tensor_): if ((input_data is not None and init is None) or (input_data is None and init is not None)) is False: raise TypeError("input_data and init can not be None at the same time.") + if isinstance(shape, numbers.Number): + shape = (shape,) + # If input_data is tuple/list/numpy.ndarray, it's support in check_type method. if init is None: validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool), @@ -575,6 +582,7 @@ class Tensor(Tensor_): def init_data(self, slice_index=None, shape=None, opt_shard_group=None): """ Get the tensor format data of this Tensor. + The init_data function can be called once for the same tensor. Args: slice_index (int): Slice index of a parameter's slices.