|
|
|
@ -319,6 +319,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
|
|
|
|
|
int groups = ctx.Attr<int>("groups");
|
|
|
|
|
|
|
|
|
|
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
|
|
|
|
|
|
|
|
|
|
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
|
|
|
|
|
|
|
|
|
|
bool is_conv3d = strides.size() == 3U;
|
|
|
|
@ -329,6 +331,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
dilations[2] == 1
|
|
|
|
|
: dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
|
|
|
|
|
"dilation in convolution is not implemented yet");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(is_conv3d != true, "int8 does not support conv3d currently");
|
|
|
|
|
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
@ -340,15 +343,24 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
GetWeightsTz(weights_tz, g, is_conv3d);
|
|
|
|
|
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
|
|
|
|
|
|
|
|
|
|
mkldnn::memory::data_type src_dt =
|
|
|
|
|
paddle::framework::ToMKLDNNDataType(input->type());
|
|
|
|
|
auto dst_dt = fuse_relu ? paddle::framework::ToMKLDNNDataType(
|
|
|
|
|
framework::DataTypeTrait<uint8_t>::DataType)
|
|
|
|
|
: paddle::framework::ToMKLDNNDataType(
|
|
|
|
|
framework::DataTypeTrait<int8_t>::DataType);
|
|
|
|
|
|
|
|
|
|
if (force_fp32_output) {
|
|
|
|
|
dst_dt = paddle::framework::ToMKLDNNDataType(
|
|
|
|
|
framework::DataTypeTrait<float>::DataType);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get unique name for storing MKLDNN primitives
|
|
|
|
|
std::string key;
|
|
|
|
|
key.reserve(MaxKeyLength);
|
|
|
|
|
mkldnn::memory::data_type src_dt =
|
|
|
|
|
paddle::framework::ToMKLDNNDataType(input->type());
|
|
|
|
|
platform::ConvMKLDNNHandler::AppendKey(
|
|
|
|
|
&key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt,
|
|
|
|
|
input->format(), ctx.op().Output("Output"));
|
|
|
|
|
|
|
|
|
|
input->format(), dst_dt, ctx.op().Output("Output"));
|
|
|
|
|
const std::string key_conv_pd = key + "@conv_pd";
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::convolution_forward> conv_p = nullptr;
|
|
|
|
@ -413,13 +425,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
platform::MKLDNNMemDesc(src_tz, src_dt, chosen_memory_format);
|
|
|
|
|
auto weights_md = platform::MKLDNNMemDesc(
|
|
|
|
|
weights_tz, memory::data_type::s8, chosen_memory_format);
|
|
|
|
|
|
|
|
|
|
auto dst_dt = force_fp32_output
|
|
|
|
|
? paddle::framework::ToMKLDNNDataType(
|
|
|
|
|
framework::DataTypeTrait<float>::DataType)
|
|
|
|
|
: paddle::framework::ToMKLDNNDataType(
|
|
|
|
|
framework::DataTypeTrait<int8_t>::DataType);
|
|
|
|
|
|
|
|
|
|
auto dst_md =
|
|
|
|
|
platform::MKLDNNMemDesc(dst_tz, dst_dt, chosen_memory_format);
|
|
|
|
|
// create a conv primitive descriptor and save it for usage in backward
|
|
|
|
@ -429,11 +434,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
memory::format::x);
|
|
|
|
|
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
|
|
|
|
|
strides, paddings, mkldnn_engine,
|
|
|
|
|
output_shift_scale, is_test);
|
|
|
|
|
fuse_relu, output_shift_scale, is_test);
|
|
|
|
|
} else {
|
|
|
|
|
conv_pd =
|
|
|
|
|
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
|
|
|
|
|
mkldnn_engine, output_shift_scale, is_test);
|
|
|
|
|
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides,
|
|
|
|
|
paddings, mkldnn_engine, fuse_relu,
|
|
|
|
|
output_shift_scale, is_test);
|
|
|
|
|
}
|
|
|
|
|
// Save conv_pd/src_memory/weights_memory for backward pass
|
|
|
|
|
dev_ctx.SetBlob(key_conv_pd, conv_pd);
|
|
|
|
@ -459,7 +464,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
mask_reorder);
|
|
|
|
|
|
|
|
|
|
if (!force_fp32_output) {
|
|
|
|
|
dst_memory_p = platform::SetDstMemory<int8_t>(ctx, output, handler);
|
|
|
|
|
if (fuse_relu) {
|
|
|
|
|
dst_memory_p = platform::SetDstMemory<uint8_t>(ctx, output, handler);
|
|
|
|
|
} else {
|
|
|
|
|
dst_memory_p = platform::SetDstMemory<int8_t>(ctx, output, handler);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
dst_memory_p = platform::SetDstMemory<float>(ctx, output, handler);
|
|
|
|
|
}
|
|
|
|
@ -518,8 +527,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
mkldnn_engine, key));
|
|
|
|
|
}
|
|
|
|
|
if (!force_fp32_output) {
|
|
|
|
|
dst_memory_p =
|
|
|
|
|
platform::SetDstMemoryHandler<int8_t>(ctx, output, handler);
|
|
|
|
|
if (fuse_relu) {
|
|
|
|
|
dst_memory_p =
|
|
|
|
|
platform::SetDstMemoryHandler<uint8_t>(ctx, output, handler);
|
|
|
|
|
} else {
|
|
|
|
|
dst_memory_p =
|
|
|
|
|
platform::SetDstMemoryHandler<int8_t>(ctx, output, handler);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
dst_memory_p =
|
|
|
|
|
platform::SetDstMemoryHandler<float>(ctx, output, handler);
|
|
|
|
@ -563,11 +577,18 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mkldnn::primitive_attr CreatePostOps(
|
|
|
|
|
const std::vector<float> output_shift_scale) const {
|
|
|
|
|
bool fuse_relu, const std::vector<float> output_shift_scale) const {
|
|
|
|
|
mkldnn::primitive_attr conv_attr;
|
|
|
|
|
mkldnn::post_ops post_operations;
|
|
|
|
|
int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0;
|
|
|
|
|
conv_attr.set_output_scales(mask, output_shift_scale);
|
|
|
|
|
if (fuse_relu) {
|
|
|
|
|
constexpr float scale = 1.0f;
|
|
|
|
|
constexpr float negative_slope = 0.0f;
|
|
|
|
|
constexpr float placeholder = 1.0f; // beta
|
|
|
|
|
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
|
|
|
|
|
negative_slope, placeholder);
|
|
|
|
|
}
|
|
|
|
|
conv_attr.set_post_ops(post_operations);
|
|
|
|
|
return conv_attr;
|
|
|
|
|
}
|
|
|
|
@ -600,7 +621,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
|
|
|
|
|
const memory::desc& dst, const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings,
|
|
|
|
|
const mkldnn::engine& engine,
|
|
|
|
|
const mkldnn::engine& engine, const bool fuse_relu,
|
|
|
|
|
const std::vector<float> output_shift_scale,
|
|
|
|
|
bool is_test) const {
|
|
|
|
|
memory::dims stride_dims = {strides[0], strides[1]};
|
|
|
|
@ -613,7 +634,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims,
|
|
|
|
|
padding_dims, padding_dims, mkldnn::padding_kind::zero);
|
|
|
|
|
|
|
|
|
|
mkldnn::primitive_attr conv_attr = CreatePostOps(output_shift_scale);
|
|
|
|
|
mkldnn::primitive_attr conv_attr =
|
|
|
|
|
CreatePostOps(fuse_relu, output_shift_scale);
|
|
|
|
|
|
|
|
|
|
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
|
|
|
|
|
conv_desc, conv_attr, engine);
|
|
|
|
@ -652,7 +674,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
const memory::desc& bias, const memory::desc& dst,
|
|
|
|
|
const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings,
|
|
|
|
|
const mkldnn::engine& engine,
|
|
|
|
|
const mkldnn::engine& engine, const bool fuse_relu,
|
|
|
|
|
const std::vector<float> output_shift_scale,
|
|
|
|
|
bool is_test) const {
|
|
|
|
|
memory::dims stride_dims = {strides[0], strides[1]};
|
|
|
|
@ -665,7 +687,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
propagation, mkldnn::convolution_direct, src, weights, bias, dst,
|
|
|
|
|
stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
|
|
|
|
|
|
|
|
|
|
mkldnn::primitive_attr conv_attr = CreatePostOps(output_shift_scale);
|
|
|
|
|
mkldnn::primitive_attr conv_attr =
|
|
|
|
|
CreatePostOps(fuse_relu, output_shift_scale);
|
|
|
|
|
|
|
|
|
|
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
|
|
|
|
|
conv_desc, conv_attr, engine);
|
|
|
|
|