diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index 89e74e35d8..62771d09f1 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -87,7 +87,7 @@ void BeamSearch::PruneEndBeams(const framework::LoDTensor &pre_ids, auto *pre_ids_data = pre_ids.data(); auto abs_lod = framework::ToAbsOffset(ids_->lod()); auto &high_level = abs_lod[lod_level_]; - for (size_t src_idx = 0; src_idx < high_level.size(); ++src_idx) { + for (size_t src_idx = 0; src_idx < high_level.size() - 1; ++src_idx) { size_t src_prefix_start = high_level[src_idx]; size_t src_prefix_end = high_level[src_idx + 1]; bool finish_flag = true; diff --git a/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py b/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py index ccb7a4f9ab..f690a0d233 100644 --- a/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py +++ b/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py @@ -148,7 +148,11 @@ def decode(context, is_sparse): pd.array_write(selected_ids, array=ids_array, i=counter) pd.array_write(selected_scores, array=scores_array, i=counter) - pd.less_than(x=counter, y=array_len, cond=cond) + # update the break condition: up to the max length or all candidates of + # source sentences have ended. + length_cond = pd.less_than(x=counter, y=array_len) + finish_cond = pd.logical_not(pd.is_empty(x=selected_ids)) + pd.logical_and(x=length_cond, y=finish_cond, out=cond) translation_ids, translation_scores = pd.beam_search_decode( ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10) diff --git a/python/paddle/fluid/tests/book/test_machine_translation.py b/python/paddle/fluid/tests/book/test_machine_translation.py index d8499fa3f7..44e4c62643 100644 --- a/python/paddle/fluid/tests/book/test_machine_translation.py +++ b/python/paddle/fluid/tests/book/test_machine_translation.py @@ -147,7 +147,11 @@ def decoder_decode(context, is_sparse): pd.array_write(selected_ids, array=ids_array, i=counter) pd.array_write(selected_scores, array=scores_array, i=counter) - pd.less_than(x=counter, y=array_len, cond=cond) + # update the break condition: up to the max length or all candidates of + # source sentences have ended. + length_cond = pd.less_than(x=counter, y=array_len) + finish_cond = pd.logical_not(pd.is_empty(x=selected_ids)) + pd.logical_and(x=length_cond, y=finish_cond, out=cond) translation_ids, translation_scores = pd.beam_search_decode( ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)