|
|
|
@ -72,7 +72,7 @@ static mkldnn::memory::data_type GetDstType(bool is_int8,
|
|
|
|
|
return dst_dt;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
template <typename T, typename K, typename T_out>
|
|
|
|
|
class ConvMKLDNNHandlerT
|
|
|
|
|
: public platform::MKLDNNHandlerT<T, mkldnn::convolution_forward> {
|
|
|
|
|
public:
|
|
|
|
@ -227,7 +227,7 @@ class ConvMKLDNNHandlerT
|
|
|
|
|
platform::MKLDNNMemDesc(weights_tz, platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
MKLDNNMemoryFormat::any);
|
|
|
|
|
const auto dst_md = platform::MKLDNNMemDesc(
|
|
|
|
|
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
|
|
|
|
|
dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
|
|
|
|
|
|
|
|
|
|
const auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
|
|
|
|
|
: mkldnn::prop_kind::forward_training;
|
|
|
|
@ -313,29 +313,29 @@ class ConvMKLDNNHandlerT
|
|
|
|
|
if (is_test && weights_mem_p) {
|
|
|
|
|
return weights_mem_p;
|
|
|
|
|
} else {
|
|
|
|
|
const T* filter_data = filter->data<T>();
|
|
|
|
|
const K* filter_data = filter->data<K>();
|
|
|
|
|
auto weights_tz = framework::vectorize(filter->dims());
|
|
|
|
|
GetWeightsTz(weights_tz, groups);
|
|
|
|
|
|
|
|
|
|
auto user_src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
weights_tz, platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
weights_tz, platform::MKLDNNGetDataType<K>(),
|
|
|
|
|
GetWeightsFormat(filter->format(), groups, is_conv3d));
|
|
|
|
|
|
|
|
|
|
return this->AcquireMemoryWithReorder(
|
|
|
|
|
user_src_md, this->fwd_pd_->weights_desc(),
|
|
|
|
|
to_void_cast<T>(filter_data), "@weights_mem_p", is_test);
|
|
|
|
|
to_void_cast<K>(filter_data), "@weights_mem_p", is_test);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder(
|
|
|
|
|
const framework::Tensor* bias, const bool is_test) {
|
|
|
|
|
const T* bias_data = bias->data<T>();
|
|
|
|
|
const K* bias_data = bias->data<K>();
|
|
|
|
|
auto user_bias_md = platform::MKLDNNMemDesc(
|
|
|
|
|
framework::vectorize(bias->dims()), platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
framework::vectorize(bias->dims()), platform::MKLDNNGetDataType<K>(),
|
|
|
|
|
MKLDNNMemoryFormat::x);
|
|
|
|
|
|
|
|
|
|
return this->AcquireMemoryWithReorder(
|
|
|
|
|
user_bias_md, this->fwd_pd_->bias_desc(), to_void_cast<T>(bias_data),
|
|
|
|
|
user_bias_md, this->fwd_pd_->bias_desc(), to_void_cast<K>(bias_data),
|
|
|
|
|
"@bias_mem_p", is_test);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -358,14 +358,14 @@ class ConvMKLDNNHandlerT
|
|
|
|
|
if (residual_param->format() !=
|
|
|
|
|
platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc())) {
|
|
|
|
|
auto residual_memory_p = this->AcquireResidualMemory(residual_param);
|
|
|
|
|
dst_memory_p = this->AcquireDstMemory(output);
|
|
|
|
|
dst_memory_p = this->template AcquireDstMemory<T_out>(output);
|
|
|
|
|
this->AcquireReorder(residual_memory_p, dst_memory_p, "@residual_dst");
|
|
|
|
|
} else {
|
|
|
|
|
// Changing ShareDataWith to TensorCopy results in performance drop
|
|
|
|
|
// on ResNet architectures
|
|
|
|
|
// (https://github.com/PaddlePaddle/Paddle/issues/22964)
|
|
|
|
|
output->ShareDataWith(*residual_param);
|
|
|
|
|
dst_memory_p = this->AcquireDstMemory(output);
|
|
|
|
|
dst_memory_p = this->template AcquireDstMemory<T_out>(output);
|
|
|
|
|
}
|
|
|
|
|
return dst_memory_p;
|
|
|
|
|
}
|
|
|
|
@ -381,7 +381,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
bool is_INT8 =
|
|
|
|
|
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
|
|
|
|
|
if (!is_INT8) {
|
|
|
|
|
ComputeFP32(ctx);
|
|
|
|
|
ComputeFP32<float>(ctx);
|
|
|
|
|
} else {
|
|
|
|
|
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
|
|
|
|
|
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
|
|
|
|
@ -399,6 +399,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T_out>
|
|
|
|
|
void ComputeFP32(const paddle::framework::ExecutionContext& ctx) const {
|
|
|
|
|
auto& dev_ctx =
|
|
|
|
|
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
|
|
|
|
@ -414,7 +415,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
|
|
|
|
|
auto* output = ctx.Output<Tensor>("Output");
|
|
|
|
|
|
|
|
|
|
ConvMKLDNNHandlerT<T> handler(
|
|
|
|
|
ConvMKLDNNHandlerT<T, K, T_out> handler(
|
|
|
|
|
ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, filter, bias,
|
|
|
|
|
output, ctx.InputName("Input") + ctx.InputName("Filter"));
|
|
|
|
|
|
|
|
|
@ -429,7 +430,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
dst_memory_p =
|
|
|
|
|
handler.AcquireDstMemoryWithResidual(output, residual_param);
|
|
|
|
|
} else {
|
|
|
|
|
dst_memory_p = handler.AcquireDstMemory(output);
|
|
|
|
|
dst_memory_p = handler.template AcquireDstMemory<T_out>(output);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto conv_p = handler.AcquireForwardPrimitive();
|
|
|
|
|