|
|
|
@ -290,13 +290,25 @@ class ConvMKLDNNHandlerT
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryWithReorder(
|
|
|
|
|
const framework::Tensor* input) {
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
|
auto user_src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
framework::vectorize(input->dims()), platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
input->format());
|
|
|
|
|
const std::string user_key_suffix{"@src_mem_p_user"};
|
|
|
|
|
auto user_src_mem_p = this->AcquireMemory(user_key_suffix);
|
|
|
|
|
|
|
|
|
|
return this->AcquireMemoryWithReorder(
|
|
|
|
|
user_src_md, this->fwd_pd_->src_desc(), to_void_cast<T>(input_data),
|
|
|
|
|
"@src_mem_p");
|
|
|
|
|
if (!user_src_mem_p) {
|
|
|
|
|
auto user_src_md = platform::MKLDNNMemDesc(
|
|
|
|
|
framework::vectorize(input->dims()), platform::MKLDNNGetDataType<T>(),
|
|
|
|
|
input->format());
|
|
|
|
|
return this->AcquireMemoryWithReorder(
|
|
|
|
|
user_src_md, this->fwd_pd_->src_desc(), to_void_cast<T>(input_data),
|
|
|
|
|
"@src_mem_p");
|
|
|
|
|
} else {
|
|
|
|
|
const std::string target_key_suffix{"@src_mem_p_target"};
|
|
|
|
|
const auto target_src_mem_p = this->AcquireMemory(target_key_suffix);
|
|
|
|
|
user_src_mem_p->set_data_handle(to_void_cast<T>(input_data));
|
|
|
|
|
if (user_src_mem_p != target_src_mem_p) {
|
|
|
|
|
this->AcquireReorder(user_src_mem_p, target_src_mem_p, "@src_mem_p");
|
|
|
|
|
}
|
|
|
|
|
return target_src_mem_p;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder(
|
|
|
|
@ -324,14 +336,19 @@ class ConvMKLDNNHandlerT
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder(
|
|
|
|
|
const framework::Tensor* bias, const bool is_test) {
|
|
|
|
|
const K* bias_data = bias->data<K>();
|
|
|
|
|
auto user_bias_md = platform::MKLDNNMemDesc(
|
|
|
|
|
framework::vectorize(bias->dims()), platform::MKLDNNGetDataType<K>(),
|
|
|
|
|
MKLDNNMemoryFormat::x);
|
|
|
|
|
|
|
|
|
|
return this->AcquireMemoryWithReorder(
|
|
|
|
|
user_bias_md, this->fwd_pd_->bias_desc(), to_void_cast<K>(bias_data),
|
|
|
|
|
"@bias_mem_p", is_test);
|
|
|
|
|
auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target");
|
|
|
|
|
if (is_test && bias_mem_p) {
|
|
|
|
|
return bias_mem_p;
|
|
|
|
|
} else {
|
|
|
|
|
const K* bias_data = bias->data<K>();
|
|
|
|
|
auto user_bias_md = platform::MKLDNNMemDesc(
|
|
|
|
|
framework::vectorize(bias->dims()), platform::MKLDNNGetDataType<K>(),
|
|
|
|
|
MKLDNNMemoryFormat::x);
|
|
|
|
|
|
|
|
|
|
return this->AcquireMemoryWithReorder(
|
|
|
|
|
user_bias_md, this->fwd_pd_->bias_desc(), to_void_cast<K>(bias_data),
|
|
|
|
|
"@bias_mem_p", is_test);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireResidualMemory(
|
|
|
|
@ -340,13 +357,19 @@ class ConvMKLDNNHandlerT
|
|
|
|
|
residual_param->type() == framework::DataTypeTrait<T_out>::DataType()
|
|
|
|
|
? to_void_cast<T_out>(residual_param->data<T_out>())
|
|
|
|
|
: to_void_cast<T>(residual_param->data<T>());
|
|
|
|
|
auto user_residual_md = platform::MKLDNNMemDesc(
|
|
|
|
|
framework::vectorize(residual_param->dims()),
|
|
|
|
|
framework::ToMKLDNNDataType(residual_param->type()),
|
|
|
|
|
residual_param->format());
|
|
|
|
|
auto residual_mem_p = this->AcquireMemory("@user_residual_data_mem_p");
|
|
|
|
|
if (residual_mem_p) {
|
|
|
|
|
residual_mem_p->set_data_handle(residual_data);
|
|
|
|
|
return residual_mem_p;
|
|
|
|
|
} else {
|
|
|
|
|
auto user_residual_md = platform::MKLDNNMemDesc(
|
|
|
|
|
framework::vectorize(residual_param->dims()),
|
|
|
|
|
framework::ToMKLDNNDataType(residual_param->type()),
|
|
|
|
|
residual_param->format());
|
|
|
|
|
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(user_residual_md, residual_data,
|
|
|
|
|
"@user_residual_data_mem_p");
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(user_residual_md, residual_data,
|
|
|
|
|
"@user_residual_data_mem_p");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireDstMemoryWithResidual(
|
|
|
|
|