From f9bd460c96ff93dcb95f01fdaa1a9bcc05fe9838 Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Wed, 13 May 2020 09:37:17 +0800 Subject: [PATCH] support vm for pack --- mindspore/ops/_op_impl/tbe/__init__.py | 2 + mindspore/ops/_op_impl/tbe/pack.py | 57 ++++++++++++++++++++++++++ mindspore/ops/_op_impl/tbe/unpack.py | 56 +++++++++++++++++++++++++ mindspore/ops/op_info_register.py | 12 ++++++ tests/ut/python/ops/test_array_ops.py | 23 +++++++++++ 5 files changed, 150 insertions(+) create mode 100644 mindspore/ops/_op_impl/tbe/pack.py create mode 100644 mindspore/ops/_op_impl/tbe/unpack.py diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index c8aa30f2c2..846a9c97c0 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -181,3 +181,5 @@ from .sgd import sgd_op_info from .lars_update import lars_update_op_info from .bn_training_update_v2 import _bn_training_update_v2_tbe from .square_sum_all import square_sum_all_op_info +from .pack import _pack_tbe +from .unpack import _unpack_tbe diff --git a/mindspore/ops/_op_impl/tbe/pack.py b/mindspore/ops/_op_impl/tbe/pack.py new file mode 100644 index 0000000000..fa1b1a2644 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/pack.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Pack op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +pack_op_info = TBERegOp("Pack") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("pack.so") \ + .compute_cost(10) \ + .kernel_name("pack") \ + .partial_flag(True) \ + .attr("axis", "optional", "int", "all") \ + .input(0, "x", False, "dynamic", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I8_NDHWC, DataType.I8_NDHWC) \ + .dtype_format(DataType.I16_NDHWC, DataType.I16_NDHWC) \ + .dtype_format(DataType.I32_NDHWC, DataType.I32_NDHWC) \ + .dtype_format(DataType.I64_NDHWC, DataType.I64_NDHWC) \ + .dtype_format(DataType.U8_NDHWC, DataType.U8_NDHWC) \ + .dtype_format(DataType.U16_NDHWC, DataType.U16_NDHWC) \ + .dtype_format(DataType.U32_NDHWC, DataType.U32_NDHWC) \ + .dtype_format(DataType.U64_NDHWC, DataType.U64_NDHWC) \ + .dtype_format(DataType.F16_NDHWC, DataType.F16_NDHWC) \ + .dtype_format(DataType.F32_NDHWC, DataType.F32_NDHWC) \ + .dtype_format(DataType.BOOL_NDHWC, DataType.BOOL_NDHWC) \ + .get_op_info() + + +@op_info_register(pack_op_info) +def _pack_tbe(): + """Pack TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/unpack.py b/mindspore/ops/_op_impl/tbe/unpack.py new file mode 100644 index 0000000000..314f81afa5 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/unpack.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Unpack op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +unpack_op_info = TBERegOp("Unpack") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("unpack.so") \ + .compute_cost(10) \ + .kernel_name("unpack") \ + .partial_flag(True) \ + .attr("num", "optional", "int", "all") \ + .attr("axis", "required", "int", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "dynamic", "all") \ + .dtype_format(DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I8_5HD, DataType.I8_5HD) \ + .dtype_format(DataType.I16_5HD, DataType.I16_5HD) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.I64_5HD, DataType.I64_5HD) \ + .dtype_format(DataType.U8_5HD, DataType.U8_5HD) \ + .dtype_format(DataType.U16_5HD, DataType.U16_5HD) \ + .dtype_format(DataType.U32_5HD, DataType.U32_5HD) \ + .dtype_format(DataType.U64_5HD, DataType.U64_5HD) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(unpack_op_info) +def _unpack_tbe(): + """Unpack TBE register""" + return diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index 1411ccc6f1..3ca616f2dc 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -499,6 +499,7 @@ class DataType: BOOL_NCHW = ("bool", "NCHW") BOOL_NHWC = ("bool", "NHWC") BOOL_HWCN = ("bool", "HWCN") + BOOL_NDHWC = ("bool", "NDHWC") I8_None = ("int8", "") I8_Default = ("int8", "DefaultFormat") @@ -509,6 +510,7 @@ class DataType: I8_NCHW = ("int8", "NCHW") I8_NHWC = ("int8", "NHWC") I8_HWCN = ("int8", "HWCN") + I8_NDHWC = ("int8", "NDHWC") U8_None = ("uint8", "") U8_Default = ("uint8", "DefaultFormat") @@ -519,6 +521,7 @@ class DataType: U8_NCHW = ("uint8", "NCHW") U8_NHWC = ("uint8", "NHWC") U8_HWCN = ("uint8", "HWCN") + U8_NDHWC = ("uint8", "NDHWC") I16_None = ("int16", "") I16_Default = ("int16", "DefaultFormat") @@ -529,6 +532,7 @@ class DataType: I16_NCHW = ("int16", "NCHW") I16_NHWC = ("int16", "NHWC") I16_HWCN = ("int16", "HWCN") + I16_NDHWC = ("int16", "NDHWC") U16_None = ("uint16", "") U16_Default = ("uint16", "DefaultFormat") @@ -539,6 +543,7 @@ class DataType: U16_NCHW = ("uint16", "NCHW") U16_NHWC = ("uint16", "NHWC") U16_HWCN = ("uint16", "HWCN") + U16_NDHWC = ("uint16", "NDHWC") I32_None = ("int32", "") I32_Default = ("int32", "DefaultFormat") @@ -549,6 +554,7 @@ class DataType: I32_NCHW = ("int32", "NCHW") I32_NHWC = ("int32", "NHWC") I32_HWCN = ("int32", "HWCN") + I32_NDHWC = ("int32", "NDHWC") U32_None = ("uint32", "") U32_Default = ("uint32", "DefaultFormat") @@ -559,6 +565,7 @@ class DataType: U32_NCHW = ("uint32", "NCHW") U32_NHWC = ("uint32", "NHWC") U32_HWCN = ("uint32", "HWCN") + U32_NDHWC = ("uint32", "NDHWC") I64_None = ("int64", "") I64_Default = ("int64", "DefaultFormat") @@ -569,6 +576,7 @@ class DataType: I64_NCHW = ("int64", "NCHW") I64_NHWC = ("int64", "NHWC") I64_HWCN = ("int64", "HWCN") + I64_NDHWC = ("int64", "NDHWC") U64_None = ("uint64", "") U64_Default = ("uint64", "DefaultFormat") @@ -579,6 +587,7 @@ class DataType: U64_NCHW = ("uint64", "NCHW") U64_NHWC = ("uint64", "NHWC") U64_HWCN = ("uint64", "HWCN") + U64_NDHWC = ("uint64", "NDHWC") F16_None = ("float16", "") F16_Default = ("float16", "DefaultFormat") @@ -589,6 +598,7 @@ class DataType: F16_NCHW = ("float16", "NCHW") F16_NHWC = ("float16", "NHWC") F16_HWCN = ("float16", "HWCN") + F16_NDHWC = ("float16", "NDHWC") F32_None = ("float32", "") F32_Default = ("float32", "DefaultFormat") @@ -599,6 +609,7 @@ class DataType: F32_NCHW = ("float32", "NCHW") F32_NHWC = ("float32", "NHWC") F32_HWCN = ("float32", "HWCN") + F32_NDHWC = ("float32", "NDHWC") F64_None = ("float64", "") F64_Default = ("float64", "DefaultFormat") @@ -609,3 +620,4 @@ class DataType: F64_NCHW = ("float64", "NCHW") F64_NHWC = ("float64", "NHWC") F64_HWCN = ("float64", "HWCN") + F64_NDHWC = ("float64", "NDHWC") diff --git a/tests/ut/python/ops/test_array_ops.py b/tests/ut/python/ops/test_array_ops.py index 083d5b1c31..ebb3db3b26 100644 --- a/tests/ut/python/ops/test_array_ops.py +++ b/tests/ut/python/ops/test_array_ops.py @@ -227,6 +227,23 @@ class SpaceToBatchNet(Cell): return self.space_to_batch(x) +class PackNet(Cell): + def __init__(self): + super(PackNet, self).__init__() + self.pack = P.Pack() + + def construct(self, x): + return self.pack((x, x)) + + +class UnpackNet(Cell): + def __init__(self): + super(UnpackNet, self).__init__() + self.unpack = P.Unpack() + + def construct(self, x): + return self.unpack(x) + test_case_array_ops = [ ('CustNet1', { 'block': CustNet1(), @@ -249,6 +266,12 @@ test_case_array_ops = [ ('SpaceToBatchNet', { 'block': SpaceToBatchNet(), 'desc_inputs': [Tensor(np.array([[[[1, 2], [3, 4]]]]).astype(np.float16))]}), + ('PackNet', { + 'block': PackNet(), + 'desc_inputs': [Tensor(np.array([[[1, 2], [3, 4]]]).astype(np.float16))]}), + ('UnpackNet', { + 'block': UnpackNet(), + 'desc_inputs': [Tensor(np.array([[1, 2], [3, 4]]).astype(np.float16))]}), ] test_case_lists = [test_case_array_ops]