!12248 Fix the bug of calling recompute api after compiled

From: @ginfung
Reviewed-by: 
Signed-off-by:
pull/12248/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 166f7e7809

@ -944,11 +944,8 @@ class Cell(Cell_):
"""Sets the name on the first time."""
if self._scope is None:
self._scope = name
elif self._scope == 'recompute':
if name is None:
self._scope = None
elif name != 'recompute':
self._scope = self._scope + '_' + name
elif self._scope == 'recompute_':
self._scope = self._scope + name
def _children_scope_recursive(self, parent_prefix='Default'):
"""Generates the scope of each layer of the network recursively."""
@ -1129,6 +1126,16 @@ class Cell(Cell_):
param.comm_fusion = fusion_type
return self
def _set_recompute_scope(self, mode):
prefix = 'recompute_'
if mode is True:
if self._scope is None:
self._scope = prefix
elif not self._scope.startswith(prefix):
self._scope = prefix + self._scope
elif not self._scope is None and self._scope.startswith(prefix):
self._scope = self._scope[len(prefix):]
def recompute(self, mode=True):
"""
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive feeds into a grad
@ -1137,10 +1144,7 @@ class Cell(Cell_):
mode (bool): Specifies whether the cell is recomputed. Default: True.
"""
Validator.check_bool(mode)
if mode is True:
self._set_scope("recompute")
else:
self._set_scope(None)
self._set_recompute_scope(mode)
for cell in self.cells():
cell.recompute(mode)

@ -0,0 +1,82 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.context as context
import mindspore.nn as nn
context.set_context(mode=context.GRAPH_MODE)
recompute_prefix = 'recompute_'
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def construct(self, input_x):
output = self.pool(input_x)
return output
def test_set_recompute_true():
net = Net()
net.pool.recompute()
assert net.pool.get_scope() == recompute_prefix
def test_set_recompute_false():
net = Net()
net.pool.recompute(False)
assert net.pool.get_scope() is None
def test_set_recompute_true_twice():
net = Net()
net.pool.recompute()
net.pool.recompute()
assert net.pool.get_scope() == recompute_prefix
def test_set_recompute_false_twice():
net = Net()
net.pool.recompute(False)
net.pool.recompute(False)
assert net.pool.get_scope() is None
def test_reset_recompute1():
net = Net()
net.pool.recompute(True)
net.pool.recompute(False)
assert net.pool.get_scope() == ""
def test_reset_recompute2():
net = Net()
net.pool.recompute(False)
net.pool.recompute(True)
assert net.pool.get_scope() == recompute_prefix
def test_set_scope_and_set_recompute_repeatedly():
net = Net()
net.pool.recompute(True)
assert net.pool.get_scope() == recompute_prefix
net.pool.recompute(False)
assert net.pool.get_scope() == ""
net.pool.recompute(True)
assert net.pool.get_scope() == recompute_prefix
net.pool.recompute(False)
assert net.pool.get_scope() == ""
Loading…
Cancel
Save