|
|
|
@ -81,6 +81,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
|
|
|
|
|
framework::OpKernelType::kDefaultCustomizedTypeValue;
|
|
|
|
|
framework::LibraryType library{framework::LibraryType::kPlain};
|
|
|
|
|
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
|
|
|
|
|
auto input_data_type = ctx.Input<Tensor>("Input")->type();
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
framework::DataLayout layout = framework::StringToDataLayout(data_format);
|
|
|
|
|
|
|
|
|
@ -94,11 +95,14 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
library = framework::LibraryType::kMKLDNN;
|
|
|
|
|
layout = framework::DataLayout::kMKLDNN;
|
|
|
|
|
customized_type_value = kConvMKLDNNFP32;
|
|
|
|
|
customized_type_value =
|
|
|
|
|
(input_data_type == framework::DataTypeTrait<int8_t>::DataType ||
|
|
|
|
|
input_data_type == framework::DataTypeTrait<uint8_t>::DataType)
|
|
|
|
|
? kConvMKLDNNINT8
|
|
|
|
|
: kConvMKLDNNFP32;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
auto input_data_type = ctx.Input<Tensor>("Input")->type();
|
|
|
|
|
if (input_data_type != framework::proto::VarType::INT8 &&
|
|
|
|
|
input_data_type != framework::proto::VarType::UINT8) {
|
|
|
|
|
auto filter_data_type = ctx.Input<Tensor>("Filter")->type();
|
|
|
|
|