fix the GetKernelTypeForVar of input for fluid.gather (#28534)

TCChenlong-patch-1
wangchaochaohu 4 years ago committed by GitHub
parent 621b31c526
commit c52fe48f6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -69,7 +69,11 @@ class GatherOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return expected_kernel_type;
if (var_name == "Axis") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};

Loading…
Cancel
Save