|
|
|
@ -601,6 +601,28 @@ class Decoder(object):
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def tracks_own_finished(self):
|
|
|
|
|
"""
|
|
|
|
|
Describes whether the Decoder keeps track of finished states by itself.
|
|
|
|
|
|
|
|
|
|
`decoder.step()` would emit a bool `finished` value at each decoding
|
|
|
|
|
step. The emited `finished` can be used to determine whether every
|
|
|
|
|
batch entries is finished directly, or it can be combined with the
|
|
|
|
|
finished tracker keeped in `dynamic_decode` by performing a logical OR
|
|
|
|
|
to take the already finished into account.
|
|
|
|
|
|
|
|
|
|
If `False`, the latter would be took when performing `dynamic_decode`,
|
|
|
|
|
which is the default. Otherwise, the former would be took, which uses
|
|
|
|
|
the finished value emited by the decoder as all batch entry finished
|
|
|
|
|
status directly, and it is the case when batch entries might be
|
|
|
|
|
reordered such as beams in BeamSearchDecoder.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
bool: A python bool `False`.
|
|
|
|
|
"""
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BeamSearchDecoder(Decoder):
|
|
|
|
|
"""
|
|
|
|
@ -1048,6 +1070,19 @@ class BeamSearchDecoder(Decoder):
|
|
|
|
|
# TODO: use FinalBeamSearchDecoderOutput as output
|
|
|
|
|
return predicted_ids, final_states
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def tracks_own_finished(self):
|
|
|
|
|
"""
|
|
|
|
|
BeamSearchDecoder reorders its beams and their finished state. Thus it
|
|
|
|
|
conflicts with `dynamic_decode` function's tracking of finished states.
|
|
|
|
|
Setting this property to true to avoid early stopping of decoding due
|
|
|
|
|
to mismanagement of the finished state.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
bool: A python bool `True`.
|
|
|
|
|
"""
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dynamic_decode(decoder,
|
|
|
|
|
inits=None,
|
|
|
|
@ -1205,7 +1240,13 @@ def dynamic_decode(decoder,
|
|
|
|
|
states_arrays)
|
|
|
|
|
(outputs, next_states, next_inputs,
|
|
|
|
|
next_finished) = decoder.step(step_idx, inputs, states, **kwargs)
|
|
|
|
|
next_finished = control_flow.logical_or(next_finished, global_finished)
|
|
|
|
|
if not decoder.tracks_own_finished:
|
|
|
|
|
# BeamSearchDecoder would track it own finished, since beams would
|
|
|
|
|
# be reordered and the finished status of each entry might change.
|
|
|
|
|
# Otherwise, perform logical OR which would not change the already
|
|
|
|
|
# finished.
|
|
|
|
|
next_finished = control_flow.logical_or(next_finished,
|
|
|
|
|
global_finished)
|
|
|
|
|
next_sequence_lengths = nn.elementwise_add(
|
|
|
|
|
sequence_lengths,
|
|
|
|
|
tensor.cast(
|
|
|
|
@ -1226,6 +1267,10 @@ def dynamic_decode(decoder,
|
|
|
|
|
lambda x, x_array: control_flow.array_write(
|
|
|
|
|
x, i=step_idx, array=x_array), outputs, outputs_arrays)
|
|
|
|
|
control_flow.increment(x=step_idx, value=1.0, in_place=True)
|
|
|
|
|
# update the global_finished first, since it might be also in states of
|
|
|
|
|
# decoder, which otherwise would write a stale finished status to array
|
|
|
|
|
tensor.assign(next_finished, global_finished)
|
|
|
|
|
tensor.assign(next_sequence_lengths, sequence_lengths)
|
|
|
|
|
if is_test:
|
|
|
|
|
map_structure(tensor.assign, next_inputs, global_inputs)
|
|
|
|
|
map_structure(tensor.assign, next_states, global_states)
|
|
|
|
@ -1236,8 +1281,6 @@ def dynamic_decode(decoder,
|
|
|
|
|
map_structure(
|
|
|
|
|
lambda x, x_array: control_flow.array_write(
|
|
|
|
|
x, i=step_idx, array=x_array), next_states, states_arrays)
|
|
|
|
|
tensor.assign(next_finished, global_finished)
|
|
|
|
|
tensor.assign(next_sequence_lengths, sequence_lengths)
|
|
|
|
|
if max_step_num is not None:
|
|
|
|
|
control_flow.logical_and(
|
|
|
|
|
control_flow.logical_not(nn.reduce_all(global_finished)),
|
|
|
|
|