|
|
|
@ -210,6 +210,73 @@ class MKLDNNHandlerT {
|
|
|
|
|
return mem_p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AcquireReorder(const std::shared_ptr<mkldnn::memory>& user_memory_p,
|
|
|
|
|
const std::shared_ptr<mkldnn::memory>& target_memory_p,
|
|
|
|
|
const std::string& suffix) {
|
|
|
|
|
const auto key_reorder_p = key_ + suffix + "reorder_p";
|
|
|
|
|
|
|
|
|
|
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
|
|
|
|
|
dev_ctx_.GetBlob(key_reorder_p));
|
|
|
|
|
|
|
|
|
|
if (reorder_p == nullptr) {
|
|
|
|
|
reorder_p =
|
|
|
|
|
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
|
|
|
|
|
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mkldnn::stream astream(engine_);
|
|
|
|
|
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
|
|
|
|
|
{MKLDNN_ARG_TO, *target_memory_p}});
|
|
|
|
|
astream.wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorder(
|
|
|
|
|
const mkldnn::memory::desc& user_md,
|
|
|
|
|
const mkldnn::memory::desc& target_md, void* ptr,
|
|
|
|
|
const std::string& suffix, bool is_persistent = false) {
|
|
|
|
|
const auto target_key = key_ + suffix + "_target";
|
|
|
|
|
const auto key_reorder_p = key_ + suffix + "reorder_p";
|
|
|
|
|
const auto user_key = key_ + suffix + "_user";
|
|
|
|
|
|
|
|
|
|
auto target_memory_p =
|
|
|
|
|
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(target_key));
|
|
|
|
|
|
|
|
|
|
if (target_memory_p == nullptr) {
|
|
|
|
|
auto user_memory_p =
|
|
|
|
|
std::make_shared<dnnl::memory>(user_md, engine_, ptr);
|
|
|
|
|
if (user_md != target_md) {
|
|
|
|
|
target_memory_p = std::make_shared<mkldnn::memory>(target_md, engine_);
|
|
|
|
|
auto reorder_p =
|
|
|
|
|
std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
|
|
|
|
|
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
|
|
|
|
|
|
|
|
|
|
mkldnn::stream astream(engine_);
|
|
|
|
|
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
|
|
|
|
|
{MKLDNN_ARG_TO, *target_memory_p}});
|
|
|
|
|
astream.wait();
|
|
|
|
|
} else {
|
|
|
|
|
target_memory_p = user_memory_p;
|
|
|
|
|
}
|
|
|
|
|
dev_ctx_.SetBlob(user_key, user_memory_p);
|
|
|
|
|
dev_ctx_.SetBlob(target_key, target_memory_p);
|
|
|
|
|
} else if (!is_persistent) {
|
|
|
|
|
mkldnn::stream astream(engine_);
|
|
|
|
|
|
|
|
|
|
auto user_memory_p =
|
|
|
|
|
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(user_key));
|
|
|
|
|
user_memory_p->set_data_handle(ptr);
|
|
|
|
|
|
|
|
|
|
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
|
|
|
|
|
dev_ctx_.GetBlob(key_reorder_p));
|
|
|
|
|
if (reorder_p != nullptr) {
|
|
|
|
|
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
|
|
|
|
|
{MKLDNN_ARG_TO, *target_memory_p}});
|
|
|
|
|
astream.wait();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return target_memory_p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireMemory(const std::string& suffix) {
|
|
|
|
|
const auto local_key = key_ + suffix;
|
|
|
|
|
return std::static_pointer_cast<mkldnn::memory>(
|
|
|
|
|