diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc index 255b77af88..113da33f89 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc @@ -49,11 +49,11 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) { if (vars_.count(var)) { AnfNodePtr node = vars_[var]; MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - return NewValueNode(GetValueNode(node)); - } else { - return node; + auto iter = resolve_to_removable_phis_.find(node); + if (iter != resolve_to_removable_phis_.end()) { + return iter->second; } + return node; } // get var from predecessor block ,if can't get the make a resolve node to it if (matured_) { @@ -64,7 +64,13 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) { return block->ReadVariable(var); } else if (prev_blocks_.empty()) { // get namespace and make Reslove - return MakeResolveSymbol(var); + auto it = var_to_resolve_.find(var); + if (it != var_to_resolve_.end()) { + return it->second; + } + auto tmp_node = MakeResolveSymbol(var); + var_to_resolve_[var] = tmp_node; + return tmp_node; } } // If have more than one predecessor blocks then build a phi node. @@ -217,6 +223,7 @@ bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { // replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1." WriteVariable(var, arg_node); removable_phis_[phi] = arg_node; + resolve_to_removable_phis_[arg_node] = phi; // The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized // recursively". check if phi1 is assigned with this phi before, then phi1 can be replaced with arg_node. for (auto &prev : prev_blocks_) { diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.h b/mindspore/ccsrc/pipeline/jit/parse/function_block.h index 70476791d8..d7efba824b 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/function_block.h +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.h @@ -101,12 +101,21 @@ class FunctionBlock : public std::enable_shared_from_this { // keeps all removable phis which will be removed in one pass. std::unordered_map removable_phis_; + // Keeps the map for the resolve node to the removable phi node. + // For the case that ReadVariable returns a phi node although this phi node + // generated in the prev block is identified as removable. The other blocks + // should find this phi node. + std::unordered_map resolve_to_removable_phis_; + // hold declared global variables in function std::set global_vars_; // other depend need to insert before function return nodes. // summary or some other node std::vector auto_depends_; + + // keeps the new made resolve symbol for the variable not found in vars_. + std::unordered_map var_to_resolve_; }; } // namespace parse diff --git a/tests/st/control/test_cont_grad.py b/tests/st/control/test_cont_grad.py index 68f6b1f30d..140a572112 100644 --- a/tests/st/control/test_cont_grad.py +++ b/tests/st/control/test_cont_grad.py @@ -14,6 +14,7 @@ # ============================================================================ """ test control ops """ import numpy as np +import pytest from mindspore import dtype as ms from mindspore import Tensor @@ -1150,3 +1151,147 @@ def test_if_by_if_forward_all_const_branch(): end = Tensor(np.array(3), dtype=ms.float32) x = Tensor(np.array(0), dtype=ms.float32) net(idx, end, x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_if_const_grad(): + class MyNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + + def construct(self, *inputs): + out = self.add(*inputs) + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, *inputs): + a = 1 + b = 2 + if a > 0: + b = 1 + a += b + return grad_by_list(self.net, self.weights)(*inputs) + + context.set_context(mode=context.GRAPH_MODE) + my_net = MyNet() + net = GradNet(my_net) + a = Tensor(np.array(0), dtype=ms.int32) + b = Tensor(np.array(1), dtype=ms.int32) + net(a, b) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_if_by_if_const_grad(): + class MyNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + + def construct(self, *inputs): + out = self.add(*inputs) + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, *inputs): + a = 1 + b = 2 + if a > 0: + b = 1 + if a < 0: + b = 0 + if a == 0: + b = 3 + a += b + return grad_by_list(self.net, self.weights)(*inputs) + + context.set_context(mode=context.GRAPH_MODE) + my_net = MyNet() + net = GradNet(my_net) + a = Tensor(np.array(0), dtype=ms.int32) + b = Tensor(np.array(1), dtype=ms.int32) + net(a, b) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_while_const_grad(): + class MyNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + + def construct(self, *inputs): + out = self.add(*inputs) + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, *inputs): + a = 1 + while a > 1: + a = a - 1 + return grad_by_list(self.net, self.weights)(*inputs) + + context.set_context(mode=context.GRAPH_MODE) + my_net = MyNet() + net = GradNet(my_net) + a = Tensor(np.array(0), dtype=ms.int32) + b = Tensor(np.array(1), dtype=ms.int32) + net(a, b) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_if_by_while_const_grad(): + class MyNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + + def construct(self, *inputs): + out = self.add(*inputs) + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, *inputs): + a = 1 + b = 2 + if a > 0: + b = 0 + while a > 1: + a = a - 1 + a += b + return grad_by_list(self.net, self.weights)(*inputs) + + context.set_context(mode=context.GRAPH_MODE) + my_net = MyNet() + net = GradNet(my_net) + a = Tensor(np.array(0), dtype=ms.int32) + b = Tensor(np.array(1), dtype=ms.int32) + net(a, b) diff --git a/tests/ut/cpp/optimizer/opt_test.cc b/tests/ut/cpp/optimizer/opt_test.cc index df1996d3e5..4ab829dbd8 100644 --- a/tests/ut/cpp/optimizer/opt_test.cc +++ b/tests/ut/cpp/optimizer/opt_test.cc @@ -187,7 +187,7 @@ TEST_F(TestOptOpt, CSE) { FuncGraphManagerPtr manager1 = Manage(test_graph1); draw::Draw("opt_cse_before_1.dot", test_graph1); - ASSERT_EQ(manager1->all_nodes().size(), 10); + ASSERT_EQ(manager1->all_nodes().size(), 9); auto cse = std::make_shared(); ASSERT_TRUE(cse != nullptr); @@ -205,7 +205,7 @@ TEST_F(TestOptOpt, CSE) { FuncGraphManagerPtr manager2 = Manage(test_graph2); draw::Draw("opt_cse_before_2.dot", test_graph2); - ASSERT_EQ(manager2->all_nodes().size(), 22); + ASSERT_EQ(manager2->all_nodes().size(), 16); is_changed = cse->Cse(test_graph2, manager2); ASSERT_TRUE(is_changed); ASSERT_EQ(manager2->all_nodes().size(), 12);