|
|
@ -69,7 +69,11 @@ class GatherOp : public framework::OperatorWithKernel {
|
|
|
|
framework::OpKernelType GetKernelTypeForVar(
|
|
|
|
framework::OpKernelType GetKernelTypeForVar(
|
|
|
|
const std::string& var_name, const framework::Tensor& tensor,
|
|
|
|
const std::string& var_name, const framework::Tensor& tensor,
|
|
|
|
const framework::OpKernelType& expected_kernel_type) const override {
|
|
|
|
const framework::OpKernelType& expected_kernel_type) const override {
|
|
|
|
return expected_kernel_type;
|
|
|
|
if (var_name == "Axis") {
|
|
|
|
|
|
|
|
return expected_kernel_type;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return framework::OpKernelType(expected_kernel_type.data_type_,
|
|
|
|
|
|
|
|
tensor.place(), tensor.layout());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|