!9429 update supported platforms for op EmbeddingLookup, ELU and fix bugs of op SparseGatherV2, ArgMaxWithValue

From: @lihongkang1
Reviewed-by: @liangchenghui,@youui
Signed-off-by: @liangchenghui
pull/9429/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 9605102dd1

@ -151,6 +151,9 @@ class ELU(Cell):
Outputs:
Tensor, with the same type and shape as the `input_data`.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float32)
>>> elu = nn.ELU()

@ -167,7 +167,7 @@ class EmbeddingLookup(Cell):
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
``Ascend`` ``CPU``
Examples:
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)

@ -205,7 +205,7 @@ class TrainOneStepCell(Cell):
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
>>>
>>> #2) Using user-defined WithLossCell
>>>class MyWithLossCell(nn.Cell):
>>> class MyWithLossCell(nn.cell):
>>> def __init__(self, backbone, loss_fn):
>>> super(MyWithLossCell, self).__init__(auto_prefix=False)
>>> self._backbone = backbone

@ -27,36 +27,26 @@ sparse_gather_v2_op_info = TBERegOp("SparseGatherV2") \
.input(0, "x", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
.dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I8_5HD) \
.dtype_format(DataType.I8_5HD, DataType.I64_5HD, DataType.I8_5HD) \
.dtype_format(DataType.I8_FracZ, DataType.I32_FracZ, DataType.I8_FracZ) \
.dtype_format(DataType.I8_FracZ, DataType.I64_FracZ, DataType.I8_FracZ) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.U8_5HD, DataType.I32_5HD, DataType.U8_5HD) \
.dtype_format(DataType.U8_5HD, DataType.I64_5HD, DataType.U8_5HD) \
.dtype_format(DataType.U8_FracZ, DataType.I32_FracZ, DataType.U8_FracZ) \
.dtype_format(DataType.U8_FracZ, DataType.I64_FracZ, DataType.U8_FracZ) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_5HD, DataType.I64_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \
.dtype_format(DataType.I32_FracZ, DataType.I64_FracZ, DataType.I32_FracZ) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.I32_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_5HD, DataType.I64_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.I32_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_FracZ, DataType.I64_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \
.get_op_info()

@ -833,6 +833,10 @@ class SparseGatherV2(GatherV2):
>>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
>>> axis = 1
>>> out = ops.SparseGatherV2()(input_params, input_indices, axis)
>>> print(out)
[[2. 7.]
[4. 54.]
[2. 55.]]
"""
@ -1642,6 +1646,8 @@ class ArgMaxWithValue(PrimitiveWithInfer):
Examples:
>>> input_x = Tensor(np.random.rand(5), mindspore.float32)
>>> index, output = ops.ArgMaxWithValue()(input_x)
>>> print(index, output)
2 0.87173676
"""
@prim_attr_register

Loading…
Cancel
Save