|
|
|
@ -197,8 +197,7 @@ std::string ItemToString(const BeamSearch::Item &item) {
|
|
|
|
|
return stream.str();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class BeamSearchOpMaker
|
|
|
|
|
: public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
BeamSearchOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
@ -225,29 +224,15 @@ class BeamSearchOpMaker
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class BeamSearchOp : public framework::OperatorWithKernel {
|
|
|
|
|
/*
|
|
|
|
|
public:
|
|
|
|
|
BeamSearchOp(const std::string& type,
|
|
|
|
|
const framework::VariableNameMap& inputs,
|
|
|
|
|
const framework::VariableNameMap& outputs,
|
|
|
|
|
const framework::AttributeMap& attrs)
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
BeamSearchOp(const BeamSearchOp& o)
|
|
|
|
|
: framework::OperatorWithKernel(
|
|
|
|
|
static_cast<const framework::OperatorBase&>(o)) {
|
|
|
|
|
PADDLE_THROW("Not Implemented");
|
|
|
|
|
}
|
|
|
|
|
*/
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
for (const std::string &arg :
|
|
|
|
|
std::vector<std::string>({"pre_ids", "ids", "scores"})) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(arg),
|
|
|
|
|
"BeamSearch need input argument '%s'", arg);
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(arg), "BeamSearch need input argument '%s'",
|
|
|
|
|
arg);
|
|
|
|
|
}
|
|
|
|
|
for (const std::string &arg :
|
|
|
|
|
std::vector<std::string>({"selected_ids", "selected_scores"})) {
|
|
|
|
@ -263,62 +248,13 @@ class BeamSearchOp : public framework::OperatorWithKernel {
|
|
|
|
|
framework::OpKernelType kt = framework::OpKernelType(
|
|
|
|
|
framework::ToDataType(
|
|
|
|
|
ctx.Input<framework::LoDTensor>("pre_ids")->type()),
|
|
|
|
|
platform::CPUPlace());
|
|
|
|
|
platform::CPUPlace());
|
|
|
|
|
std::cout << "Get Expected type 2\n";
|
|
|
|
|
// kt.place_ = ctx.Input<framework::LoDTensor>("pre_ids")->place();
|
|
|
|
|
// std::cout << "Get Expected type 3\n";
|
|
|
|
|
return kt;
|
|
|
|
|
}
|
|
|
|
|
/*
|
|
|
|
|
private:
|
|
|
|
|
void RunImpl(const framework::Scope& scope,
|
|
|
|
|
const platform::Place& dev_place) const override {
|
|
|
|
|
auto ids_var = scope.FindVar(Input("ids"));
|
|
|
|
|
auto scores_var = scope.FindVar(Input("scores"));
|
|
|
|
|
auto pre_ids_var = scope.FindVar(Input("pre_ids"));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ids_var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(scores_var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(pre_ids_var);
|
|
|
|
|
|
|
|
|
|
auto& ids = ids_var->Get<framework::LoDTensor>();
|
|
|
|
|
auto& scores = scores_var->Get<framework::LoDTensor>();
|
|
|
|
|
auto& pre_ids = pre_ids_var->Get<framework::LoDTensor>();
|
|
|
|
|
size_t level = Attr<int>("level");
|
|
|
|
|
size_t beam_size = Attr<int>("beam_size");
|
|
|
|
|
int end_id = Attr<int>("end_id");
|
|
|
|
|
BeamSearch alg(ids, scores, level, beam_size, end_id);
|
|
|
|
|
|
|
|
|
|
auto selected_ids_var = scope.FindVar(Output("selected_ids"));
|
|
|
|
|
auto selected_scores_var = scope.FindVar(Output("selected_scores"));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(selected_ids_var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(selected_scores_var);
|
|
|
|
|
auto& selected_ids_tensor =
|
|
|
|
|
*selected_ids_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
auto& selected_scores_tensor =
|
|
|
|
|
*selected_scores_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor);
|
|
|
|
|
}
|
|
|
|
|
*/
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
class BeamSearchInferShape : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *context) const override {
|
|
|
|
|
for (const std::string &arg :
|
|
|
|
|
std::vector<std::string>({"pre_ids", "ids", "scores"})) {
|
|
|
|
|
PADDLE_ENFORCE(context->HasInput(arg),
|
|
|
|
|
"BeamSearch need input argument '%s'", arg);
|
|
|
|
|
}
|
|
|
|
|
for (const std::string &arg :
|
|
|
|
|
std::vector<std::string>({"selected_ids", "selected_scores"})) {
|
|
|
|
|
PADDLE_ENFORCE(context->HasOutput(arg),
|
|
|
|
|
"BeamSearch need output argument '%s'", arg);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
*/
|
|
|
|
|
class BeamSearchInferVarType : public framework::VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::OpDesc &op_desc,
|
|
|
|
@ -334,18 +270,15 @@ class BeamSearchInferVarType : public framework::VarTypeInference {
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
/*
|
|
|
|
|
REGISTER_OPERATOR(beam_search, paddle::operators::BeamSearchOp,
|
|
|
|
|
paddle::operators::BeamSearchProtoAndCheckerMaker,
|
|
|
|
|
paddle::operators::BeamSearchInferShape,
|
|
|
|
|
paddle::operators::BeamSearchInferVarType,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker);
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(beam_search, ops::BeamSearchOp,
|
|
|
|
|
ops::BeamSearchOpMaker,
|
|
|
|
|
ops::BeamSearchInferVarType);
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(beam_search, ops::BeamSearchOp, ops::BeamSearchOpMaker,
|
|
|
|
|
ops::BeamSearchInferVarType);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
beam_search,
|
|
|
|
|
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
|
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, int>,
|
|
|
|
|
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
|
|
|
|