!11645 [Numpy-Native] Fix bug: np.column_stack raises RuntimeError in graph mode

From: @wangrao124
Reviewed-by: @liangchenghui,@zhunaipan
Signed-off-by: @liangchenghui
pull/11645/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit d7f8743486

@ -497,8 +497,6 @@ def column_stack(tup):
return tup
if not _check_is_list(tup) and not _check_is_tuple(tup):
_raise_type_error("Tensor or, list or tuple of tensors are required, but got ", tup)
if not tup:
_raise_value_error("Need at least one tensor to concatenate.")
trans_tup = ()
for tensor in tup:
@ -507,7 +505,9 @@ def column_stack(tup):
if tensor.ndim == 1:
tensor = F.expand_dims(tensor, 1)
trans_tup += (tensor,)
return P.Concat(axis=1)(trans_tup)
if not trans_tup:
_raise_value_error("Need at least one tensor to concatenate.")
return P.Concat(1)(trans_tup)
def vstack(tup):
@ -545,15 +545,15 @@ def vstack(tup):
return tup
if not _check_is_list(tup) and not _check_is_tuple(tup):
_raise_type_error("Tensor or, list or tuple of tensors are required, but got", tup)
if not tup:
_raise_value_error("Need at least one tensor to concatenate.")
trans_tup = ()
for tensor in tup:
if tensor.ndim <= 1:
tensor = _expand(tensor, 2, 0)
trans_tup += (tensor,)
return P.Concat(axis=0)(trans_tup)
if not trans_tup:
_raise_value_error("Need at least one tensor to concatenate.")
return P.Concat(0)(trans_tup)
def hstack(tup):
@ -590,19 +590,18 @@ def hstack(tup):
if _check_is_tensor(F.typeof(tup)):
return tup
if not _check_is_list(tup) and not _check_is_tuple(tup):
_raise_type_error(f"Tensor or, list or tuple of tensors are required, but got", tup)
if not tup:
_raise_value_error("Need at least one tensor to concatenate.")
_raise_type_error("Tensor or, list or tuple of tensors are required, but got", tup)
tuple_of_tensor = ()
for tensor in tup:
if tensor.ndim < 1:
tensor = F.expand_dims(tensor, 0)
tuple_of_tensor += (tensor,)
if not tuple_of_tensor:
_raise_value_error("Need at least one tensor to concatenate.")
if tuple_of_tensor[0].ndim <= 1:
return P.Concat(axis=0)(tuple_of_tensor)
return P.Concat(axis=1)(tuple_of_tensor)
return P.Concat(0)(tuple_of_tensor)
return P.Concat(1)(tuple_of_tensor)
def dstack(tup):
@ -641,8 +640,6 @@ def dstack(tup):
return tup
if not _check_is_list(tup) and not _check_is_tuple(tup):
_raise_type_error("Tensor or, list or tuple of tensors are required, but got", tup)
if not tup:
_raise_value_error("Need at least one tensor to concatenate.")
trans_tup = ()
for tensor in tup:
@ -651,7 +648,9 @@ def dstack(tup):
if tensor.ndim == 2:
tensor = F.expand_dims(tensor, 2)
trans_tup += (tensor,)
return P.Concat(axis=2)(trans_tup)
if not trans_tup:
_raise_value_error("Need at least one tensor to concatenate.")
return P.Concat(2)(trans_tup)
def where(condition, x=None, y=None):

Loading…
Cancel
Save