|
|
|
@ -318,7 +318,7 @@ class BeamSearchInferShape : public framework::InferShapeBase {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
class BeamSearchInferVarType : public framework::VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::OpDesc &op_desc,
|
|
|
|
@ -331,7 +331,7 @@ class BeamSearchInferVarType : public framework::VarTypeInference {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
/*
|
|
|
|
@ -343,7 +343,8 @@ REGISTER_OPERATOR(beam_search, paddle::operators::BeamSearchOp,
|
|
|
|
|
*/
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(beam_search, ops::BeamSearchOp,
|
|
|
|
|
ops::BeamSearchOpMaker);
|
|
|
|
|
ops::BeamSearchOpMaker,
|
|
|
|
|
ops::BeamSearchInferVarType);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
beam_search,
|
|
|
|
|
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|