|
|
|
@ -260,10 +260,13 @@ class BeamSearchOp : public framework::OperatorWithKernel {
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
std::cout << "Get Expected type 1\n";
|
|
|
|
|
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
|
|
|
|
|
framework::OpKernelType kt = framework::OpKernelType(
|
|
|
|
|
framework::ToDataType(
|
|
|
|
|
ctx.Input<framework::LoDTensor>("pre_ids")->type()),
|
|
|
|
|
platform::CPUPlace());
|
|
|
|
|
std::cout << "Get Expected type 2\n";
|
|
|
|
|
kt.place_ = ctx.Input<framework::LoDTensor>("pre_ids")->place();
|
|
|
|
|
std::cout << "Get Expected type 3\n";
|
|
|
|
|
// kt.place_ = ctx.Input<framework::LoDTensor>("pre_ids")->place();
|
|
|
|
|
// std::cout << "Get Expected type 3\n";
|
|
|
|
|
return kt;
|
|
|
|
|
}
|
|
|
|
|
/*
|
|
|
|
|