|
|
|
@ -138,25 +138,23 @@ def _tensor_setitem_by_tensor_v1(data, index, value_tensor):
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, element type and shape is same as data.
|
|
|
|
|
"""
|
|
|
|
|
result = None
|
|
|
|
|
index_dtype = F.dtype(index)
|
|
|
|
|
index_shape = F.shape(index)
|
|
|
|
|
is_bool = mult_util.is_same_type(index_dtype, mstype.bool_)
|
|
|
|
|
if not is_bool:
|
|
|
|
|
return mult_util.error_msg(
|
|
|
|
|
"The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,))
|
|
|
|
|
data_shape = F.shape(data)
|
|
|
|
|
if index_shape != data_shape:
|
|
|
|
|
return mult_util.error_msg(
|
|
|
|
|
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (data_shape, index_shape))
|
|
|
|
|
size = F.size(value_tensor)
|
|
|
|
|
if size != 1:
|
|
|
|
|
return mult_util.error_msg(
|
|
|
|
|
"When assign value is a tensor, its size should be 1, but current size is {}.", (size,))
|
|
|
|
|
dtype = F.dtype(data)
|
|
|
|
|
u_cast = F.cast(value_tensor, dtype)
|
|
|
|
|
one_data = F.ones_like(data)
|
|
|
|
|
u = F.tensor_mul(one_data, u_cast)
|
|
|
|
|
return F.select(index, u, data)
|
|
|
|
|
check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype)
|
|
|
|
|
if check_result:
|
|
|
|
|
data_shape = F.shape(data)
|
|
|
|
|
data_shape = mult_util.check_equal(data_shape, index_shape,
|
|
|
|
|
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
|
|
|
|
size = F.size(value_tensor)
|
|
|
|
|
size = mult_util.check_equal(1, size,
|
|
|
|
|
"When assign value is a tensor, its size should be {}, but current size is {}.")
|
|
|
|
|
dtype = F.dtype(data)
|
|
|
|
|
u_cast = F.cast(value_tensor, dtype)
|
|
|
|
|
one_data = F.ones_like(data)
|
|
|
|
|
u = F.tensor_mul(one_data, u_cast)
|
|
|
|
|
result = F.select(index, u, data)
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@setitem.register("Tensor", "Tensor", "Number")
|
|
|
|
@ -179,16 +177,162 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, element type and shape is same as data.
|
|
|
|
|
"""
|
|
|
|
|
result = None
|
|
|
|
|
index_dtype = F.dtype(index)
|
|
|
|
|
index_shape = F.shape(index)
|
|
|
|
|
is_bool = mult_util.is_same_type(index_dtype, mstype.bool_)
|
|
|
|
|
if not is_bool:
|
|
|
|
|
return mult_util.error_msg(
|
|
|
|
|
"The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,))
|
|
|
|
|
shape = F.shape(data)
|
|
|
|
|
if index_shape != shape:
|
|
|
|
|
return mult_util.error_msg(
|
|
|
|
|
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (shape, index_shape))
|
|
|
|
|
dtype = F.dtype(data)
|
|
|
|
|
u = F.fill(dtype, shape, value)
|
|
|
|
|
return F.select(index, u, data)
|
|
|
|
|
check_result = mult_util.check_tensor_setitem_index(mstype.tensor, index_dtype)
|
|
|
|
|
if check_result:
|
|
|
|
|
shape = F.shape(data)
|
|
|
|
|
shape = mult_util.check_equal(
|
|
|
|
|
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
|
|
|
|
|
dtype = F.dtype(data)
|
|
|
|
|
u = F.fill(dtype, shape, value)
|
|
|
|
|
result = F.select(index, u, data)
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@setitem.register("Tensor", "Slice", "Tensor")
|
|
|
|
|
def _tensor_setitem_with_slice_v3(data, input_slice, value):
|
|
|
|
|
"""
|
|
|
|
|
Tensor assignment.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Syntax support: A[Slice] = U
|
|
|
|
|
Restraint condition: A is a Tensor
|
|
|
|
|
Slice like "1:3"
|
|
|
|
|
U is a Tensor(size=1) or Tensor(size>1)
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
data (Tensor): Assigned tensor.
|
|
|
|
|
input_slice (Slice): Slice expression.
|
|
|
|
|
value (Number): Assignment value.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, element type and shape is same as data.
|
|
|
|
|
"""
|
|
|
|
|
return _tensor_assgin_tensor(data, input_slice, value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@setitem.register("Tensor", "Tuple", "Tensor")
|
|
|
|
|
def _tensor_setitem_with_slice_v4(data, input_slice, value):
|
|
|
|
|
"""
|
|
|
|
|
Tensor assignment.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Syntax support: A[Slice] = U
|
|
|
|
|
Restraint condition: A is a Tensor
|
|
|
|
|
Slice like "1:3, ::, :4:-1"
|
|
|
|
|
U is a Tensor(size=1) or Tensor(size>1)
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
data (Tensor): Assigned tensor.
|
|
|
|
|
input_slice (Tuple(Slice)): Slice expression.
|
|
|
|
|
value (Number): Assignment value.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, element type and shape is same as data.
|
|
|
|
|
"""
|
|
|
|
|
return _tensor_assgin_tensor(data, input_slice, value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tensor_assgin_tensor(data, input_slice, value):
|
|
|
|
|
"""Given a tensor value assign to tensor by slice"""
|
|
|
|
|
# 1. condition
|
|
|
|
|
result = None
|
|
|
|
|
check_result = mult_util.check_tensor_setitem_index(input_slice)
|
|
|
|
|
if check_result:
|
|
|
|
|
data_shape = F.shape(data)
|
|
|
|
|
data_size = F.size(data)
|
|
|
|
|
data_dtype = F.dtype(data)
|
|
|
|
|
indices = mult_util.slice2indices(input_slice, data_shape)
|
|
|
|
|
indices_size = F.size(indices)
|
|
|
|
|
indices_size = mult_util.check_indices(indices_size, input_slice)
|
|
|
|
|
update = F.fill(data_dtype, (indices_size,), 1)
|
|
|
|
|
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
|
|
|
|
condition_1d = F.cast(condition_1d, mstype.bool_)
|
|
|
|
|
condition = F.reshape(condition_1d, data_shape)
|
|
|
|
|
# 2. u
|
|
|
|
|
value_fill = None
|
|
|
|
|
value_size = F.size(value)
|
|
|
|
|
|
|
|
|
|
value_size = mult_util.check_indices_value_size(indices_size, value_size)
|
|
|
|
|
if value_size == 1:
|
|
|
|
|
value_fill = F.fill(data_dtype, (indices_size,), 1)
|
|
|
|
|
value = F.cast(value, data_dtype)
|
|
|
|
|
value_fill = F.tensor_mul(value_fill, value)
|
|
|
|
|
elif value_size > 1:
|
|
|
|
|
value_fill = F.reshape(value, (indices_size,))
|
|
|
|
|
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
|
|
|
|
u = F.reshape(value_1d, data_shape)
|
|
|
|
|
# A[slice]= u -> A[B]=U -> select(B, U, A)
|
|
|
|
|
result = F.select(condition, u, data)
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@setitem.register("Tensor", "Slice", "Number")
|
|
|
|
|
def _tensor_setitem_with_slice_v1(data, input_slice, value):
|
|
|
|
|
"""
|
|
|
|
|
Tensor assignment.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Syntax support: A[Slice] = u
|
|
|
|
|
Restraint condition: A is a Tensor.
|
|
|
|
|
Slice like "1:3"
|
|
|
|
|
u is a scalar
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
data (Tensor): Assigned tensor.
|
|
|
|
|
input_slice (Slice): slice expression.
|
|
|
|
|
value (Number): Assignment value.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, element type and shape is same as data.
|
|
|
|
|
"""
|
|
|
|
|
return _tensor_assgin_number(data, input_slice, value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@setitem.register("Tensor", "Tuple", "Number")
|
|
|
|
|
def _tensor_setitem_with_slice_v2(data, input_slice, value):
|
|
|
|
|
"""
|
|
|
|
|
Tensor assignment.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Syntax support: A[Slice] = u
|
|
|
|
|
Restraint condition: A is a Tensor.
|
|
|
|
|
Slice like "1:3, ::, :4:-1"
|
|
|
|
|
u is a scalar
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
data (Tensor): Assigned tensor.
|
|
|
|
|
input_slice (Tuple(Slice)): slice expression.
|
|
|
|
|
value (Number): Assignment value.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, element type and shape is same as data.
|
|
|
|
|
"""
|
|
|
|
|
return _tensor_assgin_number(data, input_slice, value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tensor_assgin_number(data, input_slice, value):
|
|
|
|
|
"""Given a scalar assign to tensor by slice"""
|
|
|
|
|
# 1. condition
|
|
|
|
|
check_result = mult_util.check_tensor_setitem_index(input_slice)
|
|
|
|
|
result = None
|
|
|
|
|
if check_result:
|
|
|
|
|
data_shape = F.shape(data)
|
|
|
|
|
data_size = F.size(data)
|
|
|
|
|
data_dtype = F.dtype(data)
|
|
|
|
|
indices = mult_util.slice2indices(input_slice, data_shape)
|
|
|
|
|
indices_size = F.size(indices)
|
|
|
|
|
indices_size = mult_util.check_indices(indices_size, input_slice)
|
|
|
|
|
update = F.fill(data_dtype, (indices_size,), 1)
|
|
|
|
|
condition_1d = F.scatter_nd(indices, update, (data_size,))
|
|
|
|
|
condition_1d = F.cast(condition_1d, mstype.bool_)
|
|
|
|
|
condition = F.reshape(condition_1d, data_shape)
|
|
|
|
|
# 2. u
|
|
|
|
|
value_fill = F.fill(data_dtype, (indices_size,), value)
|
|
|
|
|
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
|
|
|
|
|
u = F.reshape(value_1d, data_shape)
|
|
|
|
|
# A[slice]= u -> A[B]=U -> select(B, U, A)
|
|
|
|
|
result = F.select(condition, u, data)
|
|
|
|
|
return result
|
|
|
|
|