|
|
|
@ -3563,9 +3563,15 @@ def beam_search(step,
|
|
|
|
|
simple_rnn += last_time_step_output
|
|
|
|
|
return simple_rnn
|
|
|
|
|
|
|
|
|
|
generated_word_embedding = GeneratedInput(
|
|
|
|
|
size=target_dictionary_dim,
|
|
|
|
|
embedding_name="target_language_embedding",
|
|
|
|
|
embedding_size=word_vector_dim)
|
|
|
|
|
|
|
|
|
|
beam_gen = beam_search(name="decoder",
|
|
|
|
|
step=rnn_step,
|
|
|
|
|
input=[StaticInput(encoder_last)],
|
|
|
|
|
input=[StaticInput(encoder_last),
|
|
|
|
|
generated_word_embedding],
|
|
|
|
|
bos_id=0,
|
|
|
|
|
eos_id=1,
|
|
|
|
|
beam_size=5)
|
|
|
|
@ -3584,7 +3590,8 @@ def beam_search(step,
|
|
|
|
|
You can refer to the first parameter of recurrent_group, or
|
|
|
|
|
demo/seqToseq/seqToseq_net.py for more details.
|
|
|
|
|
:type step: callable
|
|
|
|
|
:param input: Input data for the recurrent unit
|
|
|
|
|
:param input: Input data for the recurrent unit, which should include the
|
|
|
|
|
previously generated words as a GeneratedInput object.
|
|
|
|
|
:type input: list
|
|
|
|
|
:param bos_id: Index of the start symbol in the dictionary. The start symbol
|
|
|
|
|
is a special token for NLP task, which indicates the
|
|
|
|
|