|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
|
import paddle.fluid.core as core
|
|
|
|
|
from paddle.fluid.framework import Program
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestOpDesc(unittest.TestCase):
|
|
|
|
@ -187,32 +188,46 @@ class TestBlockDesc(unittest.TestCase):
|
|
|
|
|
self.assertEqual(all_ops, [op0, op1, op2])
|
|
|
|
|
|
|
|
|
|
def test_remove_op(self):
|
|
|
|
|
prog = core.ProgramDesc()
|
|
|
|
|
program = Program()
|
|
|
|
|
prog = program.desc
|
|
|
|
|
self.assertIsNotNone(prog)
|
|
|
|
|
block = prog.block(0)
|
|
|
|
|
self.assertIsNotNone(block)
|
|
|
|
|
|
|
|
|
|
op0 = block.append_op()
|
|
|
|
|
op1 = block.append_op()
|
|
|
|
|
op2 = block.append_op()
|
|
|
|
|
op0.set_type("test")
|
|
|
|
|
op1.set_type("test")
|
|
|
|
|
op2.set_type("test")
|
|
|
|
|
|
|
|
|
|
var0 = block.var("var0")
|
|
|
|
|
var1 = block.var("var1")
|
|
|
|
|
var2 = block.var("var2")
|
|
|
|
|
var3 = block.var("var3")
|
|
|
|
|
var4 = block.var("var4")
|
|
|
|
|
var5 = block.var("var5")
|
|
|
|
|
|
|
|
|
|
op0.set_input("X", ["var0"])
|
|
|
|
|
op0.set_output("Y", ["var0"])
|
|
|
|
|
op1.set_input("X", ["var1", "var2"])
|
|
|
|
|
op1.set_output("Y", ["var3", "var4"])
|
|
|
|
|
op2.set_input("X", ["var1"])
|
|
|
|
|
op2.set_output("Y", ["var4", "var5"])
|
|
|
|
|
|
|
|
|
|
program.sync_with_cpp()
|
|
|
|
|
|
|
|
|
|
# remove op1, its input var2 and output var3 will be removed at the same time,
|
|
|
|
|
# but its input var1 and output var4 will not be removed since they are used for op2.
|
|
|
|
|
block.remove_op(0, 1)
|
|
|
|
|
block.remove_op(1, 2)
|
|
|
|
|
program.sync_with_cpp()
|
|
|
|
|
|
|
|
|
|
all_ops = []
|
|
|
|
|
for idx in xrange(0, block.op_size()):
|
|
|
|
|
all_ops.append(block.op(idx))
|
|
|
|
|
self.assertEqual(all_ops, [op2])
|
|
|
|
|
self.assertEqual(all_ops, [op0, op2])
|
|
|
|
|
all_vars = block.all_vars()
|
|
|
|
|
self.assertEqual(set(all_vars), {var1, var4, var5})
|
|
|
|
|
self.assertEqual(set(all_vars), {var0, var1, var4, var5})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|