add bool expand dims and wait to test

pull/10853/head
Payne 4 years ago
parent b1ca1fbdb9
commit b12ca2165c

@ -93,6 +93,7 @@ env_type = typing.EnvType()
env_type_type = typing.EnvType
type_type = typing.TypeType()
type_none = typing.TypeNone()
type_bool = typing.Bool()
string = typing.String()
type_refkey = typing.RefKeyType()
tensor_type = typing.TensorType

@ -137,13 +137,37 @@ def _expand_data_dims_with_none(data, tuple_index, op_name):
none_type_tag = const_utils.judge_index_type(index_type, mstype.type_none)
tuple_index_without_none += (const_utils.make_empty_slice(),) if none_type_tag else(index,)
none_positions += (i,) if none_type_tag else ()
for dim in none_positions:
data = F.expand_dims(data, dim)
return data, tuple_index_without_none
def _expand_data_dims_with_bool(data, tuple_index, op_name):
"""expand the data's dim with 'True/False' in tuple_index"""
indexes_types = hyper_map(F.typeof, tuple_index)
bool_positions, tuple_index_without_bool = (), ()
for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)):
bool_type_tag = const_utils.judge_index_type(index_type, mstype.type_bool)
if bool_type_tag:
if index:
tuple_index_without_bool += (const_utils.make_tensor([0], mstype.int64),)
else:
# todo wait to complete the operations' support for zero dim-size, then could make 0 length tensor.
# to replace the 'False'
return const_utils.raise_index_error("When tensor is indexed by a tuple which contains bool object, "
"the value only support 'True'.")
else:
tuple_index_without_bool += (index,)
bool_positions += (i,) if bool_type_tag else ()
for dim in bool_positions:
data = F.expand_dims(data, dim)
return data, tuple_index_without_bool
def tensor_index_by_slice(data, slice_index):
"""Tensor getitem by a single slice"""
shape = F.shape(data)
@ -168,7 +192,7 @@ def _tensor_index_by_bool(data, bool_value):
"""Tensor getitem by a single bool value"""
if bool_value:
return F.expand_dims(data, 0)
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
return const_utils.make_tensor([], data.dtype, (0,) + F.shape(data))
def _tensor_index_by_integer(data, number):
@ -207,8 +231,11 @@ def tensor_index_by_tuple(data, tuple_index):
op_name = const_utils.TENSOR_GETITEM
if len(tuple_index) == 1:
return data[tuple_index[0]]
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims_with_bool(data, tuple_index, op_name)
indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
if contain_type == const_utils.ALL_TENSOR:
@ -228,8 +255,8 @@ def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
def _tensor_getitem_by_tuple_slice(data, tuple_index):
"""Tensor getitem by a tuple of slice"""
data_shape = F.shape(data)
begin_strides, end_strides, step_strides, shrink_axis_mask = \
const_utils.get_stride_info_from_tuple(data_shape, tuple_index)
begin_strides, end_strides, step_strides, shrink_axis_mask = const_utils.get_stride_info_from_tuple(
data_shape, tuple_index)
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
@ -259,8 +286,8 @@ def _generate_indices_from_tuple(data, tuple_index, op_name):
tuple_index_len = len(tuple_index)
tensor_indexes, slice_indexes = [], []
indexes_types = hyper_map(F.typeof, tuple_index)
slice_positions, _, _, int_positions, _, \
tensor_positions, sequence_positions = const_utils.get_pos_of_indexes_types(indexes_types, op_name)
slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
tuple_index_new = ()
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
@ -296,8 +323,8 @@ def _generate_indices_from_tuple(data, tuple_index, op_name):
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_len):
if i in tensor_positions:
transform_tensor = _transform_indexing_tensor(
broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i])
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
tuple_index_new[i])
final_index_tensors.append(transform_tensor)
if i in slice_positions:
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
@ -321,6 +348,7 @@ def _generate_updates_from_tuple(data, index, value, op_type):
value_types = hyper_map(F.typeof, value)
data_dtype = F.dtype(data)
value_elements_type = const_utils.check_value_elements(data_dtype, value_types)
if value_elements_type == const_utils.ALL_TENSOR:
value_shapes = hyper_map(F.shape, value)
shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM)

@ -73,6 +73,13 @@ def make_empty_slice():
return slice(None, None, None)
@constexpr
def make_tensor(data, data_type, data_shape=None):
if data_shape:
return Tensor(np.zeros(data_shape), data_type)
return Tensor(data, data_type)
@constexpr
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
"""Checks the shape and size of the sensor and value."""
@ -158,6 +165,36 @@ def check_indexes_types_valid(dtypes, target_type, op_name):
f"but got {dtype}.")
@constexpr
def get_pos_of_indexes_types(indexes_types, op_name):
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, tensor_positions, \
sequence_positions = [], [], [], [], [], [], []
for i, index_type in enumerate(indexes_types):
if isinstance(index_type, mstype.slice_type):
slice_positions.append(i)
elif isinstance(index_type, mstype.ellipsis_type):
ellipsis_positions.append(i)
elif isinstance(index_type, mstype.none_type):
none_positions.append(i)
elif isinstance(index_type, mstype.Int):
int_positions.append(i)
elif isinstance(index_type, mstype.bool_type):
bool_positions.append(i)
elif isinstance(index_type, mstype.tensor_type):
tensor_positions.append(i)
elif isinstance(index_type, (list, tuple)):
sequence_positions.append(i)
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.")
if len(ellipsis_positions) > 1:
raise IndexError(f"For '{op_name}, an index can only have a single ellipsis('...')")
return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \
tensor_positions, sequence_positions
def slice_expand(input_slices, shape):
"""
Converts slice to indices.
@ -293,13 +330,6 @@ def tuple_element_is_int(indexs):
return False
@constexpr
def tuple_index_tensor_cnt(types, op_name):
"""count the tensor type of types which contains the tuple elements' type."""
tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types)
return ALL_TENSOR if tensor_cnt == len(types) else NO_TENSOR if tensor_cnt == 0 else CONTAIN_TENSOR
@ constexpr
def tuple_index_int_cnt(types, op_name):
"""count the int type of types which contains the tuple elements' type."""
@ -344,6 +374,8 @@ def check_value_elements(data_dtype, types):
raise TypeError(
f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.")
# TODO to del
@ constexpr
def get_index_tensor_dtype(dtype):
@ -356,6 +388,7 @@ def get_index_tensor_dtype(dtype):
f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
# TODO to del
@ constexpr
def check_index_tensors_dtype(indexes_types, op_name):
"""Check a tuple of tensor data type."""
@ -366,6 +399,7 @@ def check_index_tensors_dtype(indexes_types, op_name):
return True
# TODO to del
@ constexpr
def check_index_tensor_dtype(index_type, op_name):
"""Check a tensor data type."""
@ -375,6 +409,7 @@ def check_index_tensor_dtype(index_type, op_name):
f"but got {index_type}.")
# TODO to del
@ constexpr
def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
"""Check tensors data type same."""
@ -645,36 +680,6 @@ def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_te
return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info)
@constexpr
def get_pos_of_indexes_types(indexes_types, op_name):
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, tensor_positions, \
sequence_positions = [], [], [], [], [], [], []
for i, index_type in enumerate(indexes_types):
if isinstance(index_type, mstype.slice_type):
slice_positions.append(i)
elif isinstance(index_type, mstype.ellipsis_type):
ellipsis_positions.append(i)
elif isinstance(index_type, mstype.none_type):
none_positions.append(i)
elif isinstance(index_type, mstype.Int):
int_positions.append(i)
elif isinstance(index_type, mstype.bool_type):
bool_positions.append(i)
elif isinstance(index_type, mstype.tensor_type):
tensor_positions.append(i)
elif isinstance(index_type, (list, tuple)):
sequence_positions.append(i)
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.")
if len(ellipsis_positions) > 1:
raise IndexError(f"For '{op_name}, an index can only have a single ellipsis('...')")
return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \
tensor_positions, sequence_positions
@ constexpr
def scalar_in_sequence(x, y):
"""Determine whether the scalar in the sequence."""

@ -14,6 +14,7 @@
# ============================================================================
""" test_tensor_slice """
import numpy as np
import pytest
from mindspore import Tensor
from mindspore import context
@ -48,7 +49,7 @@ def test_tensor_fancy_index_boolean_list():
net(input_me)
def test_tensor_fancy_integer_boolean_list_graph():
def test_tensor_fancy_index_integer_boolean_list_graph():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = [1, 2, True, False]
net = NetWorkFancyIndex(index)
@ -57,7 +58,7 @@ def test_tensor_fancy_integer_boolean_list_graph():
net(input_me)
def test_tensor_fancy_integer_list_mixed():
def test_tensor_fancy_index_integer_list_mixed():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = (1, [2, 1, 3], slice(1, 3, 1), ..., 4)
net = NetWorkFancyIndex(index)
@ -66,7 +67,7 @@ def test_tensor_fancy_integer_list_mixed():
net(input_me)
def test_tensor_fancy_integer_tuple_mixed():
def test_tensor_fancy_index_integer_tuple_mixed():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = (1, (2, 1, 3), slice(1, 3, 1), ..., 4)
net = NetWorkFancyIndex(index)
@ -75,10 +76,29 @@ def test_tensor_fancy_integer_tuple_mixed():
net(input_me)
def test_tensor_fancy_integer_list_tuple_mixed():
def test_tensor_fancy_index_integer_list_tuple_mixed():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = (1, [2, 1, 3], (3, 2, 1), slice(1, 3, 1), ..., 4)
net = NetWorkFancyIndex(index)
input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8)
input_me = Tensor(input_np, dtype=mstype.float32)
net(input_me)
def test_tensor_fancy_index_integer_list_tuple_bool_mixed():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = (1, [2, 1, 3], True, (3, 2, 1), slice(1, 3, 1), ..., True, 4)
net = NetWorkFancyIndex(index)
input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8)
input_me = Tensor(input_np, dtype=mstype.float32)
net(input_me)
def test_tensor_fancy_index_integer_list_tuple_bool_mixed_error():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = (1, [2, 1, 3], True, (3, 2, 1), slice(1, 3, 1), ..., False, 4)
net = NetWorkFancyIndex(index)
input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8)
input_me = Tensor(input_np, dtype=mstype.float32)
with pytest.raises(IndexError):
net(input_me)

Loading…
Cancel
Save