From 416017bf6599dbaf1ceaeda4da99063dd7d8ab26 Mon Sep 17 00:00:00 2001 From: buxue Date: Sun, 7 Feb 2021 21:39:23 +0800 Subject: [PATCH] make slice support tensor --- mindspore/core/abstract/prim_structures.cc | 54 +++++++++-- tests/ut/python/ops/test_tuple_slice.py | 107 +++++++++++++-------- 2 files changed, 111 insertions(+), 50 deletions(-) diff --git a/mindspore/core/abstract/prim_structures.cc b/mindspore/core/abstract/prim_structures.cc index b1fca4e747..859a6d1f9e 100644 --- a/mindspore/core/abstract/prim_structures.cc +++ b/mindspore/core/abstract/prim_structures.cc @@ -19,6 +19,7 @@ #include "abstract/infer_functions.h" #include "abstract/utils.h" #include "abstract/param_validator.h" + namespace mindspore { namespace abstract { AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &, @@ -102,19 +103,52 @@ AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr // Inputs: three scalars whose value is an int32 number. CheckArgsSize(primitive->name(), args_spec_list, 3); size_t args_size = args_spec_list.size(); + AbstractBasePtrList slice_args; for (size_t index = 0; index < args_size; index++) { MS_EXCEPTION_IF_NULL(args_spec_list[index]); - if (!args_spec_list[index]->isa() && !args_spec_list[index]->isa()) { - MS_EXCEPTION(TypeError) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone."; - } - if (args_spec_list[index]->isa() && - !dyn_cast(args_spec_list[index])->BuildValue()->isa()) { - MS_EXCEPTION(TypeError) << "MakeSlice eval " << index - << " parameter is an AbstractScalar, but is not an int64 number."; + if (args_spec_list[index]->isa()) { + slice_args.push_back(args_spec_list[index]); + } else if (args_spec_list[index]->isa()) { + ValuePtr scalar_value = args_spec_list[index]->cast()->BuildValue(); + if (scalar_value->isa()) { + slice_args.push_back(args_spec_list[index]); + } else if (scalar_value->isa()) { + ValuePtr scalar_index = MakeValue(static_cast(scalar_value->cast()->value())); + slice_args.push_back(scalar_index->ToAbstract()); + } else { + MS_EXCEPTION(TypeError) << "MakeSlice eval " << index + << " the input scalar type should be int or bool, but got " << scalar_value->ToString(); + } + } else if (args_spec_list[index]->isa()) { + auto arg = args_spec_list[index]->cast(); + TypePtr tensor_dtype = arg->element()->BuildType(); + + auto value = arg->BuildValue()->cast(); + if (value == nullptr) { + MS_EXCEPTION(TypeError) << "MakeSlice eval the input tensor must be a const tensor."; + } + if (value->DataSize() != 1) { + MS_EXCEPTION(TypeError) << "MakeSlice eval the input tensor must contain only one element, but got " + << value->ToString() << " has " << value->DataSize() << " elements."; + } + + if (tensor_dtype->isa()) { + auto *bool_value = static_cast(value->data_c()); + slice_args.push_back(MakeValue((static_cast(*bool_value)))->ToAbstract()); + } else if (tensor_dtype->isa()) { + auto *int_value = static_cast(value->data_c()); + slice_args.push_back(MakeValue((*int_value))->ToAbstract()); + } else { + MS_EXCEPTION(TypeError) << "MakeSlice eval the input tensor type must be int or bool, but got " + << tensor_dtype->ToString(); + } + } else { + MS_EXCEPTION(TypeError) << "MakeSlice eval " << index << " inputs should scalar, None or Tensor, but got" + << args_spec_list[index]->ToString(); } } // Slice: start, end, step - return std::make_shared(args_spec_list[0], args_spec_list[1], args_spec_list[2]); + return std::make_shared(slice_args[0], slice_args[1], slice_args[2]); } template @@ -134,7 +168,7 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got " << index_value->ToString(); } - int64_t idx_v = GetValue(index_value); + auto idx_v = GetValue(index_value); std::size_t nelems = queue->elements().size(); if (idx_v >= SizeToLong(nelems) || idx_v < -SizeToLong(nelems)) { MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToLong(nelems) << ", " @@ -162,7 +196,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got " << index_value->ToString(); } - int64_t idx_v = GetValue(index_value); + auto idx_v = GetValue(index_value); AbstractBasePtrList elements = queue->elements(); std::size_t nelems = elements.size(); int64_t idx_t = idx_v >= 0 ? idx_v : idx_v + SizeToLong(nelems); diff --git a/tests/ut/python/ops/test_tuple_slice.py b/tests/ut/python/ops/test_tuple_slice.py index 897646ee32..c0bda76e26 100644 --- a/tests/ut/python/ops/test_tuple_slice.py +++ b/tests/ut/python/ops/test_tuple_slice.py @@ -31,13 +31,15 @@ class NetWork_1(Cell): def __init__(self): super(NetWork_1, self).__init__() self.addN = P.AddN() + self.index_0 = Tensor(3) + self.index_1 = Tensor([5]) + self.index_3 = Tensor([True]) - def construct(self, tensor1, tensor2, tensor3, tensor4, tensor5, tensor6): - tensor_tuple = (tensor1, tensor2, tensor3, tensor4, tensor5, tensor6) + def construct(self, tensor_tuple): tensor_tuple_slice0 = tensor_tuple[:] - tensor_tuple_slice1 = tensor_tuple[:3] - tensor_tuple_slice2 = tensor_tuple[1:] - tensor_tuple_slice3 = tensor_tuple[2:5:1] + tensor_tuple_slice1 = tensor_tuple[:self.index_0] + tensor_tuple_slice2 = tensor_tuple[self.index_3:] + tensor_tuple_slice3 = tensor_tuple[2:self.index_1:True] sum0 = self.addN(tensor_tuple_slice0) sum1 = self.addN(tensor_tuple_slice1) sum2 = self.addN(tensor_tuple_slice2) @@ -52,13 +54,14 @@ class NetWork_2(Cell): def __init__(self): super(NetWork_2, self).__init__() self.addN = P.AddN() + self.step = Tensor([-1]) + self.index_0 = Tensor(-6) - def construct(self, tensor1, tensor2, tensor3, tensor4, tensor5, tensor6): - tensor_tuple = (tensor1, tensor2, tensor3, tensor4, tensor5, tensor6) - tensor_tuple_slice0 = tensor_tuple[::-1] + def construct(self, tensor_tuple): + tensor_tuple_slice0 = tensor_tuple[::self.step] tensor_tuple_slice1 = tensor_tuple[-1::-1] tensor_tuple_slice2 = tensor_tuple[:-4:-1] - tensor_tuple_slice3 = tensor_tuple[-6:3] + tensor_tuple_slice3 = tensor_tuple[self.index_0:3] tensor_tuple_slice4 = tensor_tuple[-1:-6:-2] sum0 = self.addN(tensor_tuple_slice0) sum1 = self.addN(tensor_tuple_slice1) @@ -69,17 +72,15 @@ class NetWork_2(Cell): return ret -class NetWork_3(Cell): +class NetWorkSliceStepZero(Cell): """ NetWork_3 definition """ def __init__(self): - super(NetWork_3, self).__init__() - self.addN = P.AddN() + super(NetWorkSliceStepZero, self).__init__() - def construct(self, tensor_tuple, start, stop, step=1): - tensor_tuple_slice0 = tensor_tuple[start:stop:step] - res = self.addN(tensor_tuple_slice0) - return res + def construct(self, tensor_tuple): + tensor_tuple_slice = tensor_tuple[0:3:0] + return tensor_tuple_slice class NetWorkOutOfBounds(Cell): @@ -87,45 +88,60 @@ class NetWorkOutOfBounds(Cell): def __init__(self): super(NetWorkOutOfBounds, self).__init__() - self.addN = P.AddN() def construct(self, tensor_tuple): return tensor_tuple[100] +class NetWorkTensorSizeGreaterThanTwo(Cell): + """ NetWork_3 definition """ + + def __init__(self): + super(NetWorkTensorSizeGreaterThanTwo, self).__init__() + self.index_0 = Tensor([2, 3]) + + def construct(self, tensor_tuple): + return tensor_tuple[1:self.index_0] + + +class NetWorkTensorDtypeFloat(Cell): + """ NetWork_3 definition """ + + def __init__(self): + super(NetWorkTensorDtypeFloat, self).__init__() + self.index_0 = Tensor([2.1]) + + def construct(self, tensor_tuple): + return tensor_tuple[1:self.index_0] + + test_cases = [ ('SlicePositive', { 'block': NetWork_1(), - 'desc_inputs': [Tensor(np.ones([2, 3, 4], np.int32)), - Tensor(np.zeros([2, 3, 4], np.int32)), - Tensor(np.ones([2, 3, 4], np.int32)), - Tensor(np.ones([2, 3, 4], np.int32)), - Tensor(np.zeros([2, 3, 4], np.int32)), - Tensor(np.ones([2, 3, 4], np.int32))], + 'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)), + Tensor(np.zeros([2, 3, 4], np.int32)), + Tensor(np.ones([2, 3, 4], np.int32)), + Tensor(np.ones([2, 3, 4], np.int32)), + Tensor(np.zeros([2, 3, 4], np.int32)), + Tensor(np.ones([2, 3, 4], np.int32)))], }), ('SliceNegative', { 'block': NetWork_2(), - 'desc_inputs': [Tensor(np.ones([2, 3, 4], np.int32)), - Tensor(np.zeros([2, 3, 4], np.int32)), - Tensor(np.ones([2, 3, 4], np.int32)), - Tensor(np.ones([2, 3, 4], np.int32)), - Tensor(np.zeros([2, 3, 4], np.int32)), - Tensor(np.ones([2, 3, 4], np.int32))], + 'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)), + Tensor(np.zeros([2, 3, 4], np.int32)), + Tensor(np.ones([2, 3, 4], np.int32)), + Tensor(np.ones([2, 3, 4], np.int32)), + Tensor(np.zeros([2, 3, 4], np.int32)), + Tensor(np.ones([2, 3, 4], np.int32)))], }), ] test_cases_for_verify_exception = [ - ('SliceStartCross', { - 'block': (NetWork_3(), {'exception': TypeError}), - 'desc_inputs': [Tensor(np.ones([2, 3, 4], np.int32)), - Tensor(np.zeros([2, 3, 4], np.int32)), - Tensor(np.ones([2, 3, 4], np.int32))], - }), ('SliceStepZero', { - 'block': (NetWork_3(), {'exception': TypeError}), - 'desc_inputs': [Tensor(np.ones([2, 3, 4], np.int32)), - Tensor(np.zeros([2, 3, 4], np.int32)), - Tensor(np.ones([2, 3, 4], np.int32))], + 'block': (NetWorkSliceStepZero(), {'exception': ValueError}), + 'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)), + Tensor(np.zeros([2, 3, 4], np.int32)), + Tensor(np.ones([2, 3, 4], np.int32)))], }), ('SliceOutOfBounds', { 'block': (NetWorkOutOfBounds(), {'exception': IndexError}), @@ -133,7 +149,18 @@ test_cases_for_verify_exception = [ Tensor(np.zeros([2, 3, 4], np.int32)), Tensor(np.ones([2, 3, 4], np.int32)))], }), - + ('SliceTensorSizeGreaterThanTwo', { + 'block': (NetWorkTensorSizeGreaterThanTwo(), {'exception': TypeError}), + 'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)), + Tensor(np.zeros([2, 3, 4], np.int32)), + Tensor(np.ones([2, 3, 4], np.int32)))], + }), + ('SliceTensorDtypeFloat', { + 'block': (NetWorkTensorDtypeFloat(), {'exception': TypeError}), + 'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)), + Tensor(np.zeros([2, 3, 4], np.int32)), + Tensor(np.ones([2, 3, 4], np.int32)))], + }), ]