|
|
|
@ -308,6 +308,9 @@ class Block(object):
|
|
|
|
|
def create_var(self, *args, **kwargs):
|
|
|
|
|
return Variable(self, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def has_var(self, name):
|
|
|
|
|
return name in self.vars
|
|
|
|
|
|
|
|
|
|
def create_parameter(self, *args, **kwargs):
|
|
|
|
|
global_block = self.program.global_block()
|
|
|
|
|
return Parameter(global_block, *args, **kwargs)
|
|
|
|
@ -324,6 +327,43 @@ class Block(object):
|
|
|
|
|
self.ops.appendleft(op)
|
|
|
|
|
return op
|
|
|
|
|
|
|
|
|
|
def sync_with_cpp(self):
|
|
|
|
|
# sync variables from cpp
|
|
|
|
|
for var in self.desc.all_vars():
|
|
|
|
|
if not self.has_var(var.name()):
|
|
|
|
|
self.create_var(name=var.name(), desc=var, type=var.type())
|
|
|
|
|
|
|
|
|
|
# sync operators from cpp
|
|
|
|
|
ops_in_cpp = self.desc.all_ops()
|
|
|
|
|
first_op_in_python = self.ops[0].desc
|
|
|
|
|
last_op_in_python = self.ops[len(self.ops) - 1].desc
|
|
|
|
|
start_index = None
|
|
|
|
|
end_index = None
|
|
|
|
|
for index in range(len(ops_in_cpp)):
|
|
|
|
|
if first_op_in_python == ops_in_cpp[index]:
|
|
|
|
|
start_index = index
|
|
|
|
|
if last_op_in_python == ops_in_cpp[index]:
|
|
|
|
|
end_index = index
|
|
|
|
|
assert start_index is not None
|
|
|
|
|
assert end_index is not None
|
|
|
|
|
assert start_index <= end_index
|
|
|
|
|
|
|
|
|
|
# sync ops append to the head of cpp_ops
|
|
|
|
|
for index in range((start_index - 1 - 1), -1, -1):
|
|
|
|
|
op_desc = ops_in_cpp[index]
|
|
|
|
|
op = Operator(self, op_desc)
|
|
|
|
|
self.ops.appendleft(op)
|
|
|
|
|
|
|
|
|
|
# sync ops append to the end of cpp_ops
|
|
|
|
|
for index in range((end_index + 1), len(ops_in_cpp)):
|
|
|
|
|
op_desc = ops_in_cpp[index]
|
|
|
|
|
op = Operator(self, op_desc)
|
|
|
|
|
self.ops.append(op)
|
|
|
|
|
|
|
|
|
|
assert len(self.ops) == len(ops_in_cpp)
|
|
|
|
|
for index in range(len(self.ops)):
|
|
|
|
|
assert self.ops[index].desc == ops_in_cpp[index]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Program(object):
|
|
|
|
|
@classmethod
|
|
|
|
@ -354,6 +394,12 @@ class Program(object):
|
|
|
|
|
def current_block(self):
|
|
|
|
|
return self.blocks[self.current_block_idx]
|
|
|
|
|
|
|
|
|
|
def append_backward(self, target, no_grad_set):
|
|
|
|
|
assert isinstance(target, Variable)
|
|
|
|
|
param_to_grad_info = self.desc.append_backward(target.desc, no_grad_set)
|
|
|
|
|
self.sync_with_cpp()
|
|
|
|
|
return param_to_grad_info
|
|
|
|
|
|
|
|
|
|
def create_block(self):
|
|
|
|
|
new_block_idx = len(self.blocks)
|
|
|
|
|
self.desc.append_block(self.current_block().desc)
|
|
|
|
@ -364,6 +410,12 @@ class Program(object):
|
|
|
|
|
def rollback(self):
|
|
|
|
|
self.current_block_idx = self.current_block().parent_idx
|
|
|
|
|
|
|
|
|
|
def sync_with_cpp(self):
|
|
|
|
|
for block_idx in range(len(self.blocks), self.desc.num_blocks()):
|
|
|
|
|
self.blocks.append(Block(self, block_idx))
|
|
|
|
|
for block in self.blocks:
|
|
|
|
|
block.sync_with_cpp()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Parameter(Variable):
|
|
|
|
|
def __init__(self, block, shape, dtype, **kwargs):
|
|
|
|
|