From 79333916ed0c204d3ef31396b5dcc4b82c615e13 Mon Sep 17 00:00:00 2001 From: lihongkang <[lihongkang1@huawei.com]> Date: Thu, 3 Dec 2020 15:22:14 +0800 Subject: [PATCH] fix bugs --- mindspore/nn/layer/activation.py | 3 ++ mindspore/nn/layer/embedding.py | 2 +- mindspore/nn/wrap/cell_wrapper.py | 2 +- .../ops/_op_impl/tbe/sparse_gather_v2.py | 40 +++++++------------ mindspore/ops/operations/array_ops.py | 6 +++ 5 files changed, 26 insertions(+), 27 deletions(-) diff --git a/mindspore/nn/layer/activation.py b/mindspore/nn/layer/activation.py index 97bf36db28..81b5b8ac59 100644 --- a/mindspore/nn/layer/activation.py +++ b/mindspore/nn/layer/activation.py @@ -151,6 +151,9 @@ class ELU(Cell): Outputs: Tensor, with the same type and shape as the `input_data`. + Supported Platforms: + ``Ascend`` ``GPU`` + Examples: >>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float32) >>> elu = nn.ELU() diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index ce777229ca..e89c756924 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -167,7 +167,7 @@ class EmbeddingLookup(Cell): Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` + ``Ascend`` ``CPU`` Examples: >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32) diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 87b4037e8d..c330bea00d 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -205,7 +205,7 @@ class TrainOneStepCell(Cell): >>> train_net = nn.TrainOneStepCell(loss_net, optim) >>> >>> #2) Using user-defined WithLossCell - >>>class MyWithLossCell(nn.Cell): + >>> class MyWithLossCell(nn.cell): >>> def __init__(self, backbone, loss_fn): >>> super(MyWithLossCell, self).__init__(auto_prefix=False) >>> self._backbone = backbone diff --git a/mindspore/ops/_op_impl/tbe/sparse_gather_v2.py b/mindspore/ops/_op_impl/tbe/sparse_gather_v2.py index b824836312..5f89c06b7f 100644 --- a/mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +++ b/mindspore/ops/_op_impl/tbe/sparse_gather_v2.py @@ -27,36 +27,26 @@ sparse_gather_v2_op_info = TBERegOp("SparseGatherV2") \ .input(0, "x", False, "required", "all") \ .input(1, "indices", False, "required", "all") \ .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \ - .dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \ - .dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I8_5HD) \ - .dtype_format(DataType.I8_5HD, DataType.I64_5HD, DataType.I8_5HD) \ - .dtype_format(DataType.I8_FracZ, DataType.I32_FracZ, DataType.I8_FracZ) \ - .dtype_format(DataType.I8_FracZ, DataType.I64_FracZ, DataType.I8_FracZ) \ .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \ - .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \ - .dtype_format(DataType.U8_5HD, DataType.I32_5HD, DataType.U8_5HD) \ - .dtype_format(DataType.U8_5HD, DataType.I64_5HD, DataType.U8_5HD) \ - .dtype_format(DataType.U8_FracZ, DataType.I32_FracZ, DataType.U8_FracZ) \ - .dtype_format(DataType.U8_FracZ, DataType.I64_FracZ, DataType.U8_FracZ) \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ - .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \ - .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ - .dtype_format(DataType.I32_5HD, DataType.I64_5HD, DataType.I32_5HD) \ - .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \ - .dtype_format(DataType.I32_FracZ, DataType.I64_FracZ, DataType.I32_FracZ) \ - .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \ + .dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \ + .dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \ + .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \ .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \ - .dtype_format(DataType.F16_5HD, DataType.I32_5HD, DataType.F16_5HD) \ - .dtype_format(DataType.F16_5HD, DataType.I64_5HD, DataType.F16_5HD) \ - .dtype_format(DataType.F16_FracZ, DataType.I32_FracZ, DataType.F16_FracZ) \ - .dtype_format(DataType.F16_FracZ, DataType.I64_FracZ, DataType.F16_FracZ) \ - .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ - .dtype_format(DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD) \ - .dtype_format(DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD) \ - .dtype_format(DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \ - .dtype_format(DataType.F32_FracZ, DataType.I64_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \ + .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \ + .dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \ + .dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \ + .dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \ .get_op_info() diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 399972ddf4..de2b050db8 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -833,6 +833,10 @@ class SparseGatherV2(GatherV2): >>> input_indices = Tensor(np.array([1, 2]), mindspore.int32) >>> axis = 1 >>> out = ops.SparseGatherV2()(input_params, input_indices, axis) + >>> print(out) + [[2. 7.] + [4. 54.] + [2. 55.]] """ @@ -1642,6 +1646,8 @@ class ArgMaxWithValue(PrimitiveWithInfer): Examples: >>> input_x = Tensor(np.random.rand(5), mindspore.float32) >>> index, output = ops.ArgMaxWithValue()(input_x) + >>> print(index, output) + 2 0.87173676 """ @prim_attr_register