tensor assign with ellpsis

Signed-off-by: candanzg <zhangshucheng@huawei.com>
pull/721/head
candanzg 5 years ago
parent decc8404a9
commit e886a3182c

@ -22,11 +22,11 @@ from .parser import (Parser, create_obj_instance, generate_scope,
get_dataclass_attributes, get_dataclass_methods, get_dataclass_attributes, get_dataclass_methods,
get_module_namespace, get_obj_type, get_object_key, get_module_namespace, get_obj_type, get_object_key,
get_parse_method_of_class, get_scope_name, get_parse_method_of_class, get_scope_name,
is_class_member, parse_cb, resolve_symbol) is_class_member, parse_cb, resolve_symbol, create_ellipsis_obj)
from .serialize import * from .serialize import *
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type', 'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type',
'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol', 'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol',
'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj', 'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj',
'get_dataclass_methods', 'get_scope_name', 'create_slice_obj'] 'get_dataclass_methods', 'get_scope_name', 'create_slice_obj', 'create_ellipsis_obj']

@ -29,7 +29,7 @@ from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.api import _MindSporeFunction from mindspore.common.api import _MindSporeFunction
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
from ..utils import Slice from ..utils import Slice, Ellipsis_
# define return value # define return value
RET_SUCCESS = 0 RET_SUCCESS = 0
@ -70,6 +70,11 @@ parse_expr_statement_white_list = (
"append", "append",
) )
def create_ellipsis_obj():
"""Create Slice object"""
return Ellipsis_()
def create_slice_obj(start, end, step): def create_slice_obj(start, end, step):
"""Create Slice object""" """Create Slice object"""
return Slice(start, end, step) return Slice(start, end, step)

@ -110,3 +110,10 @@ class Slice:
start: int start: int
end: int end: int
step: int step: int
@dataclass
class Ellipsis_:
"""
Ellipsis class
"""

@ -80,6 +80,7 @@ const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope";
const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name"; const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name";
const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj"; const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
// define the common name // define the common name
const char NAMED_PRIMITIVE_ITER[] = "iter"; const char NAMED_PRIMITIVE_ITER[] = "iter";

@ -298,6 +298,12 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
} else if (abs_base->isa<AbstractRef>()) { } else if (abs_base->isa<AbstractRef>()) {
auto value = abs_base->cast<AbstractRefPtr>()->ref(); auto value = abs_base->cast<AbstractRefPtr>()->ref();
dic = ConvertAbstractToPython(value); dic = ConvertAbstractToPython(value);
} else if (abs_base->isa<AbstractEllipsis>()) {
auto arg_slice = dyn_cast<AbstractEllipsis>(abs_base);
std::vector<int> shape;
dic["shape"] = shape;
dic["dtype"] = arg_slice->BuildType();
dic["value"] = BuildValue(arg_slice->BuildValue());
} else if (abs_base->isa<AbstractTuple>()) { } else if (abs_base->isa<AbstractTuple>()) {
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base); auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
size_t len = arg_tuple->size(); size_t len = arg_tuple->size();

@ -98,6 +98,8 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
i++; i++;
} }
ret = rets; ret = rets;
} else if (value->isa<EllipsisObj>()) {
ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_ELLIPSIS);
} else if (value->isa<ValueSlice>()) { } else if (value->isa<ValueSlice>()) {
auto slice = value->cast<ValueSlicePtr>(); auto slice = value->cast<ValueSlicePtr>();
auto start = ValuePtrToPyData(slice->start()); auto start = ValuePtrToPyData(slice->start());

@ -20,7 +20,7 @@ import numpy as np
from ...primitive import constexpr from ...primitive import constexpr
from ....common.tensor import Tensor from ....common.tensor import Tensor
from ....common import dtype as mstype from ....common import dtype as mstype
from ...._extends.utils import Slice from ...._extends.utils import Slice, Ellipsis_
@constexpr @constexpr
def check_equal(param1, param2, msg="{},{}"): def check_equal(param1, param2, msg="{},{}"):
@ -29,31 +29,40 @@ def check_equal(param1, param2, msg="{},{}"):
raise ValueError(msg.format(param1, param2)) raise ValueError(msg.format(param1, param2))
return param1 return param1
@constexpr
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
"""Checks the shape and size of the sensor and value."""
if data_shape == value_shape or data_size == value_size or value_size == 1:
return True
raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape))
@constexpr @constexpr
def check_tensor_setitem_index(index, element_type=None): def check_tensor_setitem_index(index, element_type=None):
"""Checks tuple index type of tensor assignment.""" """Checks tuple index type of tensor assignment."""
if index is None: if index is None:
raise ValueError("Tensor's index cannot be None.") raise IndexError("Tensor's index cannot be None.")
# eg. Tensor[Slice] = u # eg. Tensor[Slice] = u
if isinstance(index, Slice): if isinstance(index, Slice):
return True return True
# eg. Tensor[tuple] = u # eg. Tensor[tuple] = u
if isinstance(index, tuple): if isinstance(index, tuple):
if not index: if not index:
raise ValueError("Tensor's index cannot be empty.") raise IndexError("Tensor's index cannot be empty.")
# eg. Tensor[tuple(Slice...)] = u # eg. Tensor[tuple(Slice...)] = u
if isinstance(index[0], (Slice, int)): if isinstance(index[0], (Slice, Ellipsis_, int)):
return True return True
raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0]))) raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0])))
# eg. Tensor[Tensor[dtype=bool]] = u # eg. Tensor[Tensor[dtype=bool]] = u
if index == mstype.tensor: if index == mstype.tensor:
if element_type is None or element_type != mstype.bool_: if element_type is None or element_type != mstype.bool_:
raise ValueError( raise TypeError(
"The index of tensor should be a bool type tensor. \ "The index of tensor should be a bool type tensor. "
{} type is not supported yet.".format(element_type)) "{} type is not supported yet.".format(element_type))
return True return True
raise ValueError("Index of type '{}' is not supported yet.".format(type(index))) raise IndexError("Index of type '{}' is not supported yet.".format(type(index)))
@constexpr @constexpr
@ -90,10 +99,18 @@ def slice_expand(input_slices, shape):
# Slice or tuple(Slice...) # Slice or tuple(Slice...)
if isinstance(input_slices, Slice): if isinstance(input_slices, Slice):
slices = (input_slices,) slices = (input_slices,)
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice): elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)):
slices = input_slices is_have_ellipsis = False
for _, element in enumerate(input_slices):
if isinstance(element, Ellipsis_):
is_have_ellipsis = True
break
if is_have_ellipsis:
slices = ellipsis2slice(input_slices, shape)
else:
slices = input_slices
else: else:
raise ValueError("Tensor's index type is not supported yet.") raise IndexError("Tensor's index type is not supported yet.")
for s in slices: for s in slices:
start = 0 if (s.start is None) else s.start start = 0 if (s.start is None) else s.start
@ -111,6 +128,26 @@ def slice_expand(input_slices, shape):
return begin, end, strides return begin, end, strides
def ellipsis2slice(input_, shape):
"""Converts ellipsis to slice."""
input_slice = input_
result = []
if isinstance(input_, Ellipsis_):
input_slice = (input_,)
ell_count = 0
for _, element in enumerate(input_slice):
if not isinstance(element, Ellipsis_):
result.append(element)
continue
ell_count += 1
if ell_count > 1:
raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, "
"but it is currently {}".format(input_slice))
for _ in range(len(shape) - len(input_slice) + 1):
result.append(Slice(None, None, None))
return tuple(result)
@constexpr @constexpr
def slice2indices(input_slices, shape): def slice2indices(input_slices, shape):
""" """
@ -139,7 +176,7 @@ def slice2indices(input_slices, shape):
def check_indices(indices_size, index): def check_indices(indices_size, index):
"""Checks indices whether is empty.""" """Checks indices whether is empty."""
if indices_size < 1: if indices_size < 1:
raise ValueError("The tensor's index is unreasonable. index:{}".format(index)) raise IndexError("The tensor's index is unreasonable. index:{}".format(index))
return indices_size return indices_size
@ -151,8 +188,8 @@ def check_indices_value_size(indices_size, value_size):
if value_size > 1: if value_size > 1:
if value_size != indices_size: if value_size != indices_size:
raise ValueError( raise ValueError(
"The value given to tensor does not match the index size. \ "The value given to tensor does not match the index size,"
value size:{}, indics size:{}".format(value_size, indices_size)) " value size:{}, indics size:{}".format(value_size, indices_size))
return value_size return value_size
@constexpr @constexpr
@ -168,8 +205,11 @@ def integer_to_indices(index, shape):
def tuple_element_is_slice(indexs): def tuple_element_is_slice(indexs):
"""Judges tuple element type.""" """Judges tuple element type."""
if not indexs: if not indexs:
raise ValueError("Tensor's index cannot be empty.") raise IndexError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple) and isinstance(indexs[0], Slice): if isinstance(indexs, tuple):
for _, ele in enumerate(indexs):
if not isinstance(ele, Slice):
return False
return True return True
return False return False
@ -177,7 +217,10 @@ def tuple_element_is_slice(indexs):
def tuple_element_is_int(indexs): def tuple_element_is_int(indexs):
"""Judges tuple element type.""" """Judges tuple element type."""
if not indexs: if not indexs:
raise ValueError("Tensor's index cannot be empty.") raise IndexError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple) and isinstance(indexs[0], int): if isinstance(indexs, tuple):
for _, ele in enumerate(indexs):
if not isinstance(ele, int):
return False
return True return True
return False return False

@ -254,10 +254,10 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value):
data_dtype = F.dtype(data) data_dtype = F.dtype(data)
indices_size = F.size(indices) indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, index) indices_size = mult_util.check_indices(indices_size, index)
update = F.fill(data_dtype, (indices_size,), 1) update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,)) condition_1d = F.scatter_nd(indices, update, (data_size,))
condition_1d = F.cast(condition_1d, mstype.bool_)
condition = F.reshape(condition_1d, data_shape) condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = None value_fill = None
value_size = F.size(value) value_size = F.size(value)
@ -336,10 +336,10 @@ def _tensor_indices_number(data, data_shape, index, indices, value):
data_dtype = F.dtype(data) data_dtype = F.dtype(data)
indices_size = F.size(indices) indices_size = F.size(indices)
indices_size = mult_util.check_indices(indices_size, index) indices_size = mult_util.check_indices(indices_size, index)
update = F.fill(data_dtype, (indices_size,), 1) update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,)) condition_1d = F.scatter_nd(indices, update, (data_size,))
condition_1d = F.cast(condition_1d, mstype.bool_)
condition = F.reshape(condition_1d, data_shape) condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = F.fill(data_dtype, (indices_size,), value) value_fill = F.fill(data_dtype, (indices_size,), value)
value_1d = F.scatter_nd(indices, value_fill, (data_size,)) value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape) u = F.reshape(value_1d, data_shape)
@ -360,3 +360,32 @@ def _tensor_setitem_with_int_v2(data, index, value):
data_shape = F.shape(data) data_shape = F.shape(data)
indices = mult_util.integer_to_indices(index, data_shape) indices = mult_util.integer_to_indices(index, data_shape)
return _tensor_indices_tensor(data, data_shape, index, indices, value) return _tensor_indices_tensor(data, data_shape, index, indices, value)
@setitem.register("Tensor", "Ellipsis", "Number")
def _tensor_setitem_with_ellipsis_v1(data, index, value):
"""Syntax: A[...] = number."""
data_shape = F.shape(data)
data_dtype = F.dtype(data)
return F.fill(data_dtype, data_shape, value)
@setitem.register("Tensor", "Ellipsis", "Tensor")
def _tensor_setitem_with_ellipsis_v2(data, index, value):
"""Syntax: A[...] = Tensor."""
result = None
data_shape = F.shape(data)
data_dtype = F.dtype(data)
data_size = F.size(data)
value_shape = F.shape(value)
value_size = F.size(value)
check_result = mult_util.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
if check_result:
if data_size == value_size:
result = F.reshape(value, data_shape)
result = F.cast(result, data_dtype)
elif value_size == 1:
param1 = F.fill(data_dtype, data_shape, 1)
param2 = F.cast(value, data_dtype)
result = F.tensor_mul(param1, param2)
return result

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save