|
|
|
@ -168,6 +168,7 @@ __device__ __forceinline__ bool PruneEndBeams(Triple* top_beam_local,
|
|
|
|
|
return finish_flag;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <bool ReturnParentIdx = false>
|
|
|
|
|
__device__ __forceinline__ void WriteBack(
|
|
|
|
|
int64_t* selected_ids, float* selected_scores, int* parent_idx,
|
|
|
|
|
size_t* selected_offsets, Triple* top_beam_local,
|
|
|
|
@ -183,7 +184,9 @@ __device__ __forceinline__ void WriteBack(
|
|
|
|
|
selected_ids[global_index] =
|
|
|
|
|
static_cast<int64_t>(top_beam_local[local_index].id);
|
|
|
|
|
selected_scores[global_index] = top_beam_local[local_index].score;
|
|
|
|
|
parent_idx[global_index] = static_cast<int>(global_offset);
|
|
|
|
|
if (ReturnParentIdx) {
|
|
|
|
|
parent_idx[global_index] = static_cast<int>(global_offset);
|
|
|
|
|
}
|
|
|
|
|
global_index++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -241,9 +244,15 @@ __device__ void BeamSearchDetails(
|
|
|
|
|
selected_offsets[0] = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
WriteBack(selected_ids, selected_scores, parent_idx, selected_offsets,
|
|
|
|
|
top_beam_local, seq_offset_start, seq_offset_end,
|
|
|
|
|
selected_seq_start, selected_seq_length);
|
|
|
|
|
if (parent_idx) {
|
|
|
|
|
WriteBack<true>(selected_ids, selected_scores, parent_idx,
|
|
|
|
|
selected_offsets, top_beam_local, seq_offset_start,
|
|
|
|
|
seq_offset_end, selected_seq_start, selected_seq_length);
|
|
|
|
|
} else {
|
|
|
|
|
WriteBack<false>(selected_ids, selected_scores, parent_idx,
|
|
|
|
|
selected_offsets, top_beam_local, seq_offset_start,
|
|
|
|
|
seq_offset_end, selected_seq_start, selected_seq_length);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -337,8 +346,12 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
selected_ids->mutable_data<int64_t>(selected_dims, context.GetPlace());
|
|
|
|
|
float* selected_scores_data =
|
|
|
|
|
selected_scores->mutable_data<float>(selected_dims, context.GetPlace());
|
|
|
|
|
int* parent_idx_data = parent_idx->mutable_data<int>(
|
|
|
|
|
{static_cast<int64_t>(num_seqs * beam_size)}, context.GetPlace());
|
|
|
|
|
int* parent_idx_data =
|
|
|
|
|
parent_idx
|
|
|
|
|
? parent_idx->mutable_data<int>(
|
|
|
|
|
{static_cast<int64_t>(num_seqs * beam_size)},
|
|
|
|
|
context.GetPlace())
|
|
|
|
|
: nullptr;
|
|
|
|
|
|
|
|
|
|
framework::LoD selected_lod(2);
|
|
|
|
|
selected_lod[0].assign(abs_lod[level].begin(), abs_lod[level].end());
|
|
|
|
@ -396,7 +409,9 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
{static_cast<int64_t>(selected_lod[1].back()), 1});
|
|
|
|
|
selected_ids->Resize(final_selected_dims);
|
|
|
|
|
selected_scores->Resize(final_selected_dims);
|
|
|
|
|
parent_idx->Resize({static_cast<int64_t>(selected_lod[1].back())});
|
|
|
|
|
if (parent_idx) {
|
|
|
|
|
parent_idx->Resize({static_cast<int64_t>(selected_lod[1].back())});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|