Avoid data transforming ShapeTensor from CPU to GPU in fill_constant op. (#25267)

fix_copy_if_different
Yiqun Liu 5 years ago committed by GitHub
parent 5e8e6dad72
commit c00f827843
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -51,6 +51,17 @@ class FillConstantOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "ShapeTensor" || var_name == "ShapeTensorList") {
return expected_kernel_type;
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(

Loading…
Cancel
Save