|
|
|
@ -130,12 +130,13 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive(
|
|
|
|
|
const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
|
|
|
|
|
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
|
|
|
|
|
std::vector<mkldnn::primitive>& pipeline, // NOLINT
|
|
|
|
|
bool is_persistent = false) {
|
|
|
|
|
auto user_weights_pd = user_weights_memory_p->get_primitive_desc();
|
|
|
|
|
auto weights_pd = conv_pd_->weights_primitive_desc();
|
|
|
|
|
return this->AcquireMemory(weights_pd, user_weights_pd,
|
|
|
|
|
user_weights_memory_p, "@weights_mem_p",
|
|
|
|
|
pipeline);
|
|
|
|
|
pipeline, is_persistent);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive(
|
|
|
|
@ -266,6 +267,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
|
|
|
|
|
"It must use CPUPlace.");
|
|
|
|
|
|
|
|
|
|
const bool is_test = ctx.Attr<bool>("is_test");
|
|
|
|
|
|
|
|
|
|
auto& dev_ctx =
|
|
|
|
|
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
|
|
|
|
|
const auto& mkldnn_engine = dev_ctx.GetEngine();
|
|
|
|
@ -371,7 +374,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
auto src_memory_p =
|
|
|
|
|
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
|
|
|
|
|
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
|
|
|
|
|
user_weights_memory_p, pipeline);
|
|
|
|
|
user_weights_memory_p, pipeline, is_test);
|
|
|
|
|
auto dst_memory_p =
|
|
|
|
|
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
|
|
|
|
|
|
|
|
|
|