|
|
|
@ -400,6 +400,93 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
|
std::vector<int> logical_axis_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ReorderMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
|
public:
|
|
|
|
|
ReorderMKLDNNHandler(std::vector<int>& dims, // NOLINT
|
|
|
|
|
framework::proto::VarType::Type vtype,
|
|
|
|
|
mkldnn::memory::data_type dtype,
|
|
|
|
|
const platform::MKLDNNDeviceContext& dev_ctx,
|
|
|
|
|
mkldnn::engine engine, const std::string& base_key)
|
|
|
|
|
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
|
|
|
|
|
dims_(dims),
|
|
|
|
|
vtype_(vtype),
|
|
|
|
|
dtype_(dtype) {}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
|
|
|
|
|
const mkldnn::memory::format& fmt, void* ptr) {
|
|
|
|
|
auto local_key = key_ + "@user_src_mem_p";
|
|
|
|
|
auto mem_p =
|
|
|
|
|
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
|
|
|
|
|
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
|
|
|
|
|
" find mem primitive in device context");
|
|
|
|
|
if (mem_p == nullptr) {
|
|
|
|
|
auto src_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt);
|
|
|
|
|
mem_p = std::make_shared<mkldnn::memory>(
|
|
|
|
|
mkldnn::memory::primitive_desc{src_md, engine_}, ptr);
|
|
|
|
|
dev_ctx_.SetBlob(local_key, mem_p);
|
|
|
|
|
} else {
|
|
|
|
|
mem_p->set_data_handle(ptr);
|
|
|
|
|
is_reusing_ = true;
|
|
|
|
|
}
|
|
|
|
|
return mem_p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireDstMemory(
|
|
|
|
|
framework::Tensor* output, const mkldnn::memory::format& fmt,
|
|
|
|
|
platform::Place place) {
|
|
|
|
|
auto local_key = key_ + "@user_dst_mem_p";
|
|
|
|
|
auto mem_p =
|
|
|
|
|
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
|
|
|
|
|
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
|
|
|
|
|
" find mem primitive in device context");
|
|
|
|
|
if (mem_p == nullptr) {
|
|
|
|
|
auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt);
|
|
|
|
|
auto dst_mdp = mkldnn::memory::primitive_desc{dst_md, engine_};
|
|
|
|
|
|
|
|
|
|
auto dst_data = output->mutable_data(place, vtype_);
|
|
|
|
|
|
|
|
|
|
mem_p = std::make_shared<mkldnn::memory>(dst_mdp, dst_data);
|
|
|
|
|
dev_ctx_.SetBlob(local_key, mem_p);
|
|
|
|
|
} else {
|
|
|
|
|
auto dst_data = output->mutable_data(place, vtype_);
|
|
|
|
|
mem_p->set_data_handle(dst_data);
|
|
|
|
|
is_reusing_ = true;
|
|
|
|
|
}
|
|
|
|
|
return mem_p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::reorder> AcquireReorder(
|
|
|
|
|
std::shared_ptr<mkldnn::memory> dst_memory_p,
|
|
|
|
|
std::shared_ptr<mkldnn::memory> src_memory_p) {
|
|
|
|
|
auto prim_key = key_ + "@reorder_p";
|
|
|
|
|
auto reorder_p =
|
|
|
|
|
std::static_pointer_cast<mkldnn::reorder>(dev_ctx_.GetBlob(prim_key));
|
|
|
|
|
PADDLE_ENFORCE((reorder_p != nullptr) || (is_reusing_ == false),
|
|
|
|
|
"Fail to find convolution primitive in device context");
|
|
|
|
|
if (reorder_p == nullptr) {
|
|
|
|
|
reorder_p =
|
|
|
|
|
std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p));
|
|
|
|
|
dev_ctx_.SetBlob(prim_key, reorder_p);
|
|
|
|
|
} else {
|
|
|
|
|
is_reusing_ = true;
|
|
|
|
|
}
|
|
|
|
|
return reorder_p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::string GetHash(std::vector<int>& shape, // NOLINT
|
|
|
|
|
mkldnn::memory::format in_fmt,
|
|
|
|
|
mkldnn::memory::format out_fmt,
|
|
|
|
|
const std::string& suffix) {
|
|
|
|
|
return dims2str(shape) + std::to_string(in_fmt) + "->" +
|
|
|
|
|
std::to_string(out_fmt) + "#" + suffix;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::vector<int> dims_;
|
|
|
|
|
framework::proto::VarType::Type vtype_;
|
|
|
|
|
mkldnn::memory::data_type dtype_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct convolutional_algorithm;
|
|
|
|
|
|
|
|
|
|