|
|
|
@ -63,6 +63,7 @@ class RecurrentLayerGroupSetGeneratorV2(Layer):
|
|
|
|
|
def use_context_name(self):
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@wrap_name_default()
|
|
|
|
|
def beam_search(step,
|
|
|
|
|
input,
|
|
|
|
@ -75,9 +76,10 @@ def beam_search(step,
|
|
|
|
|
if num_results_per_sample is None:
|
|
|
|
|
num_results_per_sample = beam_size
|
|
|
|
|
assert num_results_per_sample <= beam_size
|
|
|
|
|
# logger.warning("num_results_per_sample should be less than beam_size")
|
|
|
|
|
# logger.warning("num_results_per_sample should be less than beam_size")
|
|
|
|
|
|
|
|
|
|
if isinstance(input, paddle.layer.StaticInputV2) or isinstance(input, BaseGeneratedInputV2):
|
|
|
|
|
if isinstance(input, paddle.layer.StaticInputV2) or isinstance(
|
|
|
|
|
input, BaseGeneratedInputV2):
|
|
|
|
|
input = [input]
|
|
|
|
|
|
|
|
|
|
generated_input_index = -1
|
|
|
|
@ -107,8 +109,8 @@ def beam_search(step,
|
|
|
|
|
|
|
|
|
|
args = list(args)
|
|
|
|
|
before_step_layer = gipt.before_real_step()
|
|
|
|
|
before_step_layer.append_child(layer=generator,
|
|
|
|
|
parent_names=[before_step_layer.name])
|
|
|
|
|
before_step_layer.append_child(
|
|
|
|
|
layer=generator, parent_names=[before_step_layer.name])
|
|
|
|
|
args.insert(generated_input_index, before_step_layer)
|
|
|
|
|
|
|
|
|
|
predict = gipt.after_real_step(step(*args))
|
|
|
|
@ -125,8 +127,6 @@ def beam_search(step,
|
|
|
|
|
# name=name,
|
|
|
|
|
# is_generating=True)
|
|
|
|
|
tmp = paddle.layer.recurrent_group(
|
|
|
|
|
step=__real_step__,
|
|
|
|
|
input=real_input,
|
|
|
|
|
name=name)
|
|
|
|
|
step=__real_step__, input=real_input, name=name)
|
|
|
|
|
|
|
|
|
|
return tmp
|
|
|
|
|