Merge pull request #458 from luotao1/group

add layer check for recurrent_group
avx_docs
LCY-Seso 8 years ago committed by GitHub
commit d2ff3e4930

@ -494,8 +494,7 @@ def scaling_projection(input, param_attr=None):
:return: A ScalingProjection object :return: A ScalingProjection object
:rtype: ScalingProjection :rtype: ScalingProjection
""" """
proj = ScalingProjection(input_layer_name=input.name, proj = ScalingProjection(input_layer_name=input.name, **param_attr.attr)
**param_attr.attr)
proj.origin = input proj.origin = input
return proj return proj
@ -2783,7 +2782,12 @@ class SubsequenceInput(object):
@wrap_name_default("recurrent_group") @wrap_name_default("recurrent_group")
def recurrent_group(step, input, reverse=False, name=None, targetInlink=None): def recurrent_group(step,
input,
reverse=False,
name=None,
targetInlink=None,
is_generating=False):
""" """
Recurrent layer group is an extremely flexible recurrent unit in Recurrent layer group is an extremely flexible recurrent unit in
PaddlePaddle. As long as the user defines the calculation done within a PaddlePaddle. As long as the user defines the calculation done within a
@ -2848,6 +2852,12 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
:type targetInlink: LayerOutput|SubsequenceInput :type targetInlink: LayerOutput|SubsequenceInput
:param is_generating: If is generating, none of input type should be LayerOutput;
else, for training or testing, one of the input type must
be LayerOutput.
: type is_generating: bool
:return: LayerOutput object. :return: LayerOutput object.
:rtype: LayerOutput :rtype: LayerOutput
""" """
@ -2895,6 +2905,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
seq_reversed=reverse, seq_reversed=reverse,
target_inlinkname=targetInlinkName) target_inlinkname=targetInlinkName)
in_args = [] in_args = []
has_LayerOutput = True
for each_input in input: for each_input in input:
assert is_single_input(each_input) assert is_single_input(each_input)
if isinstance(each_input, LayerOutput): if isinstance(each_input, LayerOutput):
@ -2902,6 +2913,7 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
elif isinstance(each_input, SubsequenceInput): elif isinstance(each_input, SubsequenceInput):
in_args.append(each_input.input) in_args.append(each_input.input)
else: else:
has_LayerOutput = False
mem_name = "__%s_memory__" % each_input.input.name mem_name = "__%s_memory__" % each_input.input.name
mem = memory( mem = memory(
name=mem_name, name=mem_name,
@ -2915,6 +2927,8 @@ def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
mix += identity_projection(mem) mix += identity_projection(mem)
in_args.append(mem) in_args.append(mem)
assert (is_generating != has_LayerOutput)
layer_outs = step(*in_args) layer_outs = step(*in_args)
if isinstance(layer_outs, LayerOutput): if isinstance(layer_outs, LayerOutput):
@ -3206,7 +3220,11 @@ def beam_search(step,
return predict return predict
tmp = recurrent_group( tmp = recurrent_group(
step=__real_step__, input=real_input, reverse=False, name=name) step=__real_step__,
input=real_input,
reverse=False,
name=name,
is_generating=True)
return tmp return tmp

Loading…
Cancel
Save