|
|
|
@ -14,6 +14,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/lod_tensor_array.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
|
|
|
|
@ -87,7 +88,7 @@ struct BeamSearchDecoder {
|
|
|
|
|
*/
|
|
|
|
|
std::vector<BeamNodeVector<T>> PackTwoSteps(
|
|
|
|
|
const LoDTensor& cur_ids, const LoDTensor& cur_scores,
|
|
|
|
|
std::vector<BeamNodeVector<T>>& prefixes_list,
|
|
|
|
|
std::vector<BeamNodeVector<T>>* prefixes_list,
|
|
|
|
|
std::vector<SentenceVector<T>>* sentence_vector_list) const;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -140,7 +141,7 @@ Sentence<T> BeamSearchDecoder<T>::MakeSentence(const BeamNode<T>* node) const {
|
|
|
|
|
template <typename T>
|
|
|
|
|
std::vector<BeamNodeVector<T>> BeamSearchDecoder<T>::PackTwoSteps(
|
|
|
|
|
const LoDTensor& cur_ids, const LoDTensor& cur_scores,
|
|
|
|
|
std::vector<BeamNodeVector<T>>& prefixes_list,
|
|
|
|
|
std::vector<BeamNodeVector<T>>* prefixes_list,
|
|
|
|
|
std::vector<SentenceVector<T>>* sentence_vector_list) const {
|
|
|
|
|
std::vector<BeamNodeVector<T>> result;
|
|
|
|
|
|
|
|
|
@ -153,7 +154,7 @@ std::vector<BeamNodeVector<T>> BeamSearchDecoder<T>::PackTwoSteps(
|
|
|
|
|
|
|
|
|
|
// if prefixes size is 0, it means this is the first step. In this step,
|
|
|
|
|
// all candidate id is the start of candidate sentences.
|
|
|
|
|
if (prefixes_list.empty()) {
|
|
|
|
|
if (prefixes_list->empty()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(cur_ids.lod().at(kSourceLevel).back(),
|
|
|
|
|
cur_ids.lod().at(kSentenceLevel).back(),
|
|
|
|
|
"in the first step");
|
|
|
|
@ -162,7 +163,7 @@ std::vector<BeamNodeVector<T>> BeamSearchDecoder<T>::PackTwoSteps(
|
|
|
|
|
cur_ids.data<int64_t>()[id_idx], cur_scores.data<T>()[id_idx])));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
BeamNodeVector<T>& prefixes = prefixes_list[src_idx];
|
|
|
|
|
BeamNodeVector<T>& prefixes = prefixes_list->at(src_idx);
|
|
|
|
|
SentenceVector<T>& sentence_vector = (*sentence_vector_list)[src_idx];
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(src_end - src_start, prefixes.size(),
|
|
|
|
@ -262,7 +263,7 @@ void BeamSearchDecoder<T>::PackAllSteps(const LoDTensorArray& step_ids,
|
|
|
|
|
for (size_t step_id = 0; step_id < step_num; ++step_id) {
|
|
|
|
|
beamnode_vector_list =
|
|
|
|
|
PackTwoSteps(step_ids.at(step_id), step_scores.at(step_id),
|
|
|
|
|
beamnode_vector_list, &sentence_vector_list);
|
|
|
|
|
&beamnode_vector_list, &sentence_vector_list);
|
|
|
|
|
}
|
|
|
|
|
// append last beam_node to result
|
|
|
|
|
for (size_t src_idx = 0; src_idx < src_num; ++src_idx) {
|
|
|
|
|