From 6a2c1df5e3b4723bf9db700f4e92fc7f0dd3df70 Mon Sep 17 00:00:00 2001 From: yanzhenxiang2020 Date: Mon, 24 Aug 2020 09:57:02 +0800 Subject: [PATCH] sync aicpu ops to open from ms-incubator: UniqueWithPad, EditDistance, TransData --- mindspore/ops/_op_impl/aicpu/__init__.py | 3 + mindspore/ops/_op_impl/aicpu/edit_distance.py | 56 +++++++++++++++++++ mindspore/ops/_op_impl/aicpu/trans_data.py | 34 +++++++++++ .../ops/_op_impl/aicpu/unique_with_pad.py | 32 +++++++++++ mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/array_ops.py | 35 ++++++++++++ tests/st/ops/ascend/test_edit_distance.py | 48 ++++++++++++++++ tests/st/ops/ascend/test_unique_with_pad.py | 44 +++++++++++++++ 8 files changed, 254 insertions(+), 1 deletion(-) create mode 100644 mindspore/ops/_op_impl/aicpu/edit_distance.py create mode 100644 mindspore/ops/_op_impl/aicpu/trans_data.py create mode 100644 mindspore/ops/_op_impl/aicpu/unique_with_pad.py create mode 100644 tests/st/ops/ascend/test_edit_distance.py create mode 100644 tests/st/ops/ascend/test_unique_with_pad.py diff --git a/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/ops/_op_impl/aicpu/__init__.py index 47b04d558c..54bca3da9f 100644 --- a/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/ops/_op_impl/aicpu/__init__.py @@ -19,6 +19,8 @@ from .embedding_lookup import _embedding_lookup_aicpu from .padding import _padding_aicpu from .gather import _gather_aicpu from .identity import _identity_aicpu +from .edit_distance import _edit_distance_aicpu +from .unique_with_pad import _unique_with_pad_aicpu from .dropout_genmask import _dropout_genmask_aicpu from .get_next import _get_next_aicpu from .print_tensor import _print_aicpu @@ -56,3 +58,4 @@ from .fused_sparse_lazy_adam import _fused_sparse_lazy_adam_aicpu from .fused_sparse_ftrl import _fused_sparse_ftrl_aicpu from .fused_sparse_proximal_adagrad import _fused_sparse_proximal_adagrad_aicpu from .meshgrid import _meshgrid_aicpu +from .trans_data import _trans_data_aicpu diff --git a/mindspore/ops/_op_impl/aicpu/edit_distance.py b/mindspore/ops/_op_impl/aicpu/edit_distance.py new file mode 100644 index 0000000000..a37cd1be3d --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/edit_distance.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""EditDistance op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +edit_distance_op_info = AiCPURegOp("EditDistance") \ + .fusion_type("OPAQUE") \ + .input(0, "hypothesis_indices", "required") \ + .input(1, "hypothesis_values", "required") \ + .input(2, "hypothesis_shape", "required") \ + .input(3, "truth_indices", "required") \ + .input(4, "truth_values", "required") \ + .input(5, "truth_shape", "required") \ + .output(0, "y", "required") \ + .attr("normalize", "bool") \ + .dtype_format(DataType.I64_Default, DataType.I8_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.I8_Default, DataType.I64_Default, DataType.F32_Default,) \ + .dtype_format(DataType.I64_Default, DataType.I16_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.I16_Default, DataType.I64_Default, DataType.F32_Default,) \ + .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.I32_Default, DataType.I64_Default, DataType.F32_Default,) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.F32_Default,) \ + .dtype_format(DataType.I64_Default, DataType.U8_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.U8_Default, DataType.I64_Default, DataType.F32_Default,) \ + .dtype_format(DataType.I64_Default, DataType.U16_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.U16_Default, DataType.I64_Default, DataType.F32_Default,) \ + .dtype_format(DataType.I64_Default, DataType.U32_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.U32_Default, DataType.I64_Default, DataType.F32_Default,) \ + .dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.U64_Default, DataType.I64_Default, DataType.F32_Default,) \ + .dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.F16_Default, DataType.I64_Default, DataType.F32_Default,) \ + .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default,) \ + .dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.I64_Default, \ + DataType.I64_Default, DataType.F64_Default, DataType.I64_Default, DataType.F32_Default,) \ + .get_op_info() + +@op_info_register(edit_distance_op_info) +def _edit_distance_aicpu(): + """EditDistance AiCPU register""" + return diff --git a/mindspore/ops/_op_impl/aicpu/trans_data.py b/mindspore/ops/_op_impl/aicpu/trans_data.py new file mode 100644 index 0000000000..8f3fe0503d --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/trans_data.py @@ -0,0 +1,34 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""TransData op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +trans_data_op_info = AiCPURegOp("TransData") \ + .fusion_type("OPAQUE") \ + .input(0, "src", "required") \ + .output(0, "dst", "required") \ + .attr("src_format", "str") \ + .attr("dst_format", "str") \ + .dtype_format(DataType.U16_NCHW, DataType.U16_5HD) \ + .dtype_format(DataType.U16_5HD, DataType.U16_NCHW) \ + .dtype_format(DataType.U16_Default, DataType.U16_5HD) \ + .dtype_format(DataType.U16_5HD, DataType.U16_Default) \ + .get_op_info() + +@op_info_register(trans_data_op_info) +def _trans_data_aicpu(): + """TransData aicpu register""" + return diff --git a/mindspore/ops/_op_impl/aicpu/unique_with_pad.py b/mindspore/ops/_op_impl/aicpu/unique_with_pad.py new file mode 100644 index 0000000000..3828ca1764 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/unique_with_pad.py @@ -0,0 +1,32 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""UniqueWithPad op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +unique_with_pad_op_info = AiCPURegOp("UniqueWithPad") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .input(1, "pad_num", "required") \ + .output(0, "y", "required") \ + .output(1, "idx", "required") \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .get_op_info() + +@op_info_register(unique_with_pad_op_info) +def _unique_with_pad_aicpu(): + """UniqueWithPad AiCPU register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index a504f53e3e..789906fa81 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Meshgrid, SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, - Shape, DynamicShape, Size, Slice, Split, TransShape, ParallelConcat, Padding, + Shape, DynamicShape, Size, Slice, Split, TransShape, ParallelConcat, Padding, UniqueWithPad, ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint, Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance, Sort, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, @@ -156,6 +156,7 @@ __all__ = [ 'Padding', 'GatherD', 'Identity', + 'UniqueWithPad', 'Concat', 'Pack', 'Unpack', diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index dbc8f9b8a9..956d4a669e 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -747,6 +747,41 @@ class Padding(PrimitiveWithInfer): return out +class UniqueWithPad(PrimitiveWithInfer): + """ + Return unique elements and relative indexes in 1-D tensor, fill with pad num. + + Inputs: + - **x** (Tensor) - The tensor need to be unique. Must be 1-D vector with types: int32, int64. + - **pad_num** (int) - Pad num. + + Outputs: + tuple(Tensor), tuple of 2 tensors, y and idx. + - y (Tensor) - The unique elements filled with pad_num, the shape and type same as x. + - idx (Tensor) - The index of each value of x in the unique output y, the shape and type same as x. + + Examples: + >>> x = Tensor(np.array([1, 1, 5, 5, 4, 4, 3, 3, 2, 2,]), mindspore.int32) + >>> pad_num = 8 + >>> out = P.UniqueWithPad()(x, pad_num) + ([1, 5, 4, 3, 2, 8, 8, 8, 8, 8], [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]) + """ + @prim_attr_register + def __init__(self): + """init UniqueWithPad""" + + def __infer__(self, x, pad_num): + validator.check_tensor_type_same({"x": x['dtype']}, [mstype.int32, mstype.int64], self.name) + validator.check_subclass("pad_num", pad_num['dtype'], [mstype.int32, mstype.int64], self.name) + x_shape = list(x['shape']) + validator.check("rank of x", len(x_shape), "expected", 1, Rel.EQ, self.name) + out_shape = x_shape + out = {'shape': (out_shape, out_shape), + 'dtype': (x['dtype'], x['dtype']), + 'value': None} + return out + + class Split(PrimitiveWithInfer): """ Splits input tensor into output_num of tensors along the given axis and output numbers. diff --git a/tests/st/ops/ascend/test_edit_distance.py b/tests/st/ops/ascend/test_edit_distance.py new file mode 100644 index 0000000000..651ebfce31 --- /dev/null +++ b/tests/st/ops/ascend/test_edit_distance.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend") + + +class EditDistance(nn.Cell): + def __init__(self, hypothesis_shape, truth_shape, normalize=True): + super(EditDistance, self).__init__() + self.edit_distance = P.EditDistance(normalize) + self.hypothesis_shape = hypothesis_shape + self.truth_shape = truth_shape + + def construct(self, hypothesis_indices, hypothesis_values, truth_indices, truth_values): + return self.edit_distance(hypothesis_indices, hypothesis_values, self.hypothesis_shape, + truth_indices, truth_values, self.truth_shape) + +def test_edit_distance(): + h1, h2, h3 = np.array([[0, 0, 0], [1, 0, 1], [1, 1, 1]]), np.array([1, 2, 3]), np.array([2, 2, 2]) + t1, t2, t3 = np.array([[0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]), np.array([1, 2, 3, 1]), np.array([2, 2, 2]) + hypothesis_indices = Tensor(h1.astype(np.int64)) + hypothesis_values = Tensor(h2.astype(np.int64)) + hypothesis_shape = Tensor(h3.astype(np.int64)) + truth_indices = Tensor(t1.astype(np.int64)) + truth_values = Tensor(t2.astype(np.int64)) + truth_shape = Tensor(t3.astype(np.int64)) + edit_distance = EditDistance(hypothesis_shape, truth_shape) + out = edit_distance(hypothesis_indices, hypothesis_values, truth_indices, truth_values) + print(out) diff --git a/tests/st/ops/ascend/test_unique_with_pad.py b/tests/st/ops/ascend/test_unique_with_pad.py new file mode 100644 index 0000000000..e42739d572 --- /dev/null +++ b/tests/st/ops/ascend/test_unique_with_pad.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, pad_num): + super(Net, self).__init__() + self.unique_with_pad = P.UniqueWithPad() + self.pad_num = pad_num + + def construct(self, x): + return self.unique_with_pad(x, self.pad_num) + + +def test_unique_with_pad(): + x = Tensor(np.array([1, 1, 5, 5, 4, 4, 3, 3, 2, 2]), mstype.int32) + pad_num = 8 + unique_with_pad = Net(pad_num) + out = unique_with_pad(x) + expect_val = ([1, 5, 4, 3, 2, 8, 8, 8, 8, 8], [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]) + assert(out[0].asnumpy() == expect_val[0]).all() + assert(out[1].asnumpy() == expect_val[1]).all()