From 70cc548e32f713fa6710c5f5baf1b0de9e3f0b5e Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Sun, 7 Feb 2021 20:12:59 +0800 Subject: [PATCH] Fix calling recompute api after compiling --- mindspore/nn/cell.py | 22 +++--- tests/ut/python/optimizer/test_recompute.py | 82 +++++++++++++++++++++ 2 files changed, 95 insertions(+), 9 deletions(-) create mode 100644 tests/ut/python/optimizer/test_recompute.py diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index ba37f5bcdc..deab42dd89 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -944,11 +944,8 @@ class Cell(Cell_): """Sets the name on the first time.""" if self._scope is None: self._scope = name - elif self._scope == 'recompute': - if name is None: - self._scope = None - elif name != 'recompute': - self._scope = self._scope + '_' + name + elif self._scope == 'recompute_': + self._scope = self._scope + name def _children_scope_recursive(self, parent_prefix='Default'): """Generates the scope of each layer of the network recursively.""" @@ -1129,6 +1126,16 @@ class Cell(Cell_): param.comm_fusion = fusion_type return self + def _set_recompute_scope(self, mode): + prefix = 'recompute_' + if mode is True: + if self._scope is None: + self._scope = prefix + elif not self._scope.startswith(prefix): + self._scope = prefix + self._scope + elif not self._scope is None and self._scope.startswith(prefix): + self._scope = self._scope[len(prefix):] + def recompute(self, mode=True): """ Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive feeds into a grad @@ -1137,10 +1144,7 @@ class Cell(Cell_): mode (bool): Specifies whether the cell is recomputed. Default: True. """ Validator.check_bool(mode) - if mode is True: - self._set_scope("recompute") - else: - self._set_scope(None) + self._set_recompute_scope(mode) for cell in self.cells(): cell.recompute(mode) diff --git a/tests/ut/python/optimizer/test_recompute.py b/tests/ut/python/optimizer/test_recompute.py new file mode 100644 index 0000000000..28bbb38de8 --- /dev/null +++ b/tests/ut/python/optimizer/test_recompute.py @@ -0,0 +1,82 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore.context as context +import mindspore.nn as nn + +context.set_context(mode=context.GRAPH_MODE) +recompute_prefix = 'recompute_' + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + + def construct(self, input_x): + output = self.pool(input_x) + return output + + +def test_set_recompute_true(): + net = Net() + net.pool.recompute() + assert net.pool.get_scope() == recompute_prefix + + +def test_set_recompute_false(): + net = Net() + net.pool.recompute(False) + assert net.pool.get_scope() is None + + +def test_set_recompute_true_twice(): + net = Net() + net.pool.recompute() + net.pool.recompute() + assert net.pool.get_scope() == recompute_prefix + + +def test_set_recompute_false_twice(): + net = Net() + net.pool.recompute(False) + net.pool.recompute(False) + assert net.pool.get_scope() is None + + +def test_reset_recompute1(): + net = Net() + net.pool.recompute(True) + net.pool.recompute(False) + assert net.pool.get_scope() == "" + + +def test_reset_recompute2(): + net = Net() + net.pool.recompute(False) + net.pool.recompute(True) + assert net.pool.get_scope() == recompute_prefix + + +def test_set_scope_and_set_recompute_repeatedly(): + net = Net() + net.pool.recompute(True) + assert net.pool.get_scope() == recompute_prefix + net.pool.recompute(False) + assert net.pool.get_scope() == "" + net.pool.recompute(True) + assert net.pool.get_scope() == recompute_prefix + net.pool.recompute(False) + assert net.pool.get_scope() == ""