merge getitem the by otehr types to by tuple(success)

pull/12714/head
yepei6 4 years ago
parent 130f4c2810
commit 57373ed30a

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -173,7 +173,7 @@ def _tensor_getitem_by_none(data, none_index):
Outputs:
Tensor, element type is as same as the element type of data.
"""
return F.expand_dims(data, 0)
return compile_utils.tensor_index_by_tuple(data, (none_index,))
@getitem.register("Tensor", "Slice")
@ -188,7 +188,7 @@ def _tensor_getitem_by_slice(data, slice_index):
Outputs:
Tensor, element type is the same as the element type of data.
"""
return compile_utils.tensor_index_by_slice(data, slice_index)
return compile_utils.tensor_index_by_tuple(data, (slice_index,))
@getitem.register("Tensor", "Tensor")
@ -203,7 +203,7 @@ def _tensor_getitem_by_tensor(data, tensor_index):
Outputs:
Tensor, element type is the same as the element type of data.
"""
return compile_utils.tensor_index_by_tensor(data, tensor_index)
return compile_utils.tensor_index_by_tuple(data, (tensor_index,))
@getitem.register("Tensor", "Ellipsis")
@ -218,7 +218,7 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):
Outputs:
Tensor, same as data.
"""
return data
return compile_utils.tensor_index_by_tuple(data, (ellipsis_index,))
@getitem.register("Tensor", "List")

@ -3011,7 +3011,7 @@ class StridedSlice(PrimitiveWithInfer):
continue
if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1':
if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0:
raise ValueError(f"For {self.name}, when shrink axis, the stride cannot be negative number, "
raise IndexError(f"For {self.name}, when shrink axis, the stride cannot be negative number, "
f"and begin should be in [-{x_shape[i]}, {x_shape[i]}), "
f"but got stride: {stride}, begin: {begin}.")
j += 1

@ -155,7 +155,7 @@ class TensorGetItemByThreeTensors(Cell):
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def Xtest_getitem_by_tensors():
"""This testcase may encounter a sync stream error occassionally"""
"""This testcase may encounter a sync stream error occasionally"""
net = TensorGetItemByThreeTensors()
input_x = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32)
index_0 = np.random.randint(6, size=(3, 4, 5)).astype(np.int32)
@ -1024,7 +1024,7 @@ def Xtest_tensor_slice_reduce_out_of_bounds_neg():
input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
net = NetWork()
with pytest.raises(ValueError) as ex:
with pytest.raises(IndexError) as ex:
net(input_tensor)
assert "For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`" in str(
ex.value)
@ -1042,7 +1042,7 @@ def Xtest_tensor_slice_reduce_out_of_bounds_positive():
input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
net = NetWork()
with pytest.raises(ValueError) as ex:
with pytest.raises(IndexError) as ex:
net(input_tensor)
assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value)

@ -1160,7 +1160,7 @@ def test_tensor_slice_reduce_out_of_bounds_neg():
input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
net = NetWork()
with pytest.raises(ValueError):
with pytest.raises(IndexError):
net(input_tensor)
@ -1176,5 +1176,5 @@ def test_tensor_slice_reduce_out_of_bounds_positive():
input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
net = NetWork()
with pytest.raises(ValueError):
with pytest.raises(IndexError):
net(input_tensor)

Loading…
Cancel
Save