support for tensor indexing in pynative

support tensor slice using constexpr

remove tensorslice metagraph

add pynative testcases
pull/1920/head
huangdongrun 5 years ago
parent 8de8289cfd
commit 9522f59b87

@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope,
get_dataclass_attributes, get_dataclass_methods, get_obj_id,
get_module_namespace, get_obj_type, get_object_key,
get_default_input, get_parse_method_of_class, get_scope_name,
is_class_member, parse_cb, resolve_symbol, create_ellipsis_obj)
is_class_member, parse_cb, resolve_symbol)
from .serialize import *
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name',
'create_slice_obj', 'create_ellipsis_obj']
'create_slice_obj']

@ -29,7 +29,6 @@ from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.api import _MindSporeFunction
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
from ..utils import Slice, Ellipsis_
# define return value
RET_SUCCESS = 0
@ -70,14 +69,9 @@ parse_expr_statement_white_list = (
"append",
)
def create_ellipsis_obj():
"""Create Slice object"""
return Ellipsis_()
def create_slice_obj(start, end, step):
"""Create Slice object"""
return Slice(start, end, step)
"""Create slice object"""
return slice(start, end, step)
def parse_cb(func, parse_method=None):

@ -19,7 +19,6 @@ import logging
import os
import inspect
from functools import wraps
from dataclasses import dataclass
def cal_sha256(file_path):
@ -100,20 +99,3 @@ def cell_attr_register(fn=None, attrs=None):
if fn is not None:
return wrap_cell(fn)
return wrap_cell
@dataclass
class Slice:
"""
Slice class
"""
start: int
end: int
step: int
@dataclass
class Ellipsis_:
"""
Ellipsis class
"""

@ -932,206 +932,6 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
return ret;
}
int ConvertBinaryToDecimal(const std::vector<unsigned int> &number_bin) {
unsigned int number_dec = 0;
for (size_t index = 0; index < number_bin.size(); index++) {
number_dec |= number_bin[index] << index;
}
return static_cast<int>(number_dec);
}
void ParseSlice(const AbstractSlicePtr &slice, std::vector<int> *begin, std::vector<int> *end,
std::vector<int> *strides, int length) {
MS_EXCEPTION_IF_NULL(slice);
MS_EXCEPTION_IF_NULL(begin);
MS_EXCEPTION_IF_NULL(end);
MS_EXCEPTION_IF_NULL(strides);
if (length <= 0) {
MS_LOG(EXCEPTION) << "Could not slice a dim when it's length less than 1";
}
int start_default = 0;
int stop_default = length;
int step_default = 1;
int step_value = CheckSliceMember(slice->step(), step_default, "step");
if (step_value < 0) {
start_default = -1;
stop_default = -(length + 1);
}
begin->push_back(CheckSliceMember(slice->start(), start_default, "begin"));
end->push_back(CheckSliceMember(slice->stop(), stop_default, "stop"));
strides->push_back(step_value);
}
int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, const std::vector<int> &shape,
std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) {
MS_EXCEPTION_IF_NULL(slice_tuple);
MS_EXCEPTION_IF_NULL(begin);
MS_EXCEPTION_IF_NULL(end);
MS_EXCEPTION_IF_NULL(strides);
size_t slice_tuple_size = slice_tuple->size();
size_t shape_size = shape.size();
if (slice_tuple_size > shape_size) {
MS_LOG(EXCEPTION) << "The number of slice data to slice tensor should be less than the rank of tensor,"
"when the rank of tensor is "
<< shape_size << ", the number of slice is " << slice_tuple_size;
}
std::vector<unsigned int> shrink;
auto slice_tuple_eles = slice_tuple->elements();
size_t ellipsis_num = 0;
for (size_t index = 0; index < slice_tuple_size; index++) {
if (slice_tuple_eles[index]->isa<AbstractSlice>()) {
AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]);
ParseSlice(slice, begin, end, strides, shape[index]);
shrink.push_back(0);
continue;
}
if (slice_tuple_eles[index]->isa<AbstractScalar>()) {
int ele_index = GetArgScalarValue(dyn_cast<AbstractScalar>(slice_tuple_eles[index]), "slice_tuple");
begin->push_back(ele_index);
end->push_back(ele_index + 1);
strides->push_back(1);
shrink.push_back(1);
continue;
}
if (slice_tuple_eles[index]->isa<AbstractEllipsis>()) {
ellipsis_num++;
if (ellipsis_num > 1) {
MS_LOG(EXCEPTION) << "Tensor slice supports at most one ellipsis";
}
size_t ellipsis_len = shape_size - (slice_tuple_size - 1);
begin->insert(begin->end(), ellipsis_len, 0);
end->insert(end->end(), shape.begin() + index, shape.begin() + index + ellipsis_len);
strides->insert(strides->end(), ellipsis_len, 1);
shrink.insert(shrink.end(), ellipsis_len, 0);
continue;
}
MS_LOG(EXCEPTION) << "Slice tuple only could contain slice, int number or ellipsis, but got "
<< slice_tuple_eles[index]->ToString();
}
if (ellipsis_num == 0) {
for (size_t index = slice_tuple_size; index < shape_size; index++) {
begin->push_back(0);
end->push_back(shape[index]);
strides->push_back(1);
}
}
return ConvertBinaryToDecimal(shrink);
}
int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr &slice, const std::vector<int> &shape,
std::vector<int> *begin, std::vector<int> *end, std::vector<int> *strides) {
MS_EXCEPTION_IF_NULL(begin);
MS_EXCEPTION_IF_NULL(end);
MS_EXCEPTION_IF_NULL(strides);
size_t shape_size = shape.size();
if (shape_size == 0) {
MS_LOG(EXCEPTION) << "Could slice a scalar tensor";
}
ParseSlice(slice, begin, end, strides, shape[0]);
for (size_t index = 1; index < shape_size; index++) {
begin->push_back(0);
end->push_back(shape[index]);
strides->push_back(1);
}
return 0;
}
int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, const std::vector<int> &shape,
std::vector<int> *begin, std::vector<int> *end,
std::vector<int> *strides) {
MS_EXCEPTION_IF_NULL(begin);
MS_EXCEPTION_IF_NULL(end);
MS_EXCEPTION_IF_NULL(strides);
int ele_index = GetArgScalarValue(scalar, "slice_tuple");
begin->push_back(ele_index);
end->push_back(ele_index + 1);
strides->push_back(1);
for (size_t index = 1; index < shape.size(); index++) {
begin->push_back(0);
end->push_back(shape[index]);
strides->push_back(1);
}
return 1;
}
FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) {
auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
return ret_graph;
}
FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// slice a tensor
// args: tensor, slice or slice tuple
const std::string op_name = std::string("TensorSlice");
abstract::CheckArgsSize(op_name, args_spec_list, 2);
AbstractTensorPtr tensorPtr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr tensor_node = ret_graph->add_parameter();
(void)ret_graph->add_parameter();
auto shape = tensorPtr->shape()->shape();
std::vector<int> begin;
std::vector<int> end;
std::vector<int> strides;
int shrink_axis_mask;
if (args_spec_list[1]->isa<AbstractTuple>()) {
AbstractTuplePtr tuple_ptr = dyn_cast<AbstractTuple>(args_spec_list[1]);
shrink_axis_mask = GenerateStridedSliceParametersFromTuple(tuple_ptr, shape, &begin, &end, &strides);
} else if (args_spec_list[1]->isa<AbstractSlice>()) {
AbstractSlicePtr slice_ptr = dyn_cast<AbstractSlice>(args_spec_list[1]);
shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides);
} else if (args_spec_list[1]->isa<AbstractScalar>()) {
AbstractScalarPtr scalar_ptr = dyn_cast<AbstractScalar>(args_spec_list[1]);
if (scalar_ptr->BuildValue()->isa<BoolImm>()) {
if (scalar_ptr->BuildValue()->cast<BoolImmPtr>()->value()) {
return ExpandADim(ret_graph, tensor_node);
}
MS_LOG(EXCEPTION) << "TensorSlice not support the index is False.";
}
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
} else if (args_spec_list[1]->isa<AbstractEllipsis>()) {
ret_graph->set_output(tensor_node);
return ret_graph;
} else if (args_spec_list[1]->isa<AbstractNone>()) {
return ExpandADim(ret_graph, tensor_node);
} else {
std::ostringstream args_info;
for (const auto &arg : args_spec_list) {
MS_EXCEPTION_IF_NULL(arg);
args_info << arg->ToString() << "\n";
}
MS_LOG(EXCEPTION)
<< "TensorSlice requires the input should be one of [slice, ellipsis, int number, bool, none, tuple] , but got "
<< args_info.str();
}
auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations");
auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0),
NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)});
ret_graph->set_output(ret_graph->NewCNode(
{PrimStridedSlice, tensor_node, NewValueNode(begin), NewValueNode(end), NewValueNode(strides)}));
return ret_graph;
}
FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// select indexed item
// args: tuple of items, index
@ -1162,11 +962,6 @@ REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) {
.def(py::init<std::string &>());
}));
REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) {
(void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_")
.def(py::init<std::string &>());
}));
REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
(void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>(
*m, "TupleGetItemTensor_")

@ -175,16 +175,6 @@ class TupleSlice : public MetaFuncGraph {
};
using TupleSlicePtr = std::shared_ptr<TupleSlice>;
class TensorSlice : public MetaFuncGraph {
public:
explicit TensorSlice(const std::string &name) : MetaFuncGraph(name) {}
~TensorSlice() override = default;
MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; }
};
using TensorSlicePtr = std::shared_ptr<TensorSlice>;
class TupleGetItemTensor : public MetaFuncGraph {
public:
explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {}

@ -209,6 +209,28 @@ bool ConvertTensor(const py::object &obj, ValuePtr *const data) {
return true;
}
bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
MS_LOG(DEBUG) << "Converting slice object";
py::slice slice_obj = obj.cast<py::slice>();
auto convert_func = [obj](std::string attr) -> ValuePtr {
auto py_attr = py::getattr(obj, attr.c_str());
if (py::isinstance<py::none>(py_attr)) {
return kNone;
} else if (py::isinstance<py::int_>(py_attr)) {
int value = py::cast<int>(py_attr);
return MakeValue(value);
} else {
MS_LOG(EXCEPTION) << "Slice should contain only int or none";
}
};
ValuePtr start = convert_func("start");
ValuePtr stop = convert_func("stop");
ValuePtr step = convert_func("step");
*data = std::make_shared<ValueSlice>(start, stop, step);
return true;
}
FuncGraphPtr ConvertToBpropCut(py::object obj) {
std::vector<std::string> results = data_converter::GetObjKey(obj);
std::string obj_key = results[0];
@ -321,6 +343,10 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
converted = std::make_shared<StringImm>(py::cast<std::string>(obj));
} else if (py::isinstance<py::dict>(obj)) {
ret = ConvertDict(obj, &converted, use_signature);
} else if (py::isinstance<py::slice>(obj)) {
ret = ConvertSlice(obj, &converted);
} else if (py::isinstance<py::ellipsis>(obj)) {
converted = kEllipsis;
} else if (py::isinstance<py::tuple>(obj)) {
ret = ConvertTuple(obj, &converted, use_signature);
} else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) {

@ -353,11 +353,9 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
auto value = abs_base->cast<AbstractRefPtr>()->ref();
dic = ConvertAbstractToPython(value);
} else if (abs_base->isa<AbstractEllipsis>()) {
auto arg_slice = dyn_cast<AbstractEllipsis>(abs_base);
std::vector<int> shape;
dic["shape"] = shape;
dic["dtype"] = arg_slice->BuildType();
dic["value"] = BuildValue(arg_slice->BuildValue());
dic["shape"] = py::none();
dic["dtype"] = py::ellipsis();
dic["value"] = py::ellipsis();
} else if (abs_base->isa<AbstractTuple>()) {
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
size_t len = arg_tuple->size();

@ -106,7 +106,7 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
}
ret = rets;
} else if (value->isa<EllipsisObj>()) {
ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_CLASS_ELLIPSIS);
ret = py::ellipsis();
} else if (value->isa<ValueSlice>()) {
auto slice = value->cast<ValueSlicePtr>();
auto start = ValuePtrToPyData(slice->start());

@ -206,6 +206,9 @@ class Parameter:
res.default_input = res.default_input / other
return res
def __setitem__(self, index, value):
return self
def set_parameter_data(self, data):
"""Set `default_input` of current `Parameter`."""
if isinstance(data, bool):

@ -144,6 +144,13 @@ class Tensor(Tensor_):
out = tensor_operator_registry.get('__le__')(self, other)
return out
def __getitem__(self, index):
out = tensor_operator_registry.get('__getitem__')(self, index)
return out
def __setitem__(self, index, value):
return self
def __gt__(self, other):
out = tensor_operator_registry.get('__gt__')(self, other)
return out

@ -19,7 +19,7 @@
from functools import partial
from mindspore import context
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, TensorSlice_, \
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
from ...common import dtype as mstype
from ...common.api import ms_function, _pynative_exec, _wrap_func
@ -27,7 +27,7 @@ from .. import functional as F
from ...common.parameter import Parameter
__all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
def add_flags(fn, **flags):

@ -18,7 +18,9 @@ from . import _constexpr_utils as const_utils
from ... import functional as F
from ... import operations as P
from ...composite import base
from ....common.tensor import Tensor
from ....common import dtype as mstype
from ....common._register_for_tensor import tensor_operator_registry
hyper_map = base.HyperMap()
pack = P.Pack(axis=-1)
@ -152,3 +154,101 @@ def generate_updates_from_tensor(data, index, value, op_type):
if need_broadcast:
return broadcast(updates_shape, value)
return value
def tensor_getitem(self, index):
"""Handle tensor getitem"""
if isinstance(index, Tensor):
return tensor_index_by_tensor(self, index)
if isinstance(index, tuple):
return tensor_index_by_tuple(self, index)
if isinstance(index, int):
return tensor_index_by_number(self, index)
if isinstance(index, slice):
return tensor_index_by_slice(self, index)
if isinstance(index, bool):
return tensor_index_by_bool(self, index)
if index is ...:
return self
raise IndexError("Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32,\
got {} with type{}".format(index, type(index)))
tensor_operator_registry.register("__getitem__", tensor_getitem)
def tensor_getitem_by_tuple_of_tensor(data, tuple_index):
"""Tensor getitem by a tuple of tensor."""
indices = generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result
def tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
"""Tensor getitem by a tuple of mixed tensor."""
indices = generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result
def tensor_index_by_slice(data, slice_index):
"""Tensor getitem by a single slice"""
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(F.shape(data), slice_index)
return F.strided_slice(data, begin_strides, end_strides, step_strides)
def tensor_index_by_integer(data, number):
"""Tensor getitem by a single integer number"""
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(F.shape(data), number)
shrink_axis_mask = 1
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
def tensor_index_by_bool(data, bool_value):
"""Tensor getitem by a single bool value"""
if bool_value:
return F.expand_dims(data, 0)
return const_utils.raise_index_error("bool value as indexing ,false is not supported")
def tensor_index_by_number(data, number):
"""Tensor getitem by a Number which may be integer/float/bool value"""
number_type = const_utils.check_number_index_type(number)
if number_type == const_utils.BOOL_:
return tensor_index_by_bool(data, number)
if number_type == const_utils.INT_:
return tensor_index_by_integer(data, number)
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool")
def tensor_index_by_tensor(data, tensor_index):
"""Tensor getitem by a single tensor"""
dtype_valid = const_utils.check_index_tensor_dtype(F.dtype(tensor_index),
const_utils.TENSOR_GETITEM)
if dtype_valid:
return F.gather(data, tensor_index, 0)
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool")
def tensor_index_by_tuple_slice(data, t):
"""Tensor getitem by a tuple of slice"""
begin_strides, end_strides, step_strides, shrink_axis_mask = \
const_utils.get_stride_info_from_tuple(F.shape(data), t)
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
def tensor_index_by_tuple(data, tuple_index):
"""Tensor getitem by tuple of various types"""
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return tensor_index_by_tuple_slice(data, tuple_index)
if index_elements_type == const_utils.ALL_TENSOR:
return tensor_getitem_by_tuple_of_tensor(data, tuple_index)
return tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index)

File diff suppressed because it is too large Load Diff

@ -15,7 +15,6 @@
"""Implementation for getitem."""
from . import _compile_utils as compile_utils
from . import _constexpr_utils as const_utils
from .. import base
from ... import functional as F
@ -50,29 +49,6 @@ _tuple_slice = _TupleSlice('tuple_slice')
"""_tuple_slice is an metafuncgraph object which will slice a tuple."""
class _TensorSlice(base.TensorSlice_):
"""
Slices a tensor.
Inputs:
data (Tensor): A tensor to be sliced.
s (slice): The index to slice tuple data.
Outputs:
Tensor, consists of some elements of data.
"""
def __init__(self, name):
base.TensorSlice_.__init__(self, name)
def __call__(self, *args):
pass
_tensor_slice = _TensorSlice('tensor_slice')
"""_tensor_slice is an metafuncgraph object which will slice a tensor."""
class _TupleGetItemTensor(base.TupleGetItemTensor_):
"""
Getting item of tuple by tensor index.
@ -182,13 +158,13 @@ def _tensor_getitem_by_number(data, number_index):
Outputs:
Tensor, element type is as same as the element type of data.
"""
return _tensor_slice(data, number_index)
return compile_utils.tensor_index_by_number(data, number_index)
@getitem.register("Tensor", "None")
def _tensor_getitem_by_none(data, index):
"""
Getting item of tensor by None.
For none indexing , expand data with one dim
Inputs:
data (Tensor): A tensor.
@ -197,7 +173,7 @@ def _tensor_getitem_by_none(data, index):
Outputs:
Tensor, element type is as same as the element type of data.
"""
return _tensor_slice(data, index)
return F.expand_dims(data, 0)
@getitem.register("Tensor", "Slice")
@ -212,13 +188,13 @@ def _tensor_getitem_by_slice(data, slice_index):
Outputs:
Tensor, element type is same as the element type of data.
"""
return _tensor_slice(data, slice_index)
return compile_utils.tensor_index_by_slice(data, slice_index)
@getitem.register("Tensor", "Tensor")
def _tensor_getitem_by_tensor(data, tensor_index):
"""
Getting item of tensor by slice.
Getting item of tensor by tensor indice.
Inputs:
data (Tensor): A tensor.
@ -227,18 +203,13 @@ def _tensor_getitem_by_tensor(data, tensor_index):
Outputs:
Tensor, element type is same as the element type of data.
"""
check_dtypes = const_utils.check_index_tensor_dtype(F.dtype(tensor_index),
const_utils.TENSOR_GETITEM)
result = None
if check_dtypes:
result = F.gather(data, tensor_index, 0)
return result
return compile_utils.tensor_index_by_tensor(data, tensor_index)
@getitem.register("Tensor", "Tuple")
def _tensor_getitem_by_tuple(data, tuple_index):
"""
Getting item of tensor by slice tuple.
Getting item of tensor by tuple.
Inputs:
data (Tensor): A tensor.
@ -247,13 +218,7 @@ def _tensor_getitem_by_tuple(data, tuple_index):
Outputs:
Tensor, element type is same as the element type of data.
"""
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return _tensor_slice(data, tuple_index)
if index_elements_type == const_utils.ALL_TENSOR:
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index)
return compile_utils.tensor_index_by_tuple(data, tuple_index)
@getitem.register("Tensor", "Ellipsis")
@ -268,22 +233,4 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):
Outputs:
Tensor, same as data.
"""
return _tensor_slice(data, ellipsis_index)
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
"""Tensor getitem by a tuple of tensor."""
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result
def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
"""Tensor getitem by a tuple of mixed tensor."""
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result
return data

File diff suppressed because it is too large Load Diff

@ -240,156 +240,6 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
ASSERT_EQ(real, expect);
}
TEST_F(TestComposite, test_TensorSliceBySlice) {
MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
FuncGraphPtr tensorSlicePtrGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
AbstractBasePtrList eles;
AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(1);
AbstractScalarPtr stop_index = std::make_shared<AbstractScalar>(6);
AbstractScalarPtr step = std::make_shared<AbstractScalar>(2);
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {tensor, slice};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed.";
}
AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({3, 7, 8});
ASSERT_EQ(*ret, *expect);
}
TEST_F(TestComposite, test_TensorSliceBySliceTuple) {
MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
AbstractBasePtrList eles;
AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(0);
AbstractScalarPtr stop_index = std::make_shared<AbstractScalar>(6);
AbstractScalarPtr step = std::make_shared<AbstractScalar>(2);
AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
eles.push_back(slice);
start_index = std::make_shared<AbstractScalar>(1);
stop_index = std::make_shared<AbstractScalar>(5);
step = std::make_shared<AbstractScalar>(1);
slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
eles.push_back(slice);
start_index = std::make_shared<AbstractScalar>(2);
stop_index = std::make_shared<AbstractScalar>(8);
step = std::make_shared<AbstractScalar>(3);
slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
eles.push_back(slice);
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed.";
}
AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({3, 4, 2});
ASSERT_EQ(*ret, *expect);
}
TEST_F(TestComposite, test_TensorSliceBySliceTupleToReduceDimension) {
MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
AbstractBasePtrList eles;
AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(1);
AbstractScalarPtr stop_index = std::make_shared<AbstractScalar>(5);
AbstractScalarPtr step = std::make_shared<AbstractScalar>(2);
AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
eles.push_back(slice);
AbstractScalarPtr elem_index = std::make_shared<AbstractScalar>(1);
eles.push_back(elem_index);
start_index = std::make_shared<AbstractScalar>(2);
stop_index = std::make_shared<AbstractScalar>(6);
step = std::make_shared<AbstractScalar>(1);
slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
eles.push_back(slice);
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed.";
}
AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({2, 4});
ASSERT_EQ(*ret, *expect);
}
TEST_F(TestComposite, test_TensorSliceByScalar) {
MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(2);
AbstractBasePtrList args_spec_list = {tensor, start_index};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed.";
}
AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({7, 8});
ASSERT_EQ(*ret, *expect);
}
TEST_F(TestComposite, test_TensorSliceByScalarTuple) {
MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
AbstractBasePtrList eles;
AbstractScalarPtr elem_index = std::make_shared<AbstractScalar>(1);
eles.push_back(elem_index);
elem_index = std::make_shared<AbstractScalar>(3);
eles.push_back(elem_index);
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed.";
}
AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({8});
ASSERT_EQ(*ret, *expect);
}
TEST_F(TestComposite, test_TensorSliceByScalarTupleToScalar) {
MetaFuncGraphPtr tensorSlicePtr = std::make_shared<prim::TensorSlice>("tensor_slice");
FuncGraphPtr tensorSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tensorSlicePtr, 2);
AbstractBasePtrList eles;
AbstractScalarPtr elem_index = std::make_shared<AbstractScalar>(3);
eles.push_back(elem_index);
elem_index = std::make_shared<AbstractScalar>(0);
eles.push_back(elem_index);
elem_index = std::make_shared<AbstractScalar>(6);
eles.push_back(elem_index);
AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({6, 7, 8});
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed.";
}
AbstractTensorPtr expect = UTCompositeUtils::ArrayInt32Of({});
ASSERT_EQ(*ret, *expect);
}
TEST_F(TestComposite, test_UnpackCall_3args) {
MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3);

@ -107,5 +107,5 @@ class TestUnsupportParam():
def test_Sgd_init(self):
with pytest.raises(TypeError):
paramsTensor = Tensor(np.zeros([1, 2, 3]))
paramsTensor = Parameter(Tensor(np.zeros([1, 2, 3])), "x")
SGD(paramsTensor)

@ -25,7 +25,6 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
class NetWorkSlicePositive(Cell):
def __init__(self):
super(NetWorkSlicePositive, self).__init__()
@ -528,6 +527,7 @@ def test_tensor_assign():
# 2. A[::, 1:, ...] = scalar/tensor
net = TensorAssignWithTupleEllipsis()
net(Ta, b)
Tc = Tensor(1, mstype.float32)
with pytest.raises(ValueError):
net(Ta, Tc)
with pytest.raises(ValueError):

@ -168,7 +168,7 @@ def test_select_grad():
sens = Tensor(np.ones_like(out.asnumpy()).astype(np.float32))
args = [cond, x, y, sens]
gout = gfn(*args)
expect_cond = np.zeros_like(cond)
expect_cond = np.zeros_like(cond.asnumpy())
expect_x = np.array([[1, 0, 0], [0, 1, 1]])
expect_y = np.array([[0, 1, 1], [1, 0, 0]])
assert np.all(gout[0].asnumpy() == expect_cond)

Loading…
Cancel
Save