diff --git a/mindspore/_extends/parse/__init__.py b/mindspore/_extends/parse/__init__.py index 9366b5a2d2..62ba2e5406 100644 --- a/mindspore/_extends/parse/__init__.py +++ b/mindspore/_extends/parse/__init__.py @@ -22,11 +22,11 @@ from .parser import (Parser, create_obj_instance, generate_scope, get_dataclass_attributes, get_dataclass_methods, get_module_namespace, get_obj_type, get_object_key, get_parse_method_of_class, get_scope_name, - is_class_member, parse_cb, resolve_symbol) + is_class_member, parse_cb, resolve_symbol, create_ellipsis_obj) from .serialize import * __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', 'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_obj_type', 'create_obj_instance', 'get_module_namespace', 'get_class_member_namespace_symbol', 'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'dump_obj', 'load_obj', - 'get_dataclass_methods', 'get_scope_name', 'create_slice_obj'] + 'get_dataclass_methods', 'get_scope_name', 'create_slice_obj', 'create_ellipsis_obj'] diff --git a/mindspore/_extends/parse/parser.py b/mindspore/_extends/parse/parser.py index d8039cd56a..34a3a6c59e 100644 --- a/mindspore/_extends/parse/parser.py +++ b/mindspore/_extends/parse/parser.py @@ -29,7 +29,7 @@ from mindspore.common.dtype import pytype_to_dtype from mindspore.common.api import _MindSporeFunction from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT -from ..utils import Slice +from ..utils import Slice, Ellipsis_ # define return value RET_SUCCESS = 0 @@ -70,6 +70,11 @@ parse_expr_statement_white_list = ( "append", ) +def create_ellipsis_obj(): + """Create Slice object""" + return Ellipsis_() + + def create_slice_obj(start, end, step): """Create Slice object""" return Slice(start, end, step) diff --git a/mindspore/_extends/utils.py b/mindspore/_extends/utils.py index d0457607b5..fecbf546f5 100644 --- a/mindspore/_extends/utils.py +++ b/mindspore/_extends/utils.py @@ -110,3 +110,10 @@ class Slice: start: int end: int step: int + + +@dataclass +class Ellipsis_: + """ + Ellipsis class + """ diff --git a/mindspore/ccsrc/pipeline/parse/parse_base.h b/mindspore/ccsrc/pipeline/parse/parse_base.h index a3ca67b60a..c7ce4e1196 100644 --- a/mindspore/ccsrc/pipeline/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/parse/parse_base.h @@ -80,6 +80,7 @@ const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope"; const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name"; const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj"; +const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj"; // define the common name const char NAMED_PRIMITIVE_ITER[] = "iter"; diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index 293f31707e..274f63844c 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -298,6 +298,12 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { } else if (abs_base->isa()) { auto value = abs_base->cast()->ref(); dic = ConvertAbstractToPython(value); + } else if (abs_base->isa()) { + auto arg_slice = dyn_cast(abs_base); + std::vector shape; + dic["shape"] = shape; + dic["dtype"] = arg_slice->BuildType(); + dic["value"] = BuildValue(arg_slice->BuildValue()); } else if (abs_base->isa()) { auto arg_tuple = dyn_cast(abs_base); size_t len = arg_tuple->size(); diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index df4a8656f5..edbfe8dc4c 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -98,6 +98,8 @@ py::object ValuePtrToPyData(const ValuePtr &value) { i++; } ret = rets; + } else if (value->isa()) { + ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_ELLIPSIS); } else if (value->isa()) { auto slice = value->cast(); auto start = ValuePtrToPyData(slice->start()); diff --git a/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py b/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py index 3a44b1e483..d008f96648 100644 --- a/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py +++ b/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py @@ -20,7 +20,7 @@ import numpy as np from ...primitive import constexpr from ....common.tensor import Tensor from ....common import dtype as mstype -from ...._extends.utils import Slice +from ...._extends.utils import Slice, Ellipsis_ @constexpr def check_equal(param1, param2, msg="{},{}"): @@ -29,31 +29,40 @@ def check_equal(param1, param2, msg="{},{}"): raise ValueError(msg.format(param1, param2)) return param1 + +@constexpr +def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): + """Checks the shape and size of the sensor and value.""" + if data_shape == value_shape or data_size == value_size or value_size == 1: + return True + raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape)) + + @constexpr def check_tensor_setitem_index(index, element_type=None): """Checks tuple index type of tensor assignment.""" if index is None: - raise ValueError("Tensor's index cannot be None.") + raise IndexError("Tensor's index cannot be None.") # eg. Tensor[Slice] = u if isinstance(index, Slice): return True # eg. Tensor[tuple] = u if isinstance(index, tuple): if not index: - raise ValueError("Tensor's index cannot be empty.") + raise IndexError("Tensor's index cannot be empty.") # eg. Tensor[tuple(Slice...)] = u - if isinstance(index[0], (Slice, int)): + if isinstance(index[0], (Slice, Ellipsis_, int)): return True - raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0]))) + raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0]))) # eg. Tensor[Tensor[dtype=bool]] = u if index == mstype.tensor: if element_type is None or element_type != mstype.bool_: - raise ValueError( - "The index of tensor should be a bool type tensor. \ - {} type is not supported yet.".format(element_type)) + raise TypeError( + "The index of tensor should be a bool type tensor. " + "{} type is not supported yet.".format(element_type)) return True - raise ValueError("Index of type '{}' is not supported yet.".format(type(index))) + raise IndexError("Index of type '{}' is not supported yet.".format(type(index))) @constexpr @@ -90,10 +99,18 @@ def slice_expand(input_slices, shape): # Slice or tuple(Slice...) if isinstance(input_slices, Slice): slices = (input_slices,) - elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice): - slices = input_slices + elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)): + is_have_ellipsis = False + for _, element in enumerate(input_slices): + if isinstance(element, Ellipsis_): + is_have_ellipsis = True + break + if is_have_ellipsis: + slices = ellipsis2slice(input_slices, shape) + else: + slices = input_slices else: - raise ValueError("Tensor's index type is not supported yet.") + raise IndexError("Tensor's index type is not supported yet.") for s in slices: start = 0 if (s.start is None) else s.start @@ -111,6 +128,26 @@ def slice_expand(input_slices, shape): return begin, end, strides +def ellipsis2slice(input_, shape): + """Converts ellipsis to slice.""" + input_slice = input_ + result = [] + if isinstance(input_, Ellipsis_): + input_slice = (input_,) + ell_count = 0 + for _, element in enumerate(input_slice): + if not isinstance(element, Ellipsis_): + result.append(element) + continue + ell_count += 1 + if ell_count > 1: + raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, " + "but it is currently {}".format(input_slice)) + for _ in range(len(shape) - len(input_slice) + 1): + result.append(Slice(None, None, None)) + return tuple(result) + + @constexpr def slice2indices(input_slices, shape): """ @@ -139,7 +176,7 @@ def slice2indices(input_slices, shape): def check_indices(indices_size, index): """Checks indices whether is empty.""" if indices_size < 1: - raise ValueError("The tensor's index is unreasonable. index:{}".format(index)) + raise IndexError("The tensor's index is unreasonable. index:{}".format(index)) return indices_size @@ -151,8 +188,8 @@ def check_indices_value_size(indices_size, value_size): if value_size > 1: if value_size != indices_size: raise ValueError( - "The value given to tensor does not match the index size. \ - value size:{}, indics size:{}".format(value_size, indices_size)) + "The value given to tensor does not match the index size," + " value size:{}, indics size:{}".format(value_size, indices_size)) return value_size @constexpr @@ -168,8 +205,11 @@ def integer_to_indices(index, shape): def tuple_element_is_slice(indexs): """Judges tuple element type.""" if not indexs: - raise ValueError("Tensor's index cannot be empty.") - if isinstance(indexs, tuple) and isinstance(indexs[0], Slice): + raise IndexError("Tensor's index cannot be empty.") + if isinstance(indexs, tuple): + for _, ele in enumerate(indexs): + if not isinstance(ele, Slice): + return False return True return False @@ -177,7 +217,10 @@ def tuple_element_is_slice(indexs): def tuple_element_is_int(indexs): """Judges tuple element type.""" if not indexs: - raise ValueError("Tensor's index cannot be empty.") - if isinstance(indexs, tuple) and isinstance(indexs[0], int): + raise IndexError("Tensor's index cannot be empty.") + if isinstance(indexs, tuple): + for _, ele in enumerate(indexs): + if not isinstance(ele, int): + return False return True return False diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py index 13d4a1ffce..2f44bdc5ba 100644 --- a/mindspore/ops/composite/multitype_ops/setitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -254,10 +254,10 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value): data_dtype = F.dtype(data) indices_size = F.size(indices) indices_size = mult_util.check_indices(indices_size, index) - update = F.fill(data_dtype, (indices_size,), 1) + update = F.fill(mstype.int32, (indices_size,), 1) condition_1d = F.scatter_nd(indices, update, (data_size,)) - condition_1d = F.cast(condition_1d, mstype.bool_) condition = F.reshape(condition_1d, data_shape) + condition = F.cast(condition, mstype.bool_) value_fill = None value_size = F.size(value) @@ -336,10 +336,10 @@ def _tensor_indices_number(data, data_shape, index, indices, value): data_dtype = F.dtype(data) indices_size = F.size(indices) indices_size = mult_util.check_indices(indices_size, index) - update = F.fill(data_dtype, (indices_size,), 1) + update = F.fill(mstype.int32, (indices_size,), 1) condition_1d = F.scatter_nd(indices, update, (data_size,)) - condition_1d = F.cast(condition_1d, mstype.bool_) condition = F.reshape(condition_1d, data_shape) + condition = F.cast(condition, mstype.bool_) value_fill = F.fill(data_dtype, (indices_size,), value) value_1d = F.scatter_nd(indices, value_fill, (data_size,)) u = F.reshape(value_1d, data_shape) @@ -360,3 +360,32 @@ def _tensor_setitem_with_int_v2(data, index, value): data_shape = F.shape(data) indices = mult_util.integer_to_indices(index, data_shape) return _tensor_indices_tensor(data, data_shape, index, indices, value) + + +@setitem.register("Tensor", "Ellipsis", "Number") +def _tensor_setitem_with_ellipsis_v1(data, index, value): + """Syntax: A[...] = number.""" + data_shape = F.shape(data) + data_dtype = F.dtype(data) + return F.fill(data_dtype, data_shape, value) + + +@setitem.register("Tensor", "Ellipsis", "Tensor") +def _tensor_setitem_with_ellipsis_v2(data, index, value): + """Syntax: A[...] = Tensor.""" + result = None + data_shape = F.shape(data) + data_dtype = F.dtype(data) + data_size = F.size(data) + value_shape = F.shape(value) + value_size = F.size(value) + check_result = mult_util.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) + if check_result: + if data_size == value_size: + result = F.reshape(value, data_shape) + result = F.cast(result, data_dtype) + elif value_size == 1: + param1 = F.fill(data_dtype, data_shape, 1) + param2 = F.cast(value, data_dtype) + result = F.tensor_mul(param1, param2) + return result diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index f713b1ea0c..32c4025368 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -103,6 +103,7 @@ class TensorAssignWithSliceError1(Cell): a[1:3:-1,::] = b return a + class TensorAssignWithSliceError2(Cell): def __init__(self): super(TensorAssignWithSliceError2, self).__init__() @@ -110,24 +111,29 @@ class TensorAssignWithSliceError2(Cell): def construct(self, a, b): a[1:3:-1] = b return a + + class TensorAssignWithSlice2(Cell): def __init__(self): super(TensorAssignWithSlice2, self).__init__() - def construct(self, a, b): + def construct(self, a, b, ck): a[1:5] = b a[3:4] = 5 a[-1:1:-1] = b a[-1:3:-1] = 5 a[::] = b a[::] = 9 - return a + z = a + ck + return z + + class TensorAssignWithSlice(Cell): def __init__(self): super(TensorAssignWithSlice, self).__init__() self.c = 2 - def construct(self, a, b): + def construct(self, a, b, ck): a[1:3,::] = b a[2:3:,3:] = b a[::] = b @@ -136,9 +142,10 @@ class TensorAssignWithSlice(Cell): a[::,::] = self.c a[2:3:,0:, 4:1:-1] = b a[2:3:,0:, 4:1:-1] = self.c - z = a + z = a + ck return z + def test_tensor_assign(): context.set_context(mode=context.GRAPH_MODE, save_graphs=True) net = TensorAssignWithSlice() @@ -146,95 +153,145 @@ def test_tensor_assign(): net_e1 = TensorAssignWithSliceError1() net_e2 = TensorAssignWithSliceError2() a = np.arange(60).reshape(3,4,5) - b = Tensor([1]) - Ta = Tensor(a) - Ta4d = Tensor(a.reshape(1,3,4,5)) - Tb= Tensor([1,3]) - Tc= Tensor([]) - t = Tensor([1, 2, 3, 4, 5, 6, 7, 8]) - net(Ta, b) - net2(t, b) + ck = np.arange(60).reshape(3,4,5) + b = Tensor([1], dtype=mstype.float32) + Ta = Tensor(a, dtype=mstype.float32) + Tck = Tensor(ck, dtype=mstype.float32) + Ta4d = Tensor(a.reshape(1,3,4,5), dtype=mstype.float32) + Ta4d_ck = Tensor(ck.reshape(1,3,4,5), dtype=mstype.float32) + Tb= Tensor([1,3], dtype=mstype.float32) + Tc= Tensor([], dtype=mstype.float32) + t = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32) + tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32) + net(Ta, b, Tck) + net2(t, b, tck) # Error for A[Slice] = Number # 1. A[Slice] = Number, Slice error - with pytest.raises(ValueError): + with pytest.raises(IndexError): net_e2(t, 2) # Error for A[Slice] = U, U is a Tensor # 1. A[Slice] = U, u.size is error with pytest.raises(ValueError): - net2(t, Tb) + net2(t, Tb, tck) # 2. A[Slice] = U, U is empty with pytest.raises(ValueError): - net2(t, Tc) + net2(t, Tc, tck) # 3. A[Slice] = U, U.size error with pytest.raises(ValueError): - net2(t, Tb) + net2(t, Tb, tck) # Error for A[Tuple(Slice...)] = Tensor # 1. A[Tuple(Slice...)] = U, U is empty with pytest.raises(ValueError): - net(Ta, Tc) + net(Ta, Tc, Tck) # 2. A[Tuple(Slice...)] = U, U.size error with pytest.raises(ValueError): - net(Ta, Tb) + net(Ta, Tb, Tck) # 3. A[Tuple(Slice...)] = U, Slice error - with pytest.raises(ValueError): + with pytest.raises(IndexError): net_e1(Ta, b) # Error for A[Tuple(Slice...)] = Number # 1. A[Tuple(Slice...)] = Number, Slice error - with pytest.raises(ValueError): + with pytest.raises(IndexError): net_e1(Ta, 2) net = TensorAssignWithInteger() # Error for A[Number] = scalar/Tensor # 1. A[Number] = U, U is a Tensor, u.size not match with pytest.raises(ValueError): - net(Ta, Tb) + net(Ta, Tb, Tck) with pytest.raises(ValueError): - net(Ta, Tc) + net(Ta, Tc, Tck) # 2. A[Number] = U, the number index error with pytest.raises(IndexError): - net(Ta4d, b) + net(Ta4d, b, Ta4d_ck) # Error for A[(n,m)] = scalar/Tensor # 1. A[(n,m)] = U, U is a tensor. u.size not match net = TensorAssignWithTupleInteger() with pytest.raises(ValueError): - net(Ta, Tc) + net(Ta, Tc, Tck) with pytest.raises(ValueError): - net(Ta, Tb) + net(Ta, Tb, Tck) # 2. A[(n,m)] = U, the number index error with pytest.raises(IndexError): - net(Ta4d, b) + net(Ta4d, b, Ta4d_ck) + + #Error for A[...] = U or A[1:, ...] = u + #1. A[...] = scalar/tensor + net = TensorAssignWithEllipsis() + net(Ta, Ta4d) + with pytest.raises(ValueError): + net(Ta, Tc) + with pytest.raises(ValueError): + net(Ta, Tb) + #2. A[::, 1:, ...] = scalar/tensor + net = TensorAssignWithTupleEllipsis() + net(Ta, b) + with pytest.raises(ValueError): + net(Ta, Tc) + with pytest.raises(ValueError): + net(Ta, Tb) + + +class TensorAssignWithTupleEllipsis2(Cell): + def __init__(self): + super(TensorAssignWithTupleEllipsis2, self).__init__() + def construct(self, a, b): + a[1:, ..., ::] = b + return a + + +class TensorAssignWithTupleEllipsis(Cell): + def __init__(self): + super(TensorAssignWithTupleEllipsis, self).__init__() + def construct(self, a, b): + a[:2, ...] = 1 + a[1:, ...] = b + return a + + +class TensorAssignWithEllipsis(Cell): + def __init__(self): + super(TensorAssignWithEllipsis, self).__init__() + def construct(self, a, b): + a[...] = 1 + a[...] = b + return a + class TensorAssignWithInteger(Cell): def __init__(self): super(TensorAssignWithInteger, self).__init__() - def construct(self, a, b): + def construct(self, a, b, ck): a[1] = 1 a[0] = b - return a + z = a + ck + return z class TensorAssignWithTupleInteger(Cell): def __init__(self): super(TensorAssignWithTupleInteger, self).__init__() - def construct(self, a, b): + def construct(self, a, b, ck): a[(1)] = 1 a[(1)] = b a[(1,1)] = b a[(1,1)] = 1 - return a + z = a + ck + return z class TensorAssignWithBoolTensorIndex(Cell): def __init__(self): super(TensorAssignWithBoolTensorIndex, self).__init__() - self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float64) + self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float32) + self.u_scalar = 5 - def construct(self, a, b, c, u_tensor, _scalar): - a[c] = u_scalar + def construct(self, a, b, c, u_tensor): + a[c] = self.u_scalar a[b] = u_tensor z = a + self.t return z @@ -252,15 +309,16 @@ class TensorAssignWithBoolTensorIndexError(Cell): class TensorAssignWithBoolTensorIndex2(Cell): def __init__(self): super(TensorAssignWithBoolTensorIndex2, self).__init__() - self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64) - self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float64) + self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float32) + self.t = Tensor(np.arange(60).reshape([3,4,5]), dtype = mstype.float32) + self.u_scalar = 5 - def construct(self, a, u_tensor, _scalar): + def construct(self, a, u_tensor): a[a > 8] = u_tensor - a[a >= 6] = u_scalar - a[a < 3] = u_scalar + a[a >= 6] = self.u_scalar + a[a < 3] = self.u_scalar a[a <= 5] = u_tensor - a[a == 5] = u_scalar + a[a == 5] = self.u_scalar z = a + self.t return z @@ -274,36 +332,41 @@ class TensorAssignWithBoolTensorIndex2Error(Cell): return a -a = np.random.uniform(1,10,[3,4,5]) +a = np.arange(60).reshape(3, 4, 5) +ck = np.arange(60).reshape(3, 4, 5) +a4 = np.arange(60).reshape(3, 2, 2, 5) b = a > 5 c = a < 3 -Ta = Tensor(a) +Ta = Tensor(a, dtype=mstype.float32) +Tck = Tensor(ck, dtype=mstype.float32) +Ta4 = Tensor(a4, dtype=mstype.float32) Tb = Tensor(b) Tc = Tensor(c) Td = Tensor([True, True]) -u_tensor = Tensor([1]) -u_tensor_error = Tensor([1, 2]) -t_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8]) +u_tensor = Tensor([1], dtype=mstype.float32) +u_tensor_error = Tensor([1, 2], dtype=mstype.float32) +t_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32) +tck_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32) u_scalar = 5 def test_tensor_assign_bool_index(): net1 = TensorAssignWithBoolTensorIndex() net2 = TensorAssignWithBoolTensorIndex2() - net1(Ta, Tb, Tc, u_tensor, u_scalar) - net1(Ta, Tb, Tc, u_tensor, u_scalar) - with pytest.raises(ValueError): - net1(Ta, Td, Tc, u_tensor, u_scalar) - with pytest.raises(ValueError): - net1(Ta, u_tensor, Tc, u_tensor, u_scalar) + net1(Ta, Tb, Tc, u_tensor) + net1(Ta, Tb, Tc, u_tensor) with pytest.raises(ValueError): - net1(Ta, Tb, Td, u_tensor, u_scalar) + net1(Ta, Td, Tc, u_tensor) + with pytest.raises(TypeError): + net1(Ta, u_tensor, Tc, u_tensor) with pytest.raises(ValueError): - net1(Ta, Tb, Ta, u_tensor, u_scalar) + net1(Ta, Tb, Td, u_tensor) + with pytest.raises(TypeError): + net1(Ta, Tb, Ta, u_tensor) with pytest.raises(ValueError): - net1(Ta, Tb, Tc, u_tensor_error, u_scalar) + net1(Ta, Tb, Tc, u_tensor_error) # net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar) with pytest.raises(ValueError): - net2(Ta, u_tensor_error, u_scalar) + net2(Ta, u_tensor_error) net3 = TensorAssignWithBoolTensorIndexError() with pytest.raises(AttributeError): net3(Ta, Tb, Tc, u_tensor) @@ -316,29 +379,41 @@ def test_tensor_assign_bool_index(): net4(Ta, u_scalar) test_cases = [ + ('TensorAssignWithTupleEllipsis2', { + 'block': TensorAssignWithTupleEllipsis2(), + 'desc_inputs': [Ta4, u_tensor], + }), + ('TensorAssignWithTupleEllipsis', { + 'block': TensorAssignWithTupleEllipsis(), + 'desc_inputs': [Ta, u_tensor], + }), + ('TensorAssignWithEllipsis', { + 'block': TensorAssignWithEllipsis(), + 'desc_inputs': [Ta, u_tensor], + }), ('TensorAssignWithTupleInteger', { 'block': TensorAssignWithTupleInteger(), - 'desc_inputs': [Ta, u_tensor], + 'desc_inputs': [Ta, u_tensor, Tck], }), ('TensorAssignWithInteger', { 'block': TensorAssignWithInteger(), - 'desc_inputs': [Ta, u_tensor], + 'desc_inputs': [Ta, u_tensor, Tck], }), ('TensorAssignWithSlice', { 'block': TensorAssignWithSlice(), - 'desc_inputs': [Ta, u_tensor], + 'desc_inputs': [Ta, u_tensor, Tck], }), ('TensorAssignWithSlice2', { 'block': TensorAssignWithSlice2(), - 'desc_inputs': [t_1d, u_tensor], + 'desc_inputs': [t_1d, u_tensor, tck_1d], }), ('TensorAssignWithBoolTensorIndex', { 'block': TensorAssignWithBoolTensorIndex(), - 'desc_inputs': [Ta, Tb, Tc, u_tensor, u_scalar], + 'desc_inputs': [Ta, Tb, Tc, u_tensor], }), ('TensorAssignWithBoolTensorIndex2', { 'block': TensorAssignWithBoolTensorIndex2(), - 'desc_inputs': [Ta, u_tensor, u_scalar], + 'desc_inputs': [Ta, u_tensor], }), ('SlicePositive', { 'block': NetWorkSlicePositive(),