|
|
|
@ -175,9 +175,14 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
|
|
|
|
|
input_data_type != framework::proto::VarType::UINT8 &&
|
|
|
|
|
input_data_type != framework::proto::VarType::BF16) {
|
|
|
|
|
auto filter_data_type = ctx.Input<Tensor>("Filter")->type();
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_data_type, filter_data_type,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input and filter data type should be consistent"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
input_data_type, filter_data_type,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input and filter data type should be consistent, "
|
|
|
|
|
"but received input data type is %s and filter type "
|
|
|
|
|
"is %s",
|
|
|
|
|
paddle::framework::DataTypeToString(input_data_type),
|
|
|
|
|
paddle::framework::DataTypeToString(filter_data_type)));
|
|
|
|
|
}
|
|
|
|
|
if (input_data_type == framework::proto::VarType::FP16) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN,
|
|
|
|
|