|
|
|
@ -23,6 +23,8 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/lod_tensor.h"
|
|
|
|
|
#include "paddle/fluid/framework/operator.h"
|
|
|
|
|
|
|
|
|
|
#include <iostream>
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
@ -196,31 +198,47 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
class BeamSearchOpKernel : public framework::OpKernel<T>{
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
std::cout << "Compute 1\n";
|
|
|
|
|
auto ids_var = context.Input<framework::LoDTensor>("ids");
|
|
|
|
|
std::cout << "Compute 2\n";
|
|
|
|
|
auto scores_var = context.Input<framework::LoDTensor>("scores");
|
|
|
|
|
std::cout << "Compute 3\n";
|
|
|
|
|
auto pre_ids_var = context.Input<framework::LoDTensor>("pre_ids");
|
|
|
|
|
std::cout << "Compute 4\n";
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ids_var);
|
|
|
|
|
std::cout << "Compute 5\n";
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(scores_var);
|
|
|
|
|
std::cout << "Compute 6\n";
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(pre_ids_var);
|
|
|
|
|
|
|
|
|
|
//auto& ids = ids_var->Get<framework::LoDTensor>();
|
|
|
|
|
//auto& scores = scores_var->Get<framework::LoDTensor>();
|
|
|
|
|
//auto& pre_ids = pre_ids_var->Get<framework::LoDTensor>();
|
|
|
|
|
std::cout << "Compute 7\n";
|
|
|
|
|
// auto& ids = ids_var->Get<framework::LoDTensor>();
|
|
|
|
|
// auto& scores = scores_var->Get<framework::LoDTensor>();
|
|
|
|
|
// auto& pre_ids = pre_ids_var->Get<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
size_t level = context.Attr<int>("level");
|
|
|
|
|
std::cout << "Compute 8\n";
|
|
|
|
|
size_t beam_size = context.Attr<int>("beam_size");
|
|
|
|
|
std::cout << "Compute 9\n";
|
|
|
|
|
int end_id = context.Attr<int>("end_id");
|
|
|
|
|
std::cout << "Compute 10\n";
|
|
|
|
|
BeamSearch alg(*ids_var, *scores_var, level, beam_size, end_id);
|
|
|
|
|
|
|
|
|
|
auto selected_ids_var = context.Output<framework::LoDTensor>("selected_ids");
|
|
|
|
|
auto selected_scores_var = context.Output<framework::LoDTensor>("selected_scores");
|
|
|
|
|
std::cout << "Compute 11\n";
|
|
|
|
|
auto selected_ids_var =
|
|
|
|
|
context.Output<framework::LoDTensor>("selected_ids");
|
|
|
|
|
std::cout << "Compute 12\n";
|
|
|
|
|
auto selected_scores_var =
|
|
|
|
|
context.Output<framework::LoDTensor>("selected_scores");
|
|
|
|
|
std::cout << "Compute 13\n";
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(selected_ids_var);
|
|
|
|
|
std::cout << "Compute 14\n";
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(selected_scores_var);
|
|
|
|
|
//auto& selected_ids_tensor =
|
|
|
|
|
std::cout << "Compute 15\n";
|
|
|
|
|
// auto& selected_ids_tensor =
|
|
|
|
|
// *selected_ids_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
//auto& selected_scores_tensor =
|
|
|
|
|
// auto& selected_scores_tensor =
|
|
|
|
|
// *selected_scores_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
alg(*pre_ids_var, selected_ids_var, selected_scores_var);
|
|
|
|
|
std::cout << "Compute 16\n";
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|