|
|
|
@ -6,12 +6,13 @@ import contextlib
|
|
|
|
|
from ..registry import autodoc
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard', 'StaticRNNGuard',
|
|
|
|
|
'StaticRNNMemoryLink', 'WhileGuard', 'While', 'lod_rank_table',
|
|
|
|
|
'max_sequence_len', 'topk', 'lod_tensor_to_array', 'array_to_lod_tensor',
|
|
|
|
|
'increment', 'array_write', 'create_array', 'less_than', 'array_read',
|
|
|
|
|
'shrink_memory', 'array_length', 'IfElse', 'DynamicRNN', 'ConditionalBlock',
|
|
|
|
|
'StaticRNN', 'reorder_lod_tensor_by_rank'
|
|
|
|
|
'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard',
|
|
|
|
|
'BlockGuardWithCompletion', 'StaticRNNMemoryLink', 'WhileGuard', 'While',
|
|
|
|
|
'lod_rank_table', 'max_sequence_len', 'topk', 'lod_tensor_to_array',
|
|
|
|
|
'array_to_lod_tensor', 'increment', 'array_write', 'create_array',
|
|
|
|
|
'less_than', 'array_read', 'shrink_memory', 'array_length', 'IfElse',
|
|
|
|
|
'DynamicRNN', 'ConditionalBlock', 'StaticRNN', 'reorder_lod_tensor_by_rank',
|
|
|
|
|
'ParallelDo'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -132,29 +133,129 @@ class BlockGuard(object):
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StaticRNNGuard(BlockGuard):
|
|
|
|
|
class ParallelDo(object):
|
|
|
|
|
"""
|
|
|
|
|
StaticRNNGuard class.
|
|
|
|
|
ParallelDo class.
|
|
|
|
|
|
|
|
|
|
StaticRNNGuard class is used to create a StaticRNN block in a program.
|
|
|
|
|
ParallelDo class is used to create a ParallelDo.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, places, name=None):
|
|
|
|
|
self.helper = LayerHelper("parallel_do", name=name)
|
|
|
|
|
self.inputs = []
|
|
|
|
|
self.places = places
|
|
|
|
|
self.outputs = []
|
|
|
|
|
self.status = StaticRNN.BEFORE_RNN_BLOCK
|
|
|
|
|
|
|
|
|
|
def do(self):
|
|
|
|
|
return BlockGuardWithCompletion(self)
|
|
|
|
|
|
|
|
|
|
def parent_block(self):
|
|
|
|
|
prog = self.helper.main_program
|
|
|
|
|
parent_idx = prog.current_block().parent_idx
|
|
|
|
|
assert parent_idx >= 0
|
|
|
|
|
parent_block = prog.block(parent_idx)
|
|
|
|
|
return parent_block
|
|
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
|
if self.status != StaticRNN.AFTER_RNN_BLOCK:
|
|
|
|
|
raise ValueError("RNN output can only be retrieved after rnn block")
|
|
|
|
|
if len(self.outputs) == 0:
|
|
|
|
|
raise ValueError("RNN has no output")
|
|
|
|
|
elif len(self.outputs) == 1:
|
|
|
|
|
return self.outputs[0]
|
|
|
|
|
else:
|
|
|
|
|
return self.outputs
|
|
|
|
|
|
|
|
|
|
def read_input(self, var):
|
|
|
|
|
self.inputs.append(var)
|
|
|
|
|
return var
|
|
|
|
|
|
|
|
|
|
def write_output(self, var):
|
|
|
|
|
self.outputs.append(var)
|
|
|
|
|
|
|
|
|
|
def get_parameters(self):
|
|
|
|
|
main_program = self.helper.main_program
|
|
|
|
|
current_block = main_program.current_block()
|
|
|
|
|
parent_block = self.parent_block()
|
|
|
|
|
|
|
|
|
|
local_inputs = set()
|
|
|
|
|
|
|
|
|
|
for op in current_block.ops:
|
|
|
|
|
for oname in op.output_names:
|
|
|
|
|
for out_var_name in op.output(oname):
|
|
|
|
|
local_inputs.add(out_var_name)
|
|
|
|
|
|
|
|
|
|
for var in self.inputs:
|
|
|
|
|
local_inputs.add(var.name)
|
|
|
|
|
|
|
|
|
|
params = list()
|
|
|
|
|
for op in current_block.ops:
|
|
|
|
|
for iname in op.input_names:
|
|
|
|
|
for in_var_name in op.input(iname):
|
|
|
|
|
if in_var_name not in local_inputs:
|
|
|
|
|
params.append(in_var_name)
|
|
|
|
|
|
|
|
|
|
return [parent_block.var(name) for name in params]
|
|
|
|
|
|
|
|
|
|
def complete_op(self):
|
|
|
|
|
main_program = self.helper.main_program
|
|
|
|
|
current_block = main_program.current_block()
|
|
|
|
|
parent_block = self.parent_block()
|
|
|
|
|
|
|
|
|
|
step_scope = parent_block.create_var(
|
|
|
|
|
type=core.VarDesc.VarType.STEP_SCOPES)
|
|
|
|
|
|
|
|
|
|
self.outputs = [
|
|
|
|
|
parent_block.create_var(
|
|
|
|
|
name=o.name,
|
|
|
|
|
shape=o.shape,
|
|
|
|
|
dtype=o.dtype,
|
|
|
|
|
lod_level=o.lod_level,
|
|
|
|
|
persistable=o.persistable,
|
|
|
|
|
stop_gradient=o.stop_gradient) for o in self.outputs
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
inputs = [parent_block.var(i.name) for i in self.inputs]
|
|
|
|
|
outputs = [parent_block.var(o.name) for o in self.outputs]
|
|
|
|
|
|
|
|
|
|
parent_block.append_op(
|
|
|
|
|
type='parallel_do',
|
|
|
|
|
inputs={
|
|
|
|
|
'inputs': inputs,
|
|
|
|
|
'parameters': self.get_parameters(),
|
|
|
|
|
'places': self.places
|
|
|
|
|
},
|
|
|
|
|
outputs={'outputs': outputs,
|
|
|
|
|
'parallel_scopes': [step_scope]},
|
|
|
|
|
attrs={'sub_block': current_block})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BlockGuardWithCompletion(BlockGuard):
|
|
|
|
|
"""
|
|
|
|
|
BlockGuardWithCompletion class.
|
|
|
|
|
|
|
|
|
|
BlockGuardWithCompletion class is used to create an op with a block in a program.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, rnn):
|
|
|
|
|
if not isinstance(rnn, StaticRNN):
|
|
|
|
|
raise TypeError("StaticRNNGuard takes a StaticRNN")
|
|
|
|
|
super(StaticRNNGuard, self).__init__(rnn.helper.main_program)
|
|
|
|
|
if not (isinstance(rnn, StaticRNN) or isinstance(rnn, ParallelDo)):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"BlockGuardWithCompletion takes a StaticRNN or ParallelDo")
|
|
|
|
|
super(BlockGuardWithCompletion, self).__init__(rnn.helper.main_program)
|
|
|
|
|
self.rnn = rnn
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
|
self.rnn.status = StaticRNN.IN_RNN_BLOCK
|
|
|
|
|
return super(StaticRNNGuard, self).__enter__()
|
|
|
|
|
return super(BlockGuardWithCompletion, self).__enter__()
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
|
if exc_type is not None:
|
|
|
|
|
return False
|
|
|
|
|
self.rnn.status = StaticRNN.AFTER_RNN_BLOCK
|
|
|
|
|
self.rnn.complete_rnn_op()
|
|
|
|
|
return super(StaticRNNGuard, self).__exit__(exc_type, exc_val, exc_tb)
|
|
|
|
|
self.rnn.complete_op()
|
|
|
|
|
return super(BlockGuardWithCompletion, self).__exit__(exc_type, exc_val,
|
|
|
|
|
exc_tb)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StaticRNNMemoryLink(object):
|
|
|
|
@ -200,7 +301,7 @@ class StaticRNN(object):
|
|
|
|
|
self.seq_len = None
|
|
|
|
|
|
|
|
|
|
def step(self):
|
|
|
|
|
return StaticRNNGuard(self)
|
|
|
|
|
return BlockGuardWithCompletion(self)
|
|
|
|
|
|
|
|
|
|
def _assert_in_rnn_block_(self, method):
|
|
|
|
|
if self.status != StaticRNN.IN_RNN_BLOCK:
|
|
|
|
@ -316,7 +417,7 @@ class StaticRNN(object):
|
|
|
|
|
else:
|
|
|
|
|
return self.outputs
|
|
|
|
|
|
|
|
|
|
def complete_rnn_op(self):
|
|
|
|
|
def complete_op(self):
|
|
|
|
|
main_program = self.helper.main_program
|
|
|
|
|
rnn_block = main_program.current_block()
|
|
|
|
|
parent_block = self.parent_block()
|
|
|
|
|