|
|
|
@ -115,7 +115,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
|
|
|
|
|
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Out"));
|
|
|
|
|
return framework::OpKernelType(data_type, ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|