Add tracks_own_finished to Decoder to avoid mismanagement of the finished state in dynamic_decode. (#23664)

test=develop
revert-22778-infer_var_type
Guo Sheng 5 years ago committed by GitHub
parent 614eb942fc
commit 54a47cd271
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -601,6 +601,28 @@ class Decoder(object):
"""
raise NotImplementedError
@property
def tracks_own_finished(self):
"""
Describes whether the Decoder keeps track of finished states by itself.
`decoder.step()` would emit a bool `finished` value at each decoding
step. The emited `finished` can be used to determine whether every
batch entries is finished directly, or it can be combined with the
finished tracker keeped in `dynamic_decode` by performing a logical OR
to take the already finished into account.
If `False`, the latter would be took when performing `dynamic_decode`,
which is the default. Otherwise, the former would be took, which uses
the finished value emited by the decoder as all batch entry finished
status directly, and it is the case when batch entries might be
reordered such as beams in BeamSearchDecoder.
Returns:
bool: A python bool `False`.
"""
return False
class BeamSearchDecoder(Decoder):
"""
@ -1048,6 +1070,19 @@ class BeamSearchDecoder(Decoder):
# TODO: use FinalBeamSearchDecoderOutput as output
return predicted_ids, final_states
@property
def tracks_own_finished(self):
"""
BeamSearchDecoder reorders its beams and their finished state. Thus it
conflicts with `dynamic_decode` function's tracking of finished states.
Setting this property to true to avoid early stopping of decoding due
to mismanagement of the finished state.
Returns:
bool: A python bool `True`.
"""
return True
def dynamic_decode(decoder,
inits=None,
@ -1205,7 +1240,13 @@ def dynamic_decode(decoder,
states_arrays)
(outputs, next_states, next_inputs,
next_finished) = decoder.step(step_idx, inputs, states, **kwargs)
next_finished = control_flow.logical_or(next_finished, global_finished)
if not decoder.tracks_own_finished:
# BeamSearchDecoder would track it own finished, since beams would
# be reordered and the finished status of each entry might change.
# Otherwise, perform logical OR which would not change the already
# finished.
next_finished = control_flow.logical_or(next_finished,
global_finished)
next_sequence_lengths = nn.elementwise_add(
sequence_lengths,
tensor.cast(
@ -1226,6 +1267,10 @@ def dynamic_decode(decoder,
lambda x, x_array: control_flow.array_write(
x, i=step_idx, array=x_array), outputs, outputs_arrays)
control_flow.increment(x=step_idx, value=1.0, in_place=True)
# update the global_finished first, since it might be also in states of
# decoder, which otherwise would write a stale finished status to array
tensor.assign(next_finished, global_finished)
tensor.assign(next_sequence_lengths, sequence_lengths)
if is_test:
map_structure(tensor.assign, next_inputs, global_inputs)
map_structure(tensor.assign, next_states, global_states)
@ -1236,8 +1281,6 @@ def dynamic_decode(decoder,
map_structure(
lambda x, x_array: control_flow.array_write(
x, i=step_idx, array=x_array), next_states, states_arrays)
tensor.assign(next_finished, global_finished)
tensor.assign(next_sequence_lengths, sequence_lengths)
if max_step_num is not None:
control_flow.logical_and(
control_flow.logical_not(nn.reduce_all(global_finished)),

Loading…
Cancel
Save