From 6a2c1df5e3b4723bf9db700f4e92fc7f0dd3df70 Mon Sep 17 00:00:00 2001
From: yanzhenxiang2020 <yanzhenxiang@huawei.com>
Date: Mon, 24 Aug 2020 09:57:02 +0800
Subject: [PATCH] sync aicpu ops to open from ms-incubator: UniqueWithPad,
 EditDistance, TransData

---
 mindspore/ops/_op_impl/aicpu/__init__.py      |  3 +
 mindspore/ops/_op_impl/aicpu/edit_distance.py | 56 +++++++++++++++++++
 mindspore/ops/_op_impl/aicpu/trans_data.py    | 34 +++++++++++
 .../ops/_op_impl/aicpu/unique_with_pad.py     | 32 +++++++++++
 mindspore/ops/operations/__init__.py          |  3 +-
 mindspore/ops/operations/array_ops.py         | 35 ++++++++++++
 tests/st/ops/ascend/test_edit_distance.py     | 48 ++++++++++++++++
 tests/st/ops/ascend/test_unique_with_pad.py   | 44 +++++++++++++++
 8 files changed, 254 insertions(+), 1 deletion(-)
 create mode 100644 mindspore/ops/_op_impl/aicpu/edit_distance.py
 create mode 100644 mindspore/ops/_op_impl/aicpu/trans_data.py
 create mode 100644 mindspore/ops/_op_impl/aicpu/unique_with_pad.py
 create mode 100644 tests/st/ops/ascend/test_edit_distance.py
 create mode 100644 tests/st/ops/ascend/test_unique_with_pad.py

diff --git a/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/ops/_op_impl/aicpu/__init__.py
index 47b04d558c..54bca3da9f 100644
--- a/mindspore/ops/_op_impl/aicpu/__init__.py
+++ b/mindspore/ops/_op_impl/aicpu/__init__.py
@@ -19,6 +19,8 @@ from .embedding_lookup import _embedding_lookup_aicpu
 from .padding import _padding_aicpu
 from .gather import _gather_aicpu
 from .identity import _identity_aicpu
+from .edit_distance import _edit_distance_aicpu
+from .unique_with_pad import _unique_with_pad_aicpu
 from .dropout_genmask import _dropout_genmask_aicpu
 from .get_next import _get_next_aicpu
 from .print_tensor import _print_aicpu
@@ -56,3 +58,4 @@ from .fused_sparse_lazy_adam import _fused_sparse_lazy_adam_aicpu
 from .fused_sparse_ftrl import _fused_sparse_ftrl_aicpu
 from .fused_sparse_proximal_adagrad import _fused_sparse_proximal_adagrad_aicpu
 from .meshgrid import _meshgrid_aicpu
+from .trans_data import _trans_data_aicpu
diff --git a/mindspore/ops/_op_impl/aicpu/edit_distance.py b/mindspore/ops/_op_impl/aicpu/edit_distance.py
new file mode 100644
index 0000000000..a37cd1be3d
--- /dev/null
+++ b/mindspore/ops/_op_impl/aicpu/edit_distance.py
@@ -0,0 +1,56 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""EditDistance op"""
+from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
+
+edit_distance_op_info = AiCPURegOp("EditDistance") \
+    .fusion_type("OPAQUE") \
+    .input(0, "hypothesis_indices", "required") \
+    .input(1, "hypothesis_values", "required") \
+    .input(2, "hypothesis_shape", "required") \
+    .input(3, "truth_indices", "required") \
+    .input(4, "truth_values", "required") \
+    .input(5, "truth_shape", "required") \
+    .output(0, "y", "required") \
+    .attr("normalize", "bool") \
+    .dtype_format(DataType.I64_Default, DataType.I8_Default, DataType.I64_Default, \
+    DataType.I64_Default, DataType.I8_Default, DataType.I64_Default, DataType.F32_Default,) \
+    .dtype_format(DataType.I64_Default, DataType.I16_Default, DataType.I64_Default, \
+    DataType.I64_Default, DataType.I16_Default, DataType.I64_Default, DataType.F32_Default,) \
+    .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default, \
+    DataType.I64_Default, DataType.I32_Default, DataType.I64_Default, DataType.F32_Default,) \
+    .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, \
+    DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.F32_Default,) \
+    .dtype_format(DataType.I64_Default, DataType.U8_Default, DataType.I64_Default, \
+    DataType.I64_Default, DataType.U8_Default, DataType.I64_Default, DataType.F32_Default,) \
+    .dtype_format(DataType.I64_Default, DataType.U16_Default, DataType.I64_Default, \
+    DataType.I64_Default, DataType.U16_Default, DataType.I64_Default, DataType.F32_Default,) \
+    .dtype_format(DataType.I64_Default, DataType.U32_Default, DataType.I64_Default, \
+    DataType.I64_Default, DataType.U32_Default, DataType.I64_Default, DataType.F32_Default,) \
+    .dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.I64_Default, \
+    DataType.I64_Default, DataType.U64_Default, DataType.I64_Default, DataType.F32_Default,) \
+    .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.I64_Default, \
+    DataType.I64_Default, DataType.F16_Default, DataType.I64_Default, DataType.F32_Default,) \
+    .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I64_Default, \
+    DataType.I64_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default,) \
+    .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.I64_Default, \
+    DataType.I64_Default, DataType.F64_Default, DataType.I64_Default, DataType.F32_Default,) \
+    .get_op_info()
+
+@op_info_register(edit_distance_op_info)
+def _edit_distance_aicpu():
+    """EditDistance AiCPU register"""
+    return
diff --git a/mindspore/ops/_op_impl/aicpu/trans_data.py b/mindspore/ops/_op_impl/aicpu/trans_data.py
new file mode 100644
index 0000000000..8f3fe0503d
--- /dev/null
+++ b/mindspore/ops/_op_impl/aicpu/trans_data.py
@@ -0,0 +1,34 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""TransData op"""
+from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
+
+trans_data_op_info = AiCPURegOp("TransData") \
+    .fusion_type("OPAQUE") \
+    .input(0, "src", "required") \
+    .output(0, "dst", "required") \
+    .attr("src_format", "str") \
+    .attr("dst_format", "str") \
+    .dtype_format(DataType.U16_NCHW, DataType.U16_5HD) \
+    .dtype_format(DataType.U16_5HD, DataType.U16_NCHW) \
+    .dtype_format(DataType.U16_Default, DataType.U16_5HD) \
+    .dtype_format(DataType.U16_5HD, DataType.U16_Default) \
+    .get_op_info()
+
+@op_info_register(trans_data_op_info)
+def _trans_data_aicpu():
+    """TransData aicpu register"""
+    return
diff --git a/mindspore/ops/_op_impl/aicpu/unique_with_pad.py b/mindspore/ops/_op_impl/aicpu/unique_with_pad.py
new file mode 100644
index 0000000000..3828ca1764
--- /dev/null
+++ b/mindspore/ops/_op_impl/aicpu/unique_with_pad.py
@@ -0,0 +1,32 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""UniqueWithPad op"""
+from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
+
+unique_with_pad_op_info = AiCPURegOp("UniqueWithPad") \
+    .fusion_type("OPAQUE") \
+    .input(0, "x", "required") \
+    .input(1, "pad_num", "required") \
+    .output(0, "y", "required") \
+    .output(1, "idx", "required") \
+    .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
+    .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
+    .get_op_info()
+
+@op_info_register(unique_with_pad_op_info)
+def _unique_with_pad_aicpu():
+    """UniqueWithPad AiCPU register"""
+    return
diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py
index a504f53e3e..789906fa81 100644
--- a/mindspore/ops/operations/__init__.py
+++ b/mindspore/ops/operations/__init__.py
@@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
                         Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Meshgrid,
                         SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
                         ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
-                        Shape, DynamicShape, Size, Slice, Split, TransShape, ParallelConcat, Padding,
+                        Shape, DynamicShape, Size, Slice, Split, TransShape, ParallelConcat, Padding, UniqueWithPad,
                         ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint,
                         Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance, Sort,
                         Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
@@ -156,6 +156,7 @@ __all__ = [
     'Padding',
     'GatherD',
     'Identity',
+    'UniqueWithPad',
     'Concat',
     'Pack',
     'Unpack',
diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py
index dbc8f9b8a9..956d4a669e 100644
--- a/mindspore/ops/operations/array_ops.py
+++ b/mindspore/ops/operations/array_ops.py
@@ -747,6 +747,41 @@ class Padding(PrimitiveWithInfer):
         return out
 
 
+class UniqueWithPad(PrimitiveWithInfer):
+    """
+    Return unique elements and relative indexes in 1-D tensor, fill with pad num.
+
+    Inputs:
+        - **x** (Tensor) - The tensor need to be unique. Must be 1-D vector with types: int32, int64.
+        - **pad_num** (int) - Pad num.
+
+    Outputs:
+        tuple(Tensor), tuple of 2 tensors, y and idx.
+        - y (Tensor) - The unique elements filled with pad_num, the shape and type same as x.
+        - idx (Tensor) - The index of each value of x in the unique output y, the shape and type same as x.
+
+    Examples:
+        >>> x = Tensor(np.array([1, 1, 5, 5, 4, 4, 3, 3, 2, 2,]), mindspore.int32)
+        >>> pad_num = 8
+        >>> out = P.UniqueWithPad()(x, pad_num)
+        ([1, 5, 4, 3, 2, 8, 8, 8, 8, 8], [0, 0, 1, 1, 2, 2, 3, 3, 4, 4])
+    """
+    @prim_attr_register
+    def __init__(self):
+        """init UniqueWithPad"""
+
+    def __infer__(self, x, pad_num):
+        validator.check_tensor_type_same({"x": x['dtype']}, [mstype.int32, mstype.int64], self.name)
+        validator.check_subclass("pad_num", pad_num['dtype'], [mstype.int32, mstype.int64], self.name)
+        x_shape = list(x['shape'])
+        validator.check("rank of x", len(x_shape), "expected", 1, Rel.EQ, self.name)
+        out_shape = x_shape
+        out = {'shape': (out_shape, out_shape),
+               'dtype': (x['dtype'], x['dtype']),
+               'value': None}
+        return out
+
+
 class Split(PrimitiveWithInfer):
     """
     Splits input tensor into output_num of tensors along the given axis and output numbers.
diff --git a/tests/st/ops/ascend/test_edit_distance.py b/tests/st/ops/ascend/test_edit_distance.py
new file mode 100644
index 0000000000..651ebfce31
--- /dev/null
+++ b/tests/st/ops/ascend/test_edit_distance.py
@@ -0,0 +1,48 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+import numpy as np
+
+import mindspore.context as context
+import mindspore.nn as nn
+from mindspore import Tensor
+from mindspore.ops import operations as P
+
+context.set_context(mode=context.GRAPH_MODE,
+                    device_target="Ascend")
+
+
+class EditDistance(nn.Cell):
+    def __init__(self, hypothesis_shape, truth_shape, normalize=True):
+        super(EditDistance, self).__init__()
+        self.edit_distance = P.EditDistance(normalize)
+        self.hypothesis_shape = hypothesis_shape
+        self.truth_shape = truth_shape
+
+    def construct(self, hypothesis_indices, hypothesis_values, truth_indices, truth_values):
+        return self.edit_distance(hypothesis_indices, hypothesis_values, self.hypothesis_shape,
+                                  truth_indices, truth_values, self.truth_shape)
+
+def test_edit_distance():
+    h1, h2, h3 = np.array([[0, 0, 0], [1, 0, 1], [1, 1, 1]]), np.array([1, 2, 3]), np.array([2, 2, 2])
+    t1, t2, t3 = np.array([[0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]), np.array([1, 2, 3, 1]), np.array([2, 2, 2])
+    hypothesis_indices = Tensor(h1.astype(np.int64))
+    hypothesis_values = Tensor(h2.astype(np.int64))
+    hypothesis_shape = Tensor(h3.astype(np.int64))
+    truth_indices = Tensor(t1.astype(np.int64))
+    truth_values = Tensor(t2.astype(np.int64))
+    truth_shape = Tensor(t3.astype(np.int64))
+    edit_distance = EditDistance(hypothesis_shape, truth_shape)
+    out = edit_distance(hypothesis_indices, hypothesis_values, truth_indices, truth_values)
+    print(out)
diff --git a/tests/st/ops/ascend/test_unique_with_pad.py b/tests/st/ops/ascend/test_unique_with_pad.py
new file mode 100644
index 0000000000..e42739d572
--- /dev/null
+++ b/tests/st/ops/ascend/test_unique_with_pad.py
@@ -0,0 +1,44 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+import numpy as np
+
+import mindspore.context as context
+import mindspore.nn as nn
+import mindspore.common.dtype as mstype
+from mindspore import Tensor
+from mindspore.ops import operations as P
+
+context.set_context(mode=context.GRAPH_MODE,
+                    device_target="Ascend")
+
+
+class Net(nn.Cell):
+    def __init__(self, pad_num):
+        super(Net, self).__init__()
+        self.unique_with_pad = P.UniqueWithPad()
+        self.pad_num = pad_num
+
+    def construct(self, x):
+        return self.unique_with_pad(x, self.pad_num)
+
+
+def test_unique_with_pad():
+    x = Tensor(np.array([1, 1, 5, 5, 4, 4, 3, 3, 2, 2]), mstype.int32)
+    pad_num = 8
+    unique_with_pad = Net(pad_num)
+    out = unique_with_pad(x)
+    expect_val = ([1, 5, 4, 3, 2, 8, 8, 8, 8, 8], [0, 0, 1, 1, 2, 2, 3, 3, 4, 4])
+    assert(out[0].asnumpy() == expect_val[0]).all()
+    assert(out[1].asnumpy() == expect_val[1]).all()