!1407 support mixed tensor index for tensor get item and set item and support in operator.

Merge pull request !1407 from zhangbuxue/support_mixed_tensor_for_tensor_get_item_and_tensor_set_item
pull/1407/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit ad279e90fd

@ -105,7 +105,7 @@ convert_object_map = {
T.ge: multitype_ops.greater_equal,
T.is_: F.is_,
T.is_not: F.is_not,
T.contains: F.in_dict,
T.contains: multitype_ops.in_,
T.not_contains: F.not_in_dict,
# system function

@ -474,6 +474,8 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init());
(void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init());
(void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init());
(void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init());
(void)py::class_<Ellipsis, Type, std::shared_ptr<Ellipsis>>(m_sub, "Ellipsis").def(py::init());
}));
const TypePtr kTypeExternal = std::make_shared<External>();

@ -95,6 +95,8 @@ string = typing.String()
type_refkey = typing.RefKeyType()
tensor_type = typing.TensorType
anything_type = typing.TypeAnything
slice_type = typing.Slice
ellipsis_type = typing.Ellipsis
number_type = (int8,
int16,

@ -37,6 +37,7 @@ from .logical_and_impl import logical_and
from .logical_or_impl import logical_or
from .logic_not_impl import logical_not
from .uadd_impl import uadd
from .in_impl import in_
__all__ = [
'add',
'sub',
@ -59,5 +60,6 @@ __all__ = [
'setitem',
'logical_and',
'logical_or',
'logical_not'
'logical_not',
'in_'
]

@ -0,0 +1,154 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""constexpr util"""
from . import _constexpr_utils as const_utils
from ... import functional as F
from ... import operations as P
from ...composite import base
from ....common import dtype as mstype
hyper_map = base.HyperMap()
pack = P.Pack(axis=-1)
def broadcast(broadcast_shape, x):
"""Broadcast tensor to the required shape."""
if F.shape(x) == broadcast_shape:
return x
multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape)
if multiples:
return F.tile(x, multiples)
return x
def transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x):
"""Transform indexing tensor to the required."""
x = broadcast(broadcast_shape, x)
return broadcast(final_shape, F.reshape(x, new_shape))
def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple of tensor."""
indices = None
check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name)
if check_index_tensor_number:
dtype_tuple = hyper_map(F.dtype, tuple_index)
check_dtypes = const_utils.check_index_tensors_dtype(dtype_tuple, op_name)
if check_dtypes:
shape_tuple = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name)
broadcast_tensors = hyper_map(F.partial(broadcast, broadcast_shape), tuple_index)
indices = pack(broadcast_tensors)
return indices
def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
indexes_types = hyper_map(F.typeof, tuple_index)
int_positions = const_utils.get_pos_of_int_index(indexes_types)
for i in int_positions:
tuple_index = F.tuple_setitem(tuple_index, i, F.scalar_to_tensor(tuple_index[i], mstype.int32))
indexes_types = hyper_map(F.typeof, tuple_index)
tensor_positions, slice_positions, ellipsis_position = \
const_utils.separate_mixed_tensors_index(indexes_types, op_name)
tensor_indexes = []
slice_indexes = []
for i in tensor_positions:
tensor_indexes.append(tuple_index[i])
for j in slice_positions:
slice_indexes.append(tuple_index[j])
data_shape = F.shape(data)
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \
const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape,
indexes_types,
tensor_indexes_shapes,
tensor_indexes_dtypes,
slice_indexes,
op_name)
slice_number = 0
final_index_tensors = []
tuple_index_size = len(tuple_index)
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_size):
if i in tensor_positions:
transform_tensor = transform_indexing_tensor(broadcast_shape,
final_shape,
index_tensor_new_shape,
tuple_index[i])
final_index_tensors.append(transform_tensor)
if i in slice_positions:
slice_tensor = const_utils.convert_slice_to_tensor(slice_number,
final_shape,
indexes_shapes_info,
op_name)
final_index_tensors.append(slice_tensor)
slice_number += 1
if i == ellipsis_position:
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(slice_number,
ellipsis_occupied_dims,
final_shape,
indexes_shapes_info,
op_name)
for ele in ellipsis_tensors:
final_index_tensors.append(ele)
slice_number += ellipsis_occupied_dims
indices = pack(final_index_tensors)
return indices
def generate_updates_from_scalar(data, indices, value, op_type):
"""Generate an updates tensor from a scalar."""
data_shape = F.shape(data)
indices_shape = F.shape(indices)
data_dtype = F.dtype(data)
return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type)
def generate_updates_from_tuple(data, index, value, op_type):
"""Generate an updates tensor from a tuple."""
value_types = hyper_map(F.typeof, value)
data_dtype = F.dtype(data)
value_elements_type = const_utils.check_value_elements(data_dtype, value_types)
if value_elements_type == const_utils.ALL_TENSOR:
value_shapes = hyper_map(F.shape, value)
shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM)
if shapes_same:
value = F.pack(value)
return generate_updates_from_tensor(data, index, value, op_type)
data_shape = F.shape(data)
index_shape = F.shape(index)
return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type)
def generate_updates_from_tensor(data, index, value, op_type):
"""Generate an updates tensor from a tensor."""
data_shape = F.shape(data)
index_shape = F.shape(index)
value_shape = F.shape(value)
data_dtype = F.dtype(data)
value_dtype = F.dtype(value)
updates_shape = value_shape
check_dtype_same = const_utils.check_tensors_dtype_same(data_dtype, value_dtype, const_utils.TENSOR_SETITEM)
if check_dtype_same:
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type)
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape)
if need_broadcast:
return broadcast(updates_shape, value)
return value

@ -14,11 +14,11 @@
# ============================================================================
"""Implementation for getitem."""
from . import _utils as multi_utils
from ..import base
from . import _compile_utils as compile_utils
from . import _constexpr_utils as const_utils
from .. import base
from ... import functional as F
from ....common import dtype as mstype
getitem = base.MultitypeFuncGraph('getitem')
"""
@ -227,7 +227,8 @@ def _tensor_getitem_by_tensor(data, tensor_index):
Outputs:
Tensor, element type is same as the element type of data.
"""
check_dtypes = multi_utils.check_tensor_dtype_valid(F.dtype(tensor_index), (mstype.int32, mstype.int64))
check_dtypes = const_utils.check_index_tensor_dtype(F.dtype(tensor_index),
const_utils.TENSOR_GETITEM)
result = None
if check_dtypes:
result = F.gather(data, tensor_index, 0)
@ -246,14 +247,13 @@ def _tensor_getitem_by_tuple(data, tuple_index):
Outputs:
Tensor, element type is same as the element type of data.
"""
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_GETITEM)
result = None
if index_elements_type == multi_utils.NO_TENSOR:
result = _tensor_slice(data, tuple_index)
if index_elements_type == multi_utils.ALL_TENSOR:
result = _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
return result
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return _tensor_slice(data, tuple_index)
if index_elements_type == const_utils.ALL_TENSOR:
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index)
@getitem.register("Tensor", "Ellipsis")
@ -273,6 +273,17 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
"""Tensor getitem by a tuple of tensor."""
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_GETITEM)
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result
def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
"""Tensor getitem by a tuple of mixed tensor."""
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result

@ -0,0 +1,101 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""in_impl"""
from . import _constexpr_utils as const_utils
from ... import functional as F
from ...composite import base
in_ = base.MultitypeFuncGraph("in")
"""
in_ is a metafuncgraph object which will determine if a in b
using ".register" decorator
"""
@in_.register("Number", "Tuple")
def _number_in_tuple(x, y):
"""
Determine if a number in tuple.
Args:
x (Number): x
y (tuple): y
Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.scalar_in_sequence(x, y)
@in_.register("Number", "List")
def _number_in_list(x, y):
"""
Determine if a number in list.
Args:
x (Number): x
y (list): y
Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.scalar_in_sequence(x, y)
@in_.register("String", "Tuple")
def _string_in_tuple(x, y):
"""
Determine if a str in a tuple.
Args:
x (str): x
y (tuple): y
Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.scalar_in_sequence(x, y)
@in_.register("String", "List")
def _string_in_list(x, y):
"""
Determine if a str in a list.
Args:
x (str): x
y (list): y
Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.scalar_in_sequence(x, y)
@in_.register("String", "Dictionary")
def _str_in_dict(x, y):
"""
Determine if a str in dict.
Args:
x: str
y: dict
Returns:
bool, if x in y return true, x not in y return false.
"""
return F.in_dict(x, y)

File diff suppressed because it is too large Load Diff

@ -1419,7 +1419,6 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name):
validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
validator.check_integer("len of input_x", len(x_shape), 1, Rel.GT, prim_name)
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name)
validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT, prim_name)
rank_base = len(x_shape[0])
N = len(x_shape)
out_shape = x_shape[0]

@ -33,9 +33,4 @@ class IdentityEC(IExectorComponent):
keyword.desc_inputs: self.inputs[keyword.desc_inputs],
keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs])
}
print("buxue------------------------------------------------")
print("inputs")
print(ret[keyword.desc_inputs])
print("outputs")
print(ret[keyword.result])
return ret

@ -19,9 +19,9 @@ import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
from tests.ut.python.ut_filter import non_graph_engine
from tests.mindspore_test_framework.mindspore_test import mindspore_test
from tests.mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
@ -133,7 +133,7 @@ def test_list_append_2():
class ListOperate(nn.Cell):
def __init__(self,):
def __init__(self, ):
super(ListOperate, self).__init__()
def construct(self, t, l):
@ -152,6 +152,20 @@ class ListOperate(nn.Cell):
return x
class InListNet(nn.Cell):
def __init__(self, ):
super(InListNet, self).__init__()
self.list_ = [1, 2, 3, 4, 5, "ok"]
def construct(self, x):
ret = x
if 2 in self.list_:
ret = x + x
if "ok" in self.list_:
ret = x - x
return ret
class AxisListNet(nn.Cell):
def __init__(self):
super(AxisListNet, self).__init__()
@ -204,10 +218,15 @@ test_case_ops = [
('AxisListDefault', {
'block': AxisListDefaultNet(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
('InList', {
'block': InListNet(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
]
test_case_lists = [test_case_ops]
test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)
# use -k to select certain testcast
# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm

@ -19,9 +19,9 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import dtype as mstype
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
from tests.ut.python.ut_filter import non_graph_engine
from tests.mindspore_test_framework.mindspore_test import mindspore_test
from tests.mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
@ -52,6 +52,20 @@ class NestTupleGraphNet(nn.Cell):
return self.layers[0][1](x)
class InTupleNet(nn.Cell):
def __init__(self, ):
super(InTupleNet, self).__init__()
self.tuple_ = (1, 2, 3, 4, 5, "ok")
def construct(self, x):
ret = x
if 2 in self.tuple_:
ret = x + x
if "ok" in self.tuple_:
ret = x - x
return ret
test_case_ops = [
('TupleGraph', {
'block': TupleGraphNet(),
@ -59,6 +73,9 @@ test_case_ops = [
('NestTupleGraph', {
'block': NestTupleGraphNet(),
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
('InTuple', {
'block': InTupleNet(),
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]})
]
test_case_lists = [test_case_ops]

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save