|
|
|
@ -196,9 +196,9 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
class BeamSearchOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto ids_var = context.Input<framework::LoDTensor>("ids");
|
|
|
|
|
auto scores_var = context.Input<framework::LoDTensor>("scores");
|
|
|
|
|
auto pre_ids_var = context.Input<framework::LoDTensor>("pre_ids");
|
|
|
|
|
auto* ids_var = context.Input<framework::LoDTensor>("ids");
|
|
|
|
|
auto* scores_var = context.Input<framework::LoDTensor>("scores");
|
|
|
|
|
auto* pre_ids_var = context.Input<framework::LoDTensor>("pre_ids");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ids_var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(scores_var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(pre_ids_var);
|
|
|
|
|