Merge pull request #9765 from jacquesqiao/add-insert-op

add insert_op for block
fea/docker_cudnn7
Tao Luo 7 years ago committed by GitHub
commit 9100424048
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -659,7 +659,7 @@ class Block(object):
def __init__(self, program, idx):
self.desc = program.desc.block(idx)
self.vars = dict() # var_name --> var
self.ops = collections.deque() # operator list
self.ops = list() # operator list
self.program = program
self.removed_vars = dict()
@ -831,6 +831,13 @@ class Block(object):
self.ops.append(op)
return op
def insert_op(self, index, *args, **kwargs):
self.sync_with_cpp()
op_desc = self.desc.insert_op(index)
op = Operator(block=self, desc=op_desc, *args, **kwargs)
self.ops.insert(index, op)
return op
def delete_ops(self, ops):
# remove from cpp
# FIXME(typhoonzero): remove only the first occurrence.
@ -842,12 +849,12 @@ class Block(object):
self.desc.remove_op(start, end + 1)
def slice_ops(self, start, end):
return list(self.ops)[start:end]
return self.ops[start:end]
def prepend_op(self, *args, **kwargs):
op_desc = self.desc.prepend_op()
op = Operator(self, op_desc, *args, **kwargs)
self.ops.appendleft(op)
self.ops.insert(0, op)
return op
def sync_with_cpp(self):
@ -892,7 +899,7 @@ class Block(object):
for index in range((start_index - 1 - 1), -1, -1):
op_desc = ops_in_cpp[index]
op = Operator(self, op_desc)
self.ops.appendleft(op)
self.ops.insert(0, op)
# sync ops append to the end of cpp_ops
for index in range((end_index + 1), len(ops_in_cpp)):

Loading…
Cancel
Save