!8097 Remove redundant phi nodes

From: @ginfung
Reviewed-by: 
Signed-off-by:
pull/8097/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 9c957072e2

@ -49,11 +49,11 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
if (vars_.count(var)) {
AnfNodePtr node = vars_[var];
MS_EXCEPTION_IF_NULL(node);
if (node->isa<ValueNode>()) {
return NewValueNode(GetValueNode(node));
} else {
return node;
auto iter = resolve_to_removable_phis_.find(node);
if (iter != resolve_to_removable_phis_.end()) {
return iter->second;
}
return node;
}
// get var from predecessor block ,if can't get the make a resolve node to it
if (matured_) {
@ -64,7 +64,13 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
return block->ReadVariable(var);
} else if (prev_blocks_.empty()) {
// get namespace and make Reslove
return MakeResolveSymbol(var);
auto it = var_to_resolve_.find(var);
if (it != var_to_resolve_.end()) {
return it->second;
}
auto tmp_node = MakeResolveSymbol(var);
var_to_resolve_[var] = tmp_node;
return tmp_node;
}
}
// If have more than one predecessor blocks then build a phi node.
@ -217,6 +223,7 @@ bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
// replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1."
WriteVariable(var, arg_node);
removable_phis_[phi] = arg_node;
resolve_to_removable_phis_[arg_node] = phi;
// The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized
// recursively". check if phi1 is assigned with this phi before, then phi1 can be replaced with arg_node.
for (auto &prev : prev_blocks_) {

@ -101,12 +101,21 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
// keeps all removable phis which will be removed in one pass.
std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_;
// Keeps the map for the resolve node to the removable phi node.
// For the case that ReadVariable returns a phi node although this phi node
// generated in the prev block is identified as removable. The other blocks
// should find this phi node.
std::unordered_map<AnfNodePtr, ParameterPtr> resolve_to_removable_phis_;
// hold declared global variables in function
std::set<std::string> global_vars_;
// other depend need to insert before function return nodes.
// summary or some other node
std::vector<AnfNodePtr> auto_depends_;
// keeps the new made resolve symbol for the variable not found in vars_.
std::unordered_map<std::string, AnfNodePtr> var_to_resolve_;
};
} // namespace parse

@ -14,6 +14,7 @@
# ============================================================================
""" test control ops """
import numpy as np
import pytest
from mindspore import dtype as ms
from mindspore import Tensor
@ -1150,3 +1151,147 @@ def test_if_by_if_forward_all_const_branch():
end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(0), dtype=ms.float32)
net(idx, end, x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_if_const_grad():
class MyNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
def construct(self, *inputs):
out = self.add(*inputs)
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, *inputs):
a = 1
b = 2
if a > 0:
b = 1
a += b
return grad_by_list(self.net, self.weights)(*inputs)
context.set_context(mode=context.GRAPH_MODE)
my_net = MyNet()
net = GradNet(my_net)
a = Tensor(np.array(0), dtype=ms.int32)
b = Tensor(np.array(1), dtype=ms.int32)
net(a, b)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_if_by_if_const_grad():
class MyNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
def construct(self, *inputs):
out = self.add(*inputs)
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, *inputs):
a = 1
b = 2
if a > 0:
b = 1
if a < 0:
b = 0
if a == 0:
b = 3
a += b
return grad_by_list(self.net, self.weights)(*inputs)
context.set_context(mode=context.GRAPH_MODE)
my_net = MyNet()
net = GradNet(my_net)
a = Tensor(np.array(0), dtype=ms.int32)
b = Tensor(np.array(1), dtype=ms.int32)
net(a, b)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_while_const_grad():
class MyNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
def construct(self, *inputs):
out = self.add(*inputs)
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, *inputs):
a = 1
while a > 1:
a = a - 1
return grad_by_list(self.net, self.weights)(*inputs)
context.set_context(mode=context.GRAPH_MODE)
my_net = MyNet()
net = GradNet(my_net)
a = Tensor(np.array(0), dtype=ms.int32)
b = Tensor(np.array(1), dtype=ms.int32)
net(a, b)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_if_by_while_const_grad():
class MyNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
def construct(self, *inputs):
out = self.add(*inputs)
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, *inputs):
a = 1
b = 2
if a > 0:
b = 0
while a > 1:
a = a - 1
a += b
return grad_by_list(self.net, self.weights)(*inputs)
context.set_context(mode=context.GRAPH_MODE)
my_net = MyNet()
net = GradNet(my_net)
a = Tensor(np.array(0), dtype=ms.int32)
b = Tensor(np.array(1), dtype=ms.int32)
net(a, b)

@ -187,7 +187,7 @@ TEST_F(TestOptOpt, CSE) {
FuncGraphManagerPtr manager1 = Manage(test_graph1);
draw::Draw("opt_cse_before_1.dot", test_graph1);
ASSERT_EQ(manager1->all_nodes().size(), 10);
ASSERT_EQ(manager1->all_nodes().size(), 9);
auto cse = std::make_shared<CSE>();
ASSERT_TRUE(cse != nullptr);
@ -205,7 +205,7 @@ TEST_F(TestOptOpt, CSE) {
FuncGraphManagerPtr manager2 = Manage(test_graph2);
draw::Draw("opt_cse_before_2.dot", test_graph2);
ASSERT_EQ(manager2->all_nodes().size(), 22);
ASSERT_EQ(manager2->all_nodes().size(), 16);
is_changed = cse->Cse(test_graph2, manager2);
ASSERT_TRUE(is_changed);
ASSERT_EQ(manager2->all_nodes().size(), 12);

Loading…
Cancel
Save