fix style problem

release/0.10.0
qiaolongfei 8 years ago
parent d6e8d5cdfd
commit ba51e6ea68

@ -131,8 +131,9 @@ def gru_encoder_decoder(data_conf,
decoder_group_name = "decoder_group" decoder_group_name = "decoder_group"
group_inputs = [ group_inputs = [
StaticInput(input=encoded_vector, is_seq=True), StaticInput(
StaticInput(input=encoded_proj, is_seq=True) input=encoded_vector, is_seq=True), StaticInput(
input=encoded_proj, is_seq=True)
] ]
if not is_generating: if not is_generating:

@ -114,7 +114,8 @@ class Layer(object):
# 4. parse myself and add myself into context. # 4. parse myself and add myself into context.
ret_val = self.to_proto_impl(context=context, **kwargs) ret_val = self.to_proto_impl(context=context, **kwargs)
if self.context_name() is not None and self.context_name() not in context: if self.context_name() is not None and self.context_name(
) not in context:
context[self.context_name()] = ret_val context[self.context_name()] = ret_val
# 5. parse children that should be pased after this layer. # 5. parse children that should be pased after this layer.

@ -292,7 +292,8 @@ class RecurrentLayerInput(Layer):
else: else:
self.__parents__ = parent_layers.values()[0] self.__parents__ = parent_layers.values()[0]
self.__recurrent_name__ = recurrent_name self.__recurrent_name__ = recurrent_name
name = self.__parents__[index].name if index >= 0 else self.context_name() name = self.__parents__[
index].name if index >= 0 else self.context_name()
super(RecurrentLayerInput, self).__init__( super(RecurrentLayerInput, self).__init__(
name=name, parent_layers=parent_layers) name=name, parent_layers=parent_layers)
@ -402,9 +403,7 @@ def recurrent_group(step, input, name=None):
extra_input = None extra_input = None
if len(non_static_inputs) == 0: if len(non_static_inputs) == 0:
extra_input = RecurrentLayerInput( extra_input = RecurrentLayerInput(
recurrent_name=name, recurrent_name=name, index=-1, parent_layers={})
index=-1,
parent_layers={})
def __real_step__(*args): def __real_step__(*args):
rnn_input = list(args) rnn_input = list(args)

@ -1 +1 @@
import beam_search import beam_search

@ -63,6 +63,7 @@ class RecurrentLayerGroupSetGeneratorV2(Layer):
def use_context_name(self): def use_context_name(self):
return True return True
@wrap_name_default() @wrap_name_default()
def beam_search(step, def beam_search(step,
input, input,
@ -75,9 +76,10 @@ def beam_search(step,
if num_results_per_sample is None: if num_results_per_sample is None:
num_results_per_sample = beam_size num_results_per_sample = beam_size
assert num_results_per_sample <= beam_size assert num_results_per_sample <= beam_size
# logger.warning("num_results_per_sample should be less than beam_size") # logger.warning("num_results_per_sample should be less than beam_size")
if isinstance(input, paddle.layer.StaticInputV2) or isinstance(input, BaseGeneratedInputV2): if isinstance(input, paddle.layer.StaticInputV2) or isinstance(
input, BaseGeneratedInputV2):
input = [input] input = [input]
generated_input_index = -1 generated_input_index = -1
@ -107,8 +109,8 @@ def beam_search(step,
args = list(args) args = list(args)
before_step_layer = gipt.before_real_step() before_step_layer = gipt.before_real_step()
before_step_layer.append_child(layer=generator, before_step_layer.append_child(
parent_names=[before_step_layer.name]) layer=generator, parent_names=[before_step_layer.name])
args.insert(generated_input_index, before_step_layer) args.insert(generated_input_index, before_step_layer)
predict = gipt.after_real_step(step(*args)) predict = gipt.after_real_step(step(*args))
@ -125,8 +127,6 @@ def beam_search(step,
# name=name, # name=name,
# is_generating=True) # is_generating=True)
tmp = paddle.layer.recurrent_group( tmp = paddle.layer.recurrent_group(
step=__real_step__, step=__real_step__, input=real_input, name=name)
input=real_input,
name=name)
return tmp return tmp

Loading…
Cancel
Save