|
|
|
@ -548,9 +548,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
|
|
|
|
|
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
|
|
|
|
|
"Fail to find convolution primitive in device context");
|
|
|
|
|
if (conv_p == nullptr) {
|
|
|
|
|
conv_p = std::make_shared<forward_t>(*conv_pd_, *(src_memory_p),
|
|
|
|
|
*(weights_memory_p.get()),
|
|
|
|
|
*(dst_memory_p.get()));
|
|
|
|
|
conv_p = std::make_shared<forward_t>(*conv_pd_, *src_memory_p,
|
|
|
|
|
*weights_memory_p, *dst_memory_p);
|
|
|
|
|
|
|
|
|
|
dev_ctx_.SetBlob(prim_key, conv_p);
|
|
|
|
|
} else {
|
|
|
|
@ -570,9 +569,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
|
|
|
|
|
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
|
|
|
|
|
"Fail to find convolution primitive in device context");
|
|
|
|
|
if (conv_p == nullptr) {
|
|
|
|
|
conv_p = std::make_shared<forward_t>(
|
|
|
|
|
*conv_pd_, *(src_memory_p), *(weights_memory_p.get()),
|
|
|
|
|
*(bias_memory_p.get()), *(dst_memory_p.get()));
|
|
|
|
|
conv_p = std::make_shared<forward_t>(*conv_pd_, *src_memory_p,
|
|
|
|
|
*weights_memory_p, *bias_memory_p,
|
|
|
|
|
*dst_memory_p);
|
|
|
|
|
|
|
|
|
|
dev_ctx_.SetBlob(prim_key, conv_p);
|
|
|
|
|
} else {
|
|
|
|
|