make slice support tensor

pull/12249/head
buxue 4 years ago
parent 3805f0dfeb
commit 416017bf65

@ -19,6 +19,7 @@
#include "abstract/infer_functions.h" #include "abstract/infer_functions.h"
#include "abstract/utils.h" #include "abstract/utils.h"
#include "abstract/param_validator.h" #include "abstract/param_validator.h"
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &, AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &,
@ -102,19 +103,52 @@ AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr
// Inputs: three scalars whose value is an int32 number. // Inputs: three scalars whose value is an int32 number.
CheckArgsSize(primitive->name(), args_spec_list, 3); CheckArgsSize(primitive->name(), args_spec_list, 3);
size_t args_size = args_spec_list.size(); size_t args_size = args_spec_list.size();
AbstractBasePtrList slice_args;
for (size_t index = 0; index < args_size; index++) { for (size_t index = 0; index < args_size; index++) {
MS_EXCEPTION_IF_NULL(args_spec_list[index]); MS_EXCEPTION_IF_NULL(args_spec_list[index]);
if (!args_spec_list[index]->isa<AbstractScalar>() && !args_spec_list[index]->isa<AbstractNone>()) { if (args_spec_list[index]->isa<AbstractNone>()) {
MS_EXCEPTION(TypeError) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone."; slice_args.push_back(args_spec_list[index]);
} } else if (args_spec_list[index]->isa<AbstractScalar>()) {
if (args_spec_list[index]->isa<AbstractScalar>() && ValuePtr scalar_value = args_spec_list[index]->cast<AbstractScalarPtr>()->BuildValue();
!dyn_cast<AbstractScalar>(args_spec_list[index])->BuildValue()->isa<Int64Imm>()) { if (scalar_value->isa<IntergerImm>()) {
slice_args.push_back(args_spec_list[index]);
} else if (scalar_value->isa<BoolImm>()) {
ValuePtr scalar_index = MakeValue(static_cast<int64_t>(scalar_value->cast<BoolImmPtr>()->value()));
slice_args.push_back(scalar_index->ToAbstract());
} else {
MS_EXCEPTION(TypeError) << "MakeSlice eval " << index MS_EXCEPTION(TypeError) << "MakeSlice eval " << index
<< " parameter is an AbstractScalar, but is not an int64 number."; << " the input scalar type should be int or bool, but got " << scalar_value->ToString();
}
} else if (args_spec_list[index]->isa<AbstractTensor>()) {
auto arg = args_spec_list[index]->cast<AbstractTensorPtr>();
TypePtr tensor_dtype = arg->element()->BuildType();
auto value = arg->BuildValue()->cast<tensor::TensorPtr>();
if (value == nullptr) {
MS_EXCEPTION(TypeError) << "MakeSlice eval the input tensor must be a const tensor.";
}
if (value->DataSize() != 1) {
MS_EXCEPTION(TypeError) << "MakeSlice eval the input tensor must contain only one element, but got "
<< value->ToString() << " has " << value->DataSize() << " elements.";
}
if (tensor_dtype->isa<Bool>()) {
auto *bool_value = static_cast<bool *>(value->data_c());
slice_args.push_back(MakeValue((static_cast<int64_t>(*bool_value)))->ToAbstract());
} else if (tensor_dtype->isa<Int>()) {
auto *int_value = static_cast<int64_t *>(value->data_c());
slice_args.push_back(MakeValue((*int_value))->ToAbstract());
} else {
MS_EXCEPTION(TypeError) << "MakeSlice eval the input tensor type must be int or bool, but got "
<< tensor_dtype->ToString();
}
} else {
MS_EXCEPTION(TypeError) << "MakeSlice eval " << index << " inputs should scalar, None or Tensor, but got"
<< args_spec_list[index]->ToString();
} }
} }
// Slice: start, end, step // Slice: start, end, step
return std::make_shared<AbstractSlice>(args_spec_list[0], args_spec_list[1], args_spec_list[2]); return std::make_shared<AbstractSlice>(slice_args[0], slice_args[1], slice_args[2]);
} }
template <typename T> template <typename T>
@ -134,7 +168,7 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got " MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got "
<< index_value->ToString(); << index_value->ToString();
} }
int64_t idx_v = GetValue<int64_t>(index_value); auto idx_v = GetValue<int64_t>(index_value);
std::size_t nelems = queue->elements().size(); std::size_t nelems = queue->elements().size();
if (idx_v >= SizeToLong(nelems) || idx_v < -SizeToLong(nelems)) { if (idx_v >= SizeToLong(nelems) || idx_v < -SizeToLong(nelems)) {
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToLong(nelems) << ", " MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToLong(nelems) << ", "
@ -162,7 +196,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got " MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int64 number, but got "
<< index_value->ToString(); << index_value->ToString();
} }
int64_t idx_v = GetValue<int64_t>(index_value); auto idx_v = GetValue<int64_t>(index_value);
AbstractBasePtrList elements = queue->elements(); AbstractBasePtrList elements = queue->elements();
std::size_t nelems = elements.size(); std::size_t nelems = elements.size();
int64_t idx_t = idx_v >= 0 ? idx_v : idx_v + SizeToLong(nelems); int64_t idx_t = idx_v >= 0 ? idx_v : idx_v + SizeToLong(nelems);

@ -31,13 +31,15 @@ class NetWork_1(Cell):
def __init__(self): def __init__(self):
super(NetWork_1, self).__init__() super(NetWork_1, self).__init__()
self.addN = P.AddN() self.addN = P.AddN()
self.index_0 = Tensor(3)
self.index_1 = Tensor([5])
self.index_3 = Tensor([True])
def construct(self, tensor1, tensor2, tensor3, tensor4, tensor5, tensor6): def construct(self, tensor_tuple):
tensor_tuple = (tensor1, tensor2, tensor3, tensor4, tensor5, tensor6)
tensor_tuple_slice0 = tensor_tuple[:] tensor_tuple_slice0 = tensor_tuple[:]
tensor_tuple_slice1 = tensor_tuple[:3] tensor_tuple_slice1 = tensor_tuple[:self.index_0]
tensor_tuple_slice2 = tensor_tuple[1:] tensor_tuple_slice2 = tensor_tuple[self.index_3:]
tensor_tuple_slice3 = tensor_tuple[2:5:1] tensor_tuple_slice3 = tensor_tuple[2:self.index_1:True]
sum0 = self.addN(tensor_tuple_slice0) sum0 = self.addN(tensor_tuple_slice0)
sum1 = self.addN(tensor_tuple_slice1) sum1 = self.addN(tensor_tuple_slice1)
sum2 = self.addN(tensor_tuple_slice2) sum2 = self.addN(tensor_tuple_slice2)
@ -52,13 +54,14 @@ class NetWork_2(Cell):
def __init__(self): def __init__(self):
super(NetWork_2, self).__init__() super(NetWork_2, self).__init__()
self.addN = P.AddN() self.addN = P.AddN()
self.step = Tensor([-1])
self.index_0 = Tensor(-6)
def construct(self, tensor1, tensor2, tensor3, tensor4, tensor5, tensor6): def construct(self, tensor_tuple):
tensor_tuple = (tensor1, tensor2, tensor3, tensor4, tensor5, tensor6) tensor_tuple_slice0 = tensor_tuple[::self.step]
tensor_tuple_slice0 = tensor_tuple[::-1]
tensor_tuple_slice1 = tensor_tuple[-1::-1] tensor_tuple_slice1 = tensor_tuple[-1::-1]
tensor_tuple_slice2 = tensor_tuple[:-4:-1] tensor_tuple_slice2 = tensor_tuple[:-4:-1]
tensor_tuple_slice3 = tensor_tuple[-6:3] tensor_tuple_slice3 = tensor_tuple[self.index_0:3]
tensor_tuple_slice4 = tensor_tuple[-1:-6:-2] tensor_tuple_slice4 = tensor_tuple[-1:-6:-2]
sum0 = self.addN(tensor_tuple_slice0) sum0 = self.addN(tensor_tuple_slice0)
sum1 = self.addN(tensor_tuple_slice1) sum1 = self.addN(tensor_tuple_slice1)
@ -69,17 +72,15 @@ class NetWork_2(Cell):
return ret return ret
class NetWork_3(Cell): class NetWorkSliceStepZero(Cell):
""" NetWork_3 definition """ """ NetWork_3 definition """
def __init__(self): def __init__(self):
super(NetWork_3, self).__init__() super(NetWorkSliceStepZero, self).__init__()
self.addN = P.AddN()
def construct(self, tensor_tuple, start, stop, step=1): def construct(self, tensor_tuple):
tensor_tuple_slice0 = tensor_tuple[start:stop:step] tensor_tuple_slice = tensor_tuple[0:3:0]
res = self.addN(tensor_tuple_slice0) return tensor_tuple_slice
return res
class NetWorkOutOfBounds(Cell): class NetWorkOutOfBounds(Cell):
@ -87,45 +88,60 @@ class NetWorkOutOfBounds(Cell):
def __init__(self): def __init__(self):
super(NetWorkOutOfBounds, self).__init__() super(NetWorkOutOfBounds, self).__init__()
self.addN = P.AddN()
def construct(self, tensor_tuple): def construct(self, tensor_tuple):
return tensor_tuple[100] return tensor_tuple[100]
class NetWorkTensorSizeGreaterThanTwo(Cell):
""" NetWork_3 definition """
def __init__(self):
super(NetWorkTensorSizeGreaterThanTwo, self).__init__()
self.index_0 = Tensor([2, 3])
def construct(self, tensor_tuple):
return tensor_tuple[1:self.index_0]
class NetWorkTensorDtypeFloat(Cell):
""" NetWork_3 definition """
def __init__(self):
super(NetWorkTensorDtypeFloat, self).__init__()
self.index_0 = Tensor([2.1])
def construct(self, tensor_tuple):
return tensor_tuple[1:self.index_0]
test_cases = [ test_cases = [
('SlicePositive', { ('SlicePositive', {
'block': NetWork_1(), 'block': NetWork_1(),
'desc_inputs': [Tensor(np.ones([2, 3, 4], np.int32)), 'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)), Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32)), Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32)), Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)), Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32))], Tensor(np.ones([2, 3, 4], np.int32)))],
}), }),
('SliceNegative', { ('SliceNegative', {
'block': NetWork_2(), 'block': NetWork_2(),
'desc_inputs': [Tensor(np.ones([2, 3, 4], np.int32)), 'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)), Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32)), Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32)), Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)), Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32))], Tensor(np.ones([2, 3, 4], np.int32)))],
}), }),
] ]
test_cases_for_verify_exception = [ test_cases_for_verify_exception = [
('SliceStartCross', {
'block': (NetWork_3(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32))],
}),
('SliceStepZero', { ('SliceStepZero', {
'block': (NetWork_3(), {'exception': TypeError}), 'block': (NetWorkSliceStepZero(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.ones([2, 3, 4], np.int32)), 'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)), Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32))], Tensor(np.ones([2, 3, 4], np.int32)))],
}), }),
('SliceOutOfBounds', { ('SliceOutOfBounds', {
'block': (NetWorkOutOfBounds(), {'exception': IndexError}), 'block': (NetWorkOutOfBounds(), {'exception': IndexError}),
@ -133,7 +149,18 @@ test_cases_for_verify_exception = [
Tensor(np.zeros([2, 3, 4], np.int32)), Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32)))], Tensor(np.ones([2, 3, 4], np.int32)))],
}), }),
('SliceTensorSizeGreaterThanTwo', {
'block': (NetWorkTensorSizeGreaterThanTwo(), {'exception': TypeError}),
'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32)))],
}),
('SliceTensorDtypeFloat', {
'block': (NetWorkTensorDtypeFloat(), {'exception': TypeError}),
'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32)))],
}),
] ]

Loading…
Cancel
Save