|
|
|
@ -484,9 +484,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
handler.reset(
|
|
|
|
|
new platform::ConvMKLDNNHandler(dev_ctx, mkldnn_engine, key));
|
|
|
|
|
// create a conv primitive descriptor and save it for usage in backward
|
|
|
|
|
// TODO(lidanqing): We use relu post-op instead of brelu post-op cause
|
|
|
|
|
// mkldnn v0.18 does not support INT8 brelu post-op. Use code in /**/ when
|
|
|
|
|
// v0.20 is enabled
|
|
|
|
|
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
|
|
|
|
|
: mkldnn::prop_kind::forward_training;
|
|
|
|
|
|
|
|
|
@ -496,15 +493,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
mkldnn::memory::format::x);
|
|
|
|
|
conv_pd = handler->AcquireConvolutionPrimitiveDescriptor(
|
|
|
|
|
src_md, weights_md, bias_md, dst_md, strides, paddings,
|
|
|
|
|
mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/,
|
|
|
|
|
fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold,
|
|
|
|
|
propagation, output_shift_scale, sum_scale);
|
|
|
|
|
mkldnn_engine, fuse_relu, fuse_residual_conn, fuse_brelu,
|
|
|
|
|
fuse_brelu_threshold, propagation, output_shift_scale, sum_scale);
|
|
|
|
|
} else {
|
|
|
|
|
conv_pd = handler->AcquireConvolutionPrimitiveDescriptor(
|
|
|
|
|
src_md, weights_md, boost::none, dst_md, strides, paddings,
|
|
|
|
|
mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/,
|
|
|
|
|
fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold,
|
|
|
|
|
propagation, output_shift_scale, sum_scale);
|
|
|
|
|
mkldnn_engine, fuse_relu, fuse_residual_conn, fuse_brelu,
|
|
|
|
|
fuse_brelu_threshold, propagation, output_shift_scale, sum_scale);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// create mkldnn memory from input tensors (data/weights)
|
|
|
|
|