From 30b93ecbf89b057182f863a309ea31a91ace4abf Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Mon, 25 May 2020 11:52:38 +0800 Subject: [PATCH] use reshape as flatten grad --- mindspore/ccsrc/pre_activate/common/helper.cc | 3 +- mindspore/ops/_op_impl/tbe/__init__.py | 1 + mindspore/ops/_op_impl/tbe/flatten_grad.py | 34 +++++++++++++++++++ tests/ut/python/ops/test_ops.py | 17 +++++++++- 4 files changed, 53 insertions(+), 2 deletions(-) create mode 100644 mindspore/ops/_op_impl/tbe/flatten_grad.py diff --git a/mindspore/ccsrc/pre_activate/common/helper.cc b/mindspore/ccsrc/pre_activate/common/helper.cc index 649e2746b5..4cda390fbb 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.cc +++ b/mindspore/ccsrc/pre_activate/common/helper.cc @@ -385,7 +385,8 @@ bool IsNopNode(const AnfNodePtr &node) { return false; } static std::unordered_set nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName, - prim::kPrimSqueeze->name(), prim::kPrimFlatten->name()}; + prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(), + kFlattenGradOpName}; if (node == nullptr || !node->isa()) { return false; } diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 93392adcb0..ccbd301589 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -197,3 +197,4 @@ from .cum_sum import _cum_sum_tbe from .apply_rms_prop import _apply_rms_prop_tbe from .cumprod import _cumprop_tbe from .reduce_prod import _reduce_prod_tbe +from .flatten_grad import _flatten_grad_tbe diff --git a/mindspore/ops/_op_impl/tbe/flatten_grad.py b/mindspore/ops/_op_impl/tbe/flatten_grad.py new file mode 100644 index 0000000000..43046bb619 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/flatten_grad.py @@ -0,0 +1,34 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Reshape op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +flatten_grad_op_info = TBERegOp("FlattenGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("reshape.so") \ + .compute_cost(10) \ + .kernel_name("reshape") \ + .partial_flag(True) \ + .attr("shape", "required", "listInt", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() +@op_info_register(flatten_grad_op_info) +def _flatten_grad_tbe(): + """Reshape TBE register""" + return diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 7b01809333..d702598036 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -121,6 +121,16 @@ class NetForFlatten0D(nn.Cell): return self.flatten(x) +class NetForFlattenComposed(nn.Cell): + # make flatten op together with other ops for testing flatten grad + def __init__(self): + super(NetForFlattenComposed, self).__init__() + self.flatten = P.Flatten() + + def construct(self, x, y): + return self.flatten(x+x) + y + + class ArgmaxNet(nn.Cell): def __init__(self): super(ArgmaxNet, self).__init__() @@ -695,7 +705,7 @@ test_case_nn_ops = [ ('Flatten', { 'block': P.Flatten(), 'desc_inputs': [[128, 32, 32, 64]], - 'desc_bprop': [[128 * 32 * 8 * 16]]}), + 'desc_bprop': [[128, 65536]]}), ('LogSoftmax', { 'block': P.LogSoftmax(), 'desc_inputs': [[64, 2]], @@ -893,6 +903,11 @@ test_case_nn_ops = [ 'desc_inputs': [Tensor(np.ones([8]).astype(np.int32)), Tensor(np.ones([8, 3]).astype(np.int32))], 'desc_bprop': [Tensor(np.ones([8, 3]).astype(np.int32))], 'skip': ['backward']}), + ('Flatten_3', { + 'block': NetForFlattenComposed(), + 'desc_inputs': [Tensor(np.ones([2, 3, 4]).astype(np.int32)), Tensor(np.ones([2, 12]).astype(np.int32))], + 'desc_bprop': [Tensor(np.ones([2, 12]).astype(np.int32))], + 'skip': []}), ('ArgmaxNet', { 'block': ArgmaxNet(), 'desc_inputs': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))],