From 79a3f0821a3b168d809ba2b50e9e73a9e1e995f5 Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Thu, 7 May 2020 10:49:30 +0800 Subject: [PATCH] add vm for batch_to_space and space_to_batch --- mindspore/ccsrc/kernel/tbe/tbe_adapter.cc | 2 ++ mindspore/ops/_op_impl/tbe/__init__.py | 2 ++ mindspore/ops/_op_impl/tbe/batch_to_space.py | 37 ++++++++++++++++++++ mindspore/ops/_op_impl/tbe/space_to_batch.py | 37 ++++++++++++++++++++ tests/ut/python/ops/test_array_ops.py | 16 +++++++++ 5 files changed, 94 insertions(+) create mode 100644 mindspore/ops/_op_impl/tbe/batch_to_space.py create mode 100644 mindspore/ops/_op_impl/tbe/space_to_batch.py diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 005c290aba..11d2185c32 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -75,6 +75,8 @@ static std::map tbe_func_adapter_map = { {"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"}, {"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"}, {"pad", "pad_d"}, + {"space_to_batch", "space_to_batch_d"}, + {"batch_to_space", "batch_to_space_d"}, {"adam", "apply_adam_d"}}; void TbeAdapter::NormalizeFuncName(std::string *func_name) { diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 73afef73a1..f24afff71b 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -154,3 +154,5 @@ from .scatter_nd_update import _scatter_nd_update_tbe from .avg_pool import _avg_pool_tbe from .avg_pool_grad import _avg_pool_grad_tbe from .ones_like import _ones_like_tbe +from .batch_to_space import _batch_to_space_tbe +from .space_to_batch import _space_to_batch_tbe diff --git a/mindspore/ops/_op_impl/tbe/batch_to_space.py b/mindspore/ops/_op_impl/tbe/batch_to_space.py new file mode 100644 index 0000000000..5d0bdc1de3 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/batch_to_space.py @@ -0,0 +1,37 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BatchToSpace op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +batch_to_space_op_info = TBERegOp("BatchToSpace") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("batch_to_space_d.so") \ + .compute_cost(10) \ + .kernel_name("batch_to_space_d") \ + .partial_flag(True) \ + .attr("block_size", "required", "int", "all") \ + .attr("crops", "required", "listListInt", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .get_op_info() + + +@op_info_register(batch_to_space_op_info) +def _batch_to_space_tbe(): + """BatchToSpace TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/space_to_batch.py b/mindspore/ops/_op_impl/tbe/space_to_batch.py new file mode 100644 index 0000000000..d7c31edcbf --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/space_to_batch.py @@ -0,0 +1,37 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SpaceToBatch op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +space_to_batch_op_info = TBERegOp("SpaceToBatch") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("space_to_batch_d.so") \ + .compute_cost(10) \ + .kernel_name("space_to_batch_d") \ + .partial_flag(True) \ + .attr("block_size", "required", "int", "all") \ + .attr("paddings", "required", "listListInt", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .get_op_info() + + +@op_info_register(space_to_batch_op_info) +def _space_to_batch_tbe(): + """SpaceToBatch TBE register""" + return diff --git a/tests/ut/python/ops/test_array_ops.py b/tests/ut/python/ops/test_array_ops.py index 61b8d48fea..35f262699d 100644 --- a/tests/ut/python/ops/test_array_ops.py +++ b/tests/ut/python/ops/test_array_ops.py @@ -95,6 +95,22 @@ def test_select(): expect = np.array([[1, 8, 9], [10, 5, 6]]) assert np.all(output.asnumpy() == expect) +def test_batch_to_space(): + block_size = 2 + crops = [[0, 0], [0, 0]] + batch_to_space = P.BatchToSpace(block_size, crops) + input_x = Tensor(np.array([[[[1]]], [[[2]]], [[[3]]], [[[4]]]]).astype(np.float16)) + output = batch_to_space(input_x) + assert output.shape() == (1, 1, 2, 2) + +def test_space_to_batch(): + block_size = 2 + paddings = [[0, 0], [0, 0]] + space_to_batch = P.SpaceToBatch(block_size, paddings) + input_x = Tensor(np.array([[[[1, 2], [3, 4]]]]).astype(np.float16)) + output = space_to_batch(input_x) + assert output.shape() == (4, 1, 1, 1) + def test_argmin_invalid_output_type(): P.Argmin(-1, mstype.int64) P.Argmin(-1, mstype.int32)