Following the design * https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md Just written `Program`, `Block` and unittest of program.revert-4814-Add_sequence_project_op
parent
cc1860c10e
commit
3c39df197e
@ -0,0 +1,45 @@
|
||||
import paddle.v2.framework.core as core
|
||||
|
||||
|
||||
class Block(object):
|
||||
def __init__(self, program, idx):
|
||||
self.proto = program.proto.block(idx)
|
||||
self.vars = dict() # var_name --> var
|
||||
self.ops = list() # operator list
|
||||
self.program = program
|
||||
|
||||
@property
|
||||
def parent_idx(self):
|
||||
return self.proto.parent
|
||||
|
||||
@property
|
||||
def idx(self):
|
||||
return self.proto.id
|
||||
|
||||
|
||||
class Program(object):
|
||||
def __init__(self):
|
||||
self.proto = core.ProgramDesc.instance()
|
||||
assert self.proto.num_blocks() == 1
|
||||
self.blocks = [Block(self, 0)]
|
||||
self.current_block_idx = 0
|
||||
|
||||
def global_block(self):
|
||||
return self.blocks[0]
|
||||
|
||||
def current_block(self):
|
||||
return self.blocks[self.current_block_idx]
|
||||
|
||||
def create_block(self):
|
||||
new_block_idx = len(self.blocks)
|
||||
self.proto.append_block(self.current_block().proto)
|
||||
self.current_block_idx = new_block_idx
|
||||
self.blocks.append(Block(self, self.current_block_idx))
|
||||
return self.current_block()
|
||||
|
||||
def rollback(self):
|
||||
self.current_block_idx = self.current_block().parent_idx
|
||||
|
||||
|
||||
# program is a global instance.
|
||||
g_program = Program()
|
@ -0,0 +1,36 @@
|
||||
import unittest
|
||||
from paddle.v2.framework.graph import g_program
|
||||
|
||||
|
||||
class TestProgram(unittest.TestCase):
|
||||
def test_program(self):
|
||||
b = g_program.current_block()
|
||||
self.assertEqual(-1, b.parent_idx)
|
||||
self.assertEqual(0, b.idx)
|
||||
|
||||
b = g_program.create_block()
|
||||
self.assertEqual(1, b.idx)
|
||||
self.assertEqual(0, b.parent_idx)
|
||||
|
||||
b = g_program.create_block()
|
||||
self.assertEqual(2, b.idx)
|
||||
self.assertEqual(1, b.parent_idx)
|
||||
|
||||
g_program.rollback()
|
||||
|
||||
b = g_program.current_block()
|
||||
self.assertEqual(1, b.idx)
|
||||
self.assertEqual(0, b.parent_idx)
|
||||
|
||||
b = g_program.create_block()
|
||||
self.assertEqual(3, b.idx)
|
||||
self.assertEqual(1, b.parent_idx)
|
||||
|
||||
g_program.rollback()
|
||||
b = g_program.current_block()
|
||||
self.assertEqual(1, b.idx)
|
||||
self.assertEqual(0, b.parent_idx)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue