optimizer error info for not supported np.dtype and modify example of TopK

pull/8373/head
buxue 4 years ago
parent c214ed2d0f
commit 885c60d1c8

@ -69,8 +69,8 @@ class Tensor(Tensor_):
valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64, np.bool_)
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes:
raise TypeError(f"For Tensor, the input_data is a numpy array whose data type is "
f"{input_data.dtype} that is not supported to initialize a Tensor.")
raise TypeError(f"For Tensor, the input_data is a numpy array whose value is {input_data} and "
f"data type is {input_data.dtype} that is not supported to initialize a Tensor.")
if isinstance(input_data, (tuple, list)):
if np.array(input_data).dtype not in valid_dtypes:
raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.")

@ -1741,8 +1741,8 @@ class TopK(PrimitiveWithInfer):
>>> input_x = Tensor([1, 2, 3, 4, 5], mindspore.float16)
>>> k = 3
>>> values, indices = topk(input_x, k)
>>> assert values == Tensor(np.array([5, 4, 3]), mstype.float16)
>>> assert indices == Tensor(np.array([4, 3, 2]), mstype.int32)
>>> assert values == Tensor(np.array([5, 4, 3]), mstype.float16).all()
>>> assert indices == Tensor(np.array([4, 3, 2]), mstype.int32).all()
"""
@prim_attr_register

Loading…
Cancel
Save