modify Tensor shape

pull/11935/head
lilei 4 years ago
parent b220b2185f
commit 63eb3ed2d9

@ -406,6 +406,11 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
return std::make_shared<Tensor>(data_type, GetShapeFromTuple(shape));
}),
py::arg("dtype"), py::arg("shape"))
.def(py::init([](const TypePtr &type_ptr, const py::list &shape) {
auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64;
return std::make_shared<Tensor>(data_type, GetShapeFromTuple(shape));
}),
py::arg("dtype"), py::arg("shape"))
.def(py::init([](const py::array &input, const TypePtr &type_ptr) {
return TensorPy::MakeTensor(input, type_ptr);
}),

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""Tensor implementation."""
import numbers
import numpy as np
from mindspore import log as logger
@ -43,6 +44,9 @@ class Tensor(Tensor_):
shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of
output. Default: None.
init (class:'Initializer'): the information of init data.
'init' is used for delayed initialization in parallel mode. Usually, it is not recommended to
use 'init' interface to initialize parameters in other conditions. If 'init' interface is used
to initialize parameters, the `init_data` API need to be called to convert `Tensor` to the actual data.
Outputs:
Tensor, with the same shape as `input_data`.
@ -76,6 +80,9 @@ class Tensor(Tensor_):
if ((input_data is not None and init is None) or (input_data is None and init is not None)) is False:
raise TypeError("input_data and init can not be None at the same time.")
if isinstance(shape, numbers.Number):
shape = (shape,)
# If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
if init is None:
validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool),
@ -575,6 +582,7 @@ class Tensor(Tensor_):
def init_data(self, slice_index=None, shape=None, opt_shard_group=None):
"""
Get the tensor format data of this Tensor.
The init_data function can be called once for the same tensor.
Args:
slice_index (int): Slice index of a parameter's slices.

Loading…
Cancel
Save