|
|
|
@ -194,8 +194,65 @@ std::string ItemToString(const BeamSearch::Item& item);
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
class BeamSearchKernel : public framework::OpKernel<T>{
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto* ids_var = context.Input<framework::Tensor>("ids");
|
|
|
|
|
auto* scores_var = context.Input<framework::Tensor>("scores");
|
|
|
|
|
auto* pre_ids_var = context.Input<framework::Tensor>("pre_ids");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ids_var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(scores_var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(pre_ids_var);
|
|
|
|
|
|
|
|
|
|
auto& ids = ids_var->Get<framework::LoDTensor>();
|
|
|
|
|
auto& scores = scores_var->Get<framework::LoDTensor>();
|
|
|
|
|
auto& pre_ids = pre_ids_var->Get<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
size_t level = Attr<int>("level");
|
|
|
|
|
size_t beam_size = Attr<int>("beam_size");
|
|
|
|
|
int end_id = Attr<int>("end_id");
|
|
|
|
|
BeamSearch alg(ids, scores, level, beam_size, end_id);
|
|
|
|
|
|
|
|
|
|
auto* selected_ids_var = context.Output<framework::Tensor>("selected_ids");
|
|
|
|
|
auto* selected_scores_var = context.Output<framework::Tensor>("selected_scores");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(selected_ids_var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(selected_scores_var);
|
|
|
|
|
auto& selected_ids_tensor =
|
|
|
|
|
*selected_ids_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
auto& selected_scores_tensor =
|
|
|
|
|
*selected_scores_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
void RunImpl(const framework::Scope& scope,
|
|
|
|
|
const platform::Place& dev_place) const override {
|
|
|
|
|
auto ids_var = scope.FindVar(Input("ids"));
|
|
|
|
|
auto scores_var = scope.FindVar(Input("scores"));
|
|
|
|
|
auto pre_ids_var = scope.FindVar(Input("pre_ids"));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ids_var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(scores_var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(pre_ids_var);
|
|
|
|
|
|
|
|
|
|
auto& ids = ids_var->Get<framework::LoDTensor>();
|
|
|
|
|
auto& scores = scores_var->Get<framework::LoDTensor>();
|
|
|
|
|
auto& pre_ids = pre_ids_var->Get<framework::LoDTensor>();
|
|
|
|
|
size_t level = Attr<int>("level");
|
|
|
|
|
size_t beam_size = Attr<int>("beam_size");
|
|
|
|
|
int end_id = Attr<int>("end_id");
|
|
|
|
|
BeamSearch alg(ids, scores, level, beam_size, end_id);
|
|
|
|
|
|
|
|
|
|
auto selected_ids_var = scope.FindVar(Output("selected_ids"));
|
|
|
|
|
auto selected_scores_var = scope.FindVar(Output("selected_scores"));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(selected_ids_var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(selected_scores_var);
|
|
|
|
|
auto& selected_ids_tensor =
|
|
|
|
|
*selected_ids_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
auto& selected_scores_tensor =
|
|
|
|
|
*selected_scores_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor);
|
|
|
|
|
}
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|