|
|
|
@ -43,16 +43,9 @@ class MKLDNNHandlerT {
|
|
|
|
|
engine_(engine),
|
|
|
|
|
place_(cpu_place),
|
|
|
|
|
key_common_(base_key),
|
|
|
|
|
key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)),
|
|
|
|
|
fwd_pd_(nullptr),
|
|
|
|
|
bwd_pd_(nullptr) {
|
|
|
|
|
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() !=
|
|
|
|
|
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
|
|
|
|
|
key_ = key_common_;
|
|
|
|
|
} else {
|
|
|
|
|
key_ = key_common_ + "-t:" + ThreadIDasStr();
|
|
|
|
|
}
|
|
|
|
|
key_ += dev_ctx.GetKeySuffix();
|
|
|
|
|
}
|
|
|
|
|
bwd_pd_(nullptr) {}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<TForward> AcquireForwardPrimitive() {
|
|
|
|
|
const std::string key_p = key_ + "@fwd_p";
|
|
|
|
@ -306,8 +299,8 @@ class MKLDNNHandlerT {
|
|
|
|
|
const MKLDNNDeviceContext& dev_ctx_;
|
|
|
|
|
mkldnn::engine engine_;
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
std::string key_;
|
|
|
|
|
std::string key_common_;
|
|
|
|
|
std::string key_;
|
|
|
|
|
std::shared_ptr<typename TForward::primitive_desc> fwd_pd_;
|
|
|
|
|
std::shared_ptr<typename TBackward::primitive_desc> bwd_pd_;
|
|
|
|
|
};
|
|
|
|
@ -317,15 +310,10 @@ class MKLDNNHandler {
|
|
|
|
|
public:
|
|
|
|
|
MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
|
|
|
|
|
const std::string& base_key)
|
|
|
|
|
: dev_ctx_(dev_ctx), engine_(engine), key_common_(base_key) {
|
|
|
|
|
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() !=
|
|
|
|
|
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
|
|
|
|
|
key_ = key_common_;
|
|
|
|
|
} else {
|
|
|
|
|
key_ = key_common_ + "-t:" + ThreadIDasStr();
|
|
|
|
|
}
|
|
|
|
|
key_ += dev_ctx.GetKeySuffix();
|
|
|
|
|
}
|
|
|
|
|
: dev_ctx_(dev_ctx),
|
|
|
|
|
engine_(engine),
|
|
|
|
|
key_common_(base_key),
|
|
|
|
|
key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)) {}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
|
|
|
|
|
const mkldnn::memory::desc& md, void* ptr) {
|
|
|
|
@ -508,8 +496,8 @@ class MKLDNNHandler {
|
|
|
|
|
protected:
|
|
|
|
|
const MKLDNNDeviceContext& dev_ctx_;
|
|
|
|
|
mkldnn::engine engine_;
|
|
|
|
|
std::string key_;
|
|
|
|
|
std::string key_common_;
|
|
|
|
|
std::string key_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -524,7 +512,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
|
|
|
|
|
: platform::MKLDNNHandlerT<T, dnnl::binary>(
|
|
|
|
|
dev_ctx, engine, cpu_place,
|
|
|
|
|
platform::CreateKey(
|
|
|
|
|
framework::vectorize(x->dims()),
|
|
|
|
|
dev_ctx, framework::vectorize(x->dims()),
|
|
|
|
|
uniq_name + (algo == dnnl::algorithm::binary_mul ? "M" : ""))) {
|
|
|
|
|
// bradcasting combined with in-place may require
|
|
|
|
|
auto rankdiff = x->dims().size() - y->dims().size();
|
|
|
|
@ -627,7 +615,7 @@ class ActivationMKLDNNHandler
|
|
|
|
|
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
|
|
|
|
|
mkldnn::eltwise_backward>(
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
|
|
|
|
platform::CreateKey(dims, "a", algorithm, unique_name)) {
|
|
|
|
|
platform::CreateKey(dev_ctx, dims, "a", algorithm, unique_name)) {
|
|
|
|
|
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
|
|
|
|
|
|
|
|
|
|
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
|
|
|
|
@ -645,7 +633,7 @@ class ActivationMKLDNNHandler
|
|
|
|
|
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
|
|
|
|
|
mkldnn::eltwise_backward>(
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
|
|
|
|
platform::CreateKey(dims, "a", algorithm, unique_name)) {
|
|
|
|
|
platform::CreateKey(dev_ctx, dims, "a", algorithm, unique_name)) {
|
|
|
|
|
auto diff_dst_md = platform::MKLDNNMemDesc(
|
|
|
|
|
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
|
|
|
|
|
auto src_md =
|
|
|
|
@ -676,7 +664,7 @@ class LRNMKLDNNHandler
|
|
|
|
|
|
|
|
|
|
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
|
|
|
|
|
dev_ctx, mkldnn_engine, cpu_place,
|
|
|
|
|
platform::CreateKey(framework::vectorize(input->dims()),
|
|
|
|
|
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
|
|
|
|
|
unique_name)) {
|
|
|
|
|
if (!this->isCached()) {
|
|
|
|
|
const int n = ctx.Attr<int>("n");
|
|
|
|
@ -712,7 +700,7 @@ class LRNMKLDNNHandler
|
|
|
|
|
|
|
|
|
|
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
|
|
|
|
platform::CreateKey(dims, unique_name)) {
|
|
|
|
|
platform::CreateKey(dev_ctx, dims, unique_name)) {
|
|
|
|
|
auto src_md =
|
|
|
|
|
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
|
|
|
|
|
auto diff_md =
|
|
|
|
@ -752,7 +740,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
|
|
|
|
|
: platform::MKLDNNHandlerT<T, mkldnn::pooling_forward,
|
|
|
|
|
mkldnn::pooling_backward>(
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
|
|
|
|
platform::CreateKey(framework::vectorize(input->dims()),
|
|
|
|
|
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
|
|
|
|
|
framework::ToMKLDNNDataType(input->type()),
|
|
|
|
|
unique_name)) {
|
|
|
|
|
if (!this->isCached()) {
|
|
|
|
@ -861,7 +849,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
|
|
|
|
|
: platform::MKLDNNHandlerT<T, mkldnn::pooling_forward,
|
|
|
|
|
mkldnn::pooling_backward>(
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
|
|
|
|
platform::CreateKey(diff_src_dims, dt, unique_name)) {
|
|
|
|
|
platform::CreateKey(dev_ctx, diff_src_dims, dt, unique_name)) {
|
|
|
|
|
auto diff_dst_md = mkldnn::memory::desc(
|
|
|
|
|
diff_dst_dims, platform::MKLDNNGetDataType<T>(), diff_dst_fmt);
|
|
|
|
|
auto diff_src_md =
|
|
|
|
|