|
|
|
@ -18,22 +18,6 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
static inline framework::OpKernelType ExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) {
|
|
|
|
|
auto* table_var = ctx.InputVar("W");
|
|
|
|
|
if (table_var->IsType<LoDTensor>()) {
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
framework::ToDataType(table_var->Get<LoDTensor>().type()),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
} else if (table_var->IsType<SelectedRows>()) {
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
framework::ToDataType(table_var->Get<SelectedRows>().value().type()),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("W should be LoDTensor or SelectedRows");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class LookupTableOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
@ -67,7 +51,8 @@ class LookupTableOp : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
return ExpectedKernelType(ctx);
|
|
|
|
|
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
|
|
|
|
|
return framework::OpKernelType(data_type, ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -138,7 +123,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
return ExpectedKernelType(ctx);
|
|
|
|
|
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
|
|
|
|
|
return framework::OpKernelType(data_type, ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|