expand dims of scalar Tensor index

pull/12716/head
yepei6 4 years ago
parent 692d158f5c
commit e7a3f68d29

@ -510,6 +510,8 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
"""Set a tensor item by a int tensor with a scalar."""
if not F.shape(index):
index = F.expand_dims(index, 0)
updates = _generate_updates_from_scalar(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
index = F.expand_dims(index, -1)
return P.TensorScatterUpdate()(data, index, updates)

Loading…
Cancel
Save