You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							524 lines
						
					
					
						
							18 KiB
						
					
					
				
			
		
		
	
	
							524 lines
						
					
					
						
							18 KiB
						
					
					
				/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
 | 
						|
 | 
						|
Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
you may not use this file except in compliance with the License.
 | 
						|
You may obtain a copy of the License at
 | 
						|
 | 
						|
    http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 | 
						|
Unless required by applicable law or agreed to in writing, software
 | 
						|
distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
See the License for the specific language governing permissions and
 | 
						|
limitations under the License. */
 | 
						|
#pragma once
 | 
						|
 | 
						|
#include <algorithm>
 | 
						|
#include <iostream>
 | 
						|
#include <memory>
 | 
						|
#include <sstream>
 | 
						|
#include <string>
 | 
						|
#include <utility>
 | 
						|
#include <vector>
 | 
						|
#include "mkldnn.hpp"
 | 
						|
#include "paddle/fluid/framework/operator.h"
 | 
						|
#include "paddle/fluid/platform/place.h"
 | 
						|
#include "paddle/fluid/platform/profiler.h"
 | 
						|
namespace paddle {
 | 
						|
#ifdef PADDLE_WITH_MKLDNN
 | 
						|
using MKLDNNMemoryFormat = mkldnn::memory::format_tag;
 | 
						|
#endif
 | 
						|
namespace platform {
 | 
						|
 | 
						|
using MKLDNNStream = mkldnn::stream;
 | 
						|
using MKLDNNEngine = mkldnn::engine;
 | 
						|
using MKLDNNMemory = mkldnn::memory;
 | 
						|
using MKLDNNMemoryDescriptor = mkldnn::memory::desc;
 | 
						|
using MKLDNNPrimitive = mkldnn::primitive;
 | 
						|
using MKLDNNPrimitiveDesc = mkldnn::handle<mkldnn_primitive_desc_t>;
 | 
						|
 | 
						|
typedef std::unique_ptr<MKLDNNStream> MKLDNNStreamPtr;
 | 
						|
typedef std::unique_ptr<MKLDNNEngine> MKLDNNEnginePtr;
 | 
						|
typedef std::unique_ptr<MKLDNNMemory> MKLDNNMemoryPtr;
 | 
						|
typedef std::unique_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr;
 | 
						|
typedef std::unique_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr;
 | 
						|
 | 
						|
template <typename Type>
 | 
						|
void* to_void_cast(const Type* t) {
 | 
						|
  return static_cast<void*>(const_cast<Type*>(t));
 | 
						|
}
 | 
						|
 | 
						|
template <typename Type>
 | 
						|
void* to_void_reinterpret_cast(const Type* t) {
 | 
						|
  return reinterpret_cast<void*>(const_cast<Type*>(t));
 | 
						|
}
 | 
						|
 | 
						|
template <class Type>
 | 
						|
using tf_desc = typename Type::desc;
 | 
						|
 | 
						|
template <class Type>
 | 
						|
using tf_pd = typename Type::primitive_desc;
 | 
						|
 | 
						|
template <typename Type, typename Engine, typename... Args>
 | 
						|
std::shared_ptr<tf_pd<Type>> MKLDNNFwdPrimitiveDesc(const Engine& e,
 | 
						|
                                                    Args&&... args) {
 | 
						|
  auto desc = tf_desc<Type>(mkldnn::prop_kind::forward, (args)...);
 | 
						|
  auto pd = new tf_pd<Type>(desc, e);
 | 
						|
  return std::shared_ptr<tf_pd<Type>>(pd);
 | 
						|
}
 | 
						|
 | 
						|
template <typename Type, typename Engine, typename Primitive, typename... Args>
 | 
						|
tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
 | 
						|
                                   Args&&... args) {
 | 
						|
  auto desc = tf_desc<Type>(args...);
 | 
						|
  return tf_pd<Type>(desc, e, p);
 | 
						|
}
 | 
						|
 | 
						|
inline void MatchShapeToLayout(framework::Tensor* tensor_in,
 | 
						|
                               framework::DataLayout from,
 | 
						|
                               framework::DataLayout to) {
 | 
						|
  // In these data layouts, channel dimension is either on 2nd position: nChw or
 | 
						|
  // at last nhwC, so for dim==2 these layouts are the same and nothing should
 | 
						|
  // be done. Similarly for dim==1 when you have just one possible combination.
 | 
						|
  if (tensor_in->dims().size() < 3) {
 | 
						|
    return;
 | 
						|
  }
 | 
						|
 | 
						|
  auto print_dims = [](const std::vector<int>& dims) {
 | 
						|
    std::ostringstream oss;
 | 
						|
 | 
						|
    if (!dims.empty()) {
 | 
						|
      oss << "[";
 | 
						|
      // Convert all but the last element to avoid a trailing ","
 | 
						|
      std::copy(dims.begin(), dims.end() - 1,
 | 
						|
                std::ostream_iterator<int>(oss, ","));
 | 
						|
 | 
						|
      // Now add the last element with no delimiter
 | 
						|
      oss << dims.back() << "]";
 | 
						|
    }
 | 
						|
 | 
						|
    return oss.str();
 | 
						|
  };
 | 
						|
 | 
						|
  switch (from) {
 | 
						|
    case framework::DataLayout::kMKLDNN:
 | 
						|
      if (to == framework::DataLayout::kNHWC) {
 | 
						|
        auto dims = framework::vectorize<int>(tensor_in->dims());
 | 
						|
        std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
 | 
						|
        tensor_in->Resize(framework::make_ddim(dims));
 | 
						|
        VLOG(3) << "Rotating Shape from: kMKLDNN to: kNHWC output_shape"
 | 
						|
                << print_dims(dims);
 | 
						|
      }
 | 
						|
      break;
 | 
						|
    case framework::DataLayout::kNHWC:
 | 
						|
      if (to == framework::DataLayout::kMKLDNN) {
 | 
						|
        auto dims = framework::vectorize<int>(tensor_in->dims());
 | 
						|
        std::rotate(dims.begin() + 1, dims.end() - 1, dims.end());
 | 
						|
        tensor_in->Resize(framework::make_ddim(dims));
 | 
						|
        VLOG(3) << "Rotating Shape from: kNHWC to: kMKLDNN output_shape"
 | 
						|
                << print_dims(dims);
 | 
						|
      }
 | 
						|
      break;
 | 
						|
    default:
 | 
						|
      break;
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
struct mkldnn_dummy_primitive {
 | 
						|
  struct primitive_desc {};
 | 
						|
  struct desc {};
 | 
						|
};
 | 
						|
 | 
						|
inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
 | 
						|
                                          mkldnn::memory::data_type data_type,
 | 
						|
                                          MKLDNNMemoryFormat format) {
 | 
						|
  return mkldnn::memory::desc({dims}, data_type, format);
 | 
						|
}
 | 
						|
 | 
						|
inline void ClearMKLDNNCache(const platform::Place& place) {
 | 
						|
  // Clear mkl-dnn cache,
 | 
						|
  if (platform::is_cpu_place(place)) {
 | 
						|
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
 | 
						|
    platform::MKLDNNDeviceContext* dev_ctx =
 | 
						|
        (platform::MKLDNNDeviceContext*)pool.Get(place);
 | 
						|
    dev_ctx->ResetBlobMap();
 | 
						|
    platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
 | 
						|
        paddle::framework::DataLayout::kNCHW);
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
inline void DontClearMKLDNNCache(const platform::Place& place) {
 | 
						|
  // Clear mkl-dnn cache,
 | 
						|
  if (platform::is_cpu_place(place)) {
 | 
						|
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
 | 
						|
    platform::MKLDNNDeviceContext* dev_ctx =
 | 
						|
        (platform::MKLDNNDeviceContext*)pool.Get(place);
 | 
						|
    dev_ctx->BlockNextCacheClearing();
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
template <typename Type>
 | 
						|
mkldnn::memory::data_type MKLDNNGetDataType() {
 | 
						|
  return mkldnn::memory::data_type::undef;
 | 
						|
}
 | 
						|
 | 
						|
template <>
 | 
						|
inline mkldnn::memory::data_type MKLDNNGetDataType<float>() {
 | 
						|
  return mkldnn::memory::data_type::f32;
 | 
						|
}
 | 
						|
template <>
 | 
						|
inline mkldnn::memory::data_type MKLDNNGetDataType<int32_t>() {
 | 
						|
  return mkldnn::memory::data_type::s32;
 | 
						|
}
 | 
						|
template <>
 | 
						|
inline mkldnn::memory::data_type MKLDNNGetDataType<int8_t>() {
 | 
						|
  return mkldnn::memory::data_type::s8;
 | 
						|
}
 | 
						|
template <>
 | 
						|
inline mkldnn::memory::data_type MKLDNNGetDataType<uint8_t>() {
 | 
						|
  return mkldnn::memory::data_type::u8;
 | 
						|
}
 | 
						|
 | 
						|
template <>
 | 
						|
inline mkldnn::memory::data_type
 | 
						|
MKLDNNGetDataType<paddle::platform::bfloat16>() {
 | 
						|
  return mkldnn::memory::data_type::bf16;
 | 
						|
}
 | 
						|
 | 
						|
inline void Reorder(mkldnn::memory src, mkldnn::memory dst,
 | 
						|
                    const mkldnn::engine& engine) {
 | 
						|
  auto reorder_prim = mkldnn::reorder(src, dst);
 | 
						|
  auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
 | 
						|
  platform::RecordEvent record_reorder("int_reorder",
 | 
						|
                                       platform::EventRole::kUniqueOp);
 | 
						|
  reorder_prim.execute(astream, src, dst);
 | 
						|
  astream.wait();
 | 
						|
}
 | 
						|
 | 
						|
inline mkldnn::memory::format_tag GetMKLDNNFormat(
 | 
						|
    mkldnn::memory::desc mem_desc) {
 | 
						|
  auto ndims = mem_desc.data.ndims;
 | 
						|
  auto strides = mem_desc.data.format_desc.blocking.strides;
 | 
						|
  auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks;
 | 
						|
  auto inner_blks = mem_desc.data.format_desc.blocking.inner_blks;
 | 
						|
  auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs;
 | 
						|
 | 
						|
  if (ndims == 1) {
 | 
						|
    return mkldnn::memory::format_tag::x;
 | 
						|
  } else if (ndims == 2) {
 | 
						|
    if (inner_nblks == 0) {
 | 
						|
      if (strides[0] >= strides[1]) {
 | 
						|
        return mkldnn::memory::format_tag::nc;
 | 
						|
      } else {
 | 
						|
        return mkldnn::memory::format_tag::cn;
 | 
						|
      }
 | 
						|
    }
 | 
						|
  } else if (ndims == 3) {
 | 
						|
    if (inner_nblks == 0) {
 | 
						|
      if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
 | 
						|
        return mkldnn::memory::format_tag::ncw;
 | 
						|
      } else if (strides[1] >= strides[0] && strides[0] >= strides[2]) {
 | 
						|
        return mkldnn::memory::format_tag::ntc;
 | 
						|
      } else {
 | 
						|
        return mkldnn::memory::format_tag::nwc;
 | 
						|
      }
 | 
						|
    }
 | 
						|
  } else if (ndims == 4) {
 | 
						|
    if (inner_nblks == 0) {
 | 
						|
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
 | 
						|
          strides[2] >= strides[3]) {
 | 
						|
        return mkldnn::memory::format_tag::nchw;
 | 
						|
      } else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
 | 
						|
                 strides[1] >= strides[0]) {
 | 
						|
        return mkldnn::memory::format_tag::cdba;
 | 
						|
      } else {
 | 
						|
        return mkldnn::memory::format_tag::nhwc;
 | 
						|
      }
 | 
						|
    } else if (inner_nblks == 1) {
 | 
						|
      if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
 | 
						|
        return mkldnn::memory::format_tag::nChw16c;
 | 
						|
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
 | 
						|
        return mkldnn::memory::format_tag::nChw8c;
 | 
						|
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
 | 
						|
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
 | 
						|
            strides[3] >= strides[1]) {
 | 
						|
          return mkldnn::memory::format_tag::Acdb8a;
 | 
						|
        }
 | 
						|
      } else if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
 | 
						|
        return mkldnn::memory::format_tag::nChw4c;
 | 
						|
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
 | 
						|
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
 | 
						|
            strides[3] >= strides[1]) {
 | 
						|
          return mkldnn::memory::format_tag::Acdb16a;
 | 
						|
        }
 | 
						|
      }
 | 
						|
    } else if (inner_nblks == 2) {
 | 
						|
      if (inner_blks[0] == 16 && inner_blks[1] == 16) {
 | 
						|
        if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
 | 
						|
          return mkldnn::memory::format_tag::OIhw16i16o;
 | 
						|
        }
 | 
						|
      } else if (inner_blks[0] == 8 && inner_blks[1] == 8) {
 | 
						|
        if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
 | 
						|
          return mkldnn::memory::format_tag::OIhw8i8o;
 | 
						|
        }
 | 
						|
      }
 | 
						|
    }
 | 
						|
  } else if (ndims == 5) {
 | 
						|
    if (inner_nblks == 0) {
 | 
						|
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
 | 
						|
          strides[2] >= strides[3] && strides[3] >= strides[4]) {
 | 
						|
        return mkldnn::memory::format_tag::ncdhw;
 | 
						|
      } else {
 | 
						|
        return mkldnn::memory::format_tag::ndhwc;
 | 
						|
      }
 | 
						|
    } else if (inner_nblks == 1) {
 | 
						|
      if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
 | 
						|
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
 | 
						|
            strides[3] >= strides[4] && strides[4] >= strides[1]) {
 | 
						|
          return mkldnn::memory::format_tag::Acdeb8a;
 | 
						|
        }
 | 
						|
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
 | 
						|
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
 | 
						|
          return mkldnn::memory::format_tag::Abcde8a;
 | 
						|
        }
 | 
						|
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
 | 
						|
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
 | 
						|
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
 | 
						|
          return mkldnn::memory::format_tag::aBcde8b;
 | 
						|
        }
 | 
						|
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
 | 
						|
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
 | 
						|
            strides[3] >= strides[4] && strides[4] >= strides[1]) {
 | 
						|
          return mkldnn::memory::format_tag::Acdeb16a;
 | 
						|
        }
 | 
						|
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
 | 
						|
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
 | 
						|
          return mkldnn::memory::format_tag::Abcde16a;
 | 
						|
        }
 | 
						|
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
 | 
						|
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
 | 
						|
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
 | 
						|
          return mkldnn::memory::format_tag::aBcde16b;
 | 
						|
        }
 | 
						|
      }
 | 
						|
    }
 | 
						|
  } else if (ndims == 6) {
 | 
						|
    if (inner_nblks == 0) {
 | 
						|
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
 | 
						|
          strides[2] >= strides[3] && strides[3] >= strides[4] &&
 | 
						|
          strides[4] >= strides[5]) {
 | 
						|
        return mkldnn::memory::format_tag::abcdef;
 | 
						|
      }
 | 
						|
    }
 | 
						|
  }
 | 
						|
  // DEBUG CODE - KEEP UNTILL TENSOR.MEMORY_DESC IMPLEMENTED
 | 
						|
  // std::cout<<"@@@@@@@@@@ UNDEFINED FORMAT @@@@@@@@@@@@@@@@@@@"<<std::endl;
 | 
						|
  // std::cout<<"NDIMS: "<<ndims<<std::endl;
 | 
						|
  // std::cout<<"INNER_NBLKS: "<<inner_nblks<<std::endl;
 | 
						|
  // for (int i=0;i<ndims;++i) {
 | 
						|
  //   std::cout<<"STRIDE["<<i<<"]: "<<strides[i]<<std::endl;
 | 
						|
  // }
 | 
						|
  // for (int i=0;i<inner_nblks;++i) {
 | 
						|
  //   std::cout<<"INNER_BLKS["<<i<<"]: "<<inner_blks[i]<<std::endl;
 | 
						|
  // }
 | 
						|
  // for (int i=0;i<inner_nblks;++i) {
 | 
						|
  //   std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl;
 | 
						|
  // }
 | 
						|
  return mkldnn::memory::format_tag::undef;
 | 
						|
}
 | 
						|
 | 
						|
inline mkldnn::memory::format_tag GetMKLDNNFormat(const mkldnn::memory memory) {
 | 
						|
  auto mem_desc = memory.get_desc();
 | 
						|
  return GetMKLDNNFormat(mem_desc);
 | 
						|
}
 | 
						|
 | 
						|
inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size,
 | 
						|
                                              MKLDNNMemoryFormat data_format) {
 | 
						|
  if (dims_size == 1) {
 | 
						|
    return MKLDNNMemoryFormat::x;
 | 
						|
  } else if (dims_size == 2) {
 | 
						|
    return MKLDNNMemoryFormat::nc;
 | 
						|
  } else if (dims_size == 3) {
 | 
						|
    if (data_format == MKLDNNMemoryFormat::nchw) {
 | 
						|
      return MKLDNNMemoryFormat::ncw;
 | 
						|
    } else if (data_format == MKLDNNMemoryFormat::nhwc) {
 | 
						|
      return MKLDNNMemoryFormat::nwc;
 | 
						|
    }
 | 
						|
  } else if (dims_size == 4) {
 | 
						|
    if (data_format == MKLDNNMemoryFormat::goihw) {
 | 
						|
      return MKLDNNMemoryFormat::oihw;
 | 
						|
    }
 | 
						|
  } else if (dims_size == 5) {
 | 
						|
    if (data_format == MKLDNNMemoryFormat::goidhw) {
 | 
						|
      return MKLDNNMemoryFormat::oidhw;
 | 
						|
    }
 | 
						|
    if (data_format == MKLDNNMemoryFormat::nchw) {
 | 
						|
      return MKLDNNMemoryFormat::ncdhw;
 | 
						|
    } else if (data_format == MKLDNNMemoryFormat::nhwc) {
 | 
						|
      return MKLDNNMemoryFormat::ndhwc;
 | 
						|
    }
 | 
						|
  }
 | 
						|
  return data_format;
 | 
						|
}
 | 
						|
 | 
						|
inline MKLDNNMemoryFormat data_format_to_memory_format(
 | 
						|
    const std::string& data_format) {
 | 
						|
  switch (framework::StringToDataLayout(data_format)) {
 | 
						|
    case framework::DataLayout::kNHWC:
 | 
						|
      return MKLDNNMemoryFormat::nhwc;
 | 
						|
    case framework::DataLayout::kNCHW:
 | 
						|
      return MKLDNNMemoryFormat::nchw;
 | 
						|
    default:
 | 
						|
      return MKLDNNMemoryFormat::any;
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
inline MKLDNNMemoryFormat StringToMKLDNNFormat(std::string* format) {
 | 
						|
  std::transform(format->begin(), format->end(), format->begin(), ::tolower);
 | 
						|
 | 
						|
  if (!format->compare("nchw")) {
 | 
						|
    return MKLDNNMemoryFormat::nchw;
 | 
						|
  } else if (!format->compare("nchw16c")) {
 | 
						|
    return MKLDNNMemoryFormat::nChw16c;
 | 
						|
  } else if (!format->compare("nchw8c")) {
 | 
						|
    return MKLDNNMemoryFormat::nChw8c;
 | 
						|
  } else if (!format->compare("nhwc")) {
 | 
						|
    return MKLDNNMemoryFormat::nhwc;
 | 
						|
  } else {
 | 
						|
    return MKLDNNMemoryFormat::any;
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
inline std::string ThreadIDasStr(void) {
 | 
						|
  return std::to_string(
 | 
						|
      std::hash<std::thread::id>()(std::this_thread::get_id()));
 | 
						|
}
 | 
						|
 | 
						|
template <typename T>
 | 
						|
inline void AppendKey(std::string* key, const T& num) {
 | 
						|
  key->append(std::to_string(num));
 | 
						|
}
 | 
						|
 | 
						|
template <>
 | 
						|
inline void AppendKey(std::string* key,
 | 
						|
                      const mkldnn::memory::format_tag& format) {
 | 
						|
  key->append(std::to_string(static_cast<int>(format)));
 | 
						|
}
 | 
						|
 | 
						|
template <>
 | 
						|
inline void AppendKey(std::string* key,
 | 
						|
                      const mkldnn::memory::data_type& data_type) {
 | 
						|
  key->append(std::to_string(static_cast<int>(data_type)));
 | 
						|
}
 | 
						|
 | 
						|
template <>
 | 
						|
inline void AppendKey(std::string* key, const mkldnn::algorithm& algorithm) {
 | 
						|
  key->append(std::to_string(static_cast<int>(algorithm)));
 | 
						|
}
 | 
						|
 | 
						|
template <>
 | 
						|
inline void AppendKey(std::string* key,
 | 
						|
                      const mkldnn::normalization_flags& flags) {
 | 
						|
  key->append(std::to_string(static_cast<int>(flags)));
 | 
						|
}
 | 
						|
 | 
						|
inline void AppendKey(std::string* key, const std::string& str) {
 | 
						|
  key->append(str);
 | 
						|
}
 | 
						|
 | 
						|
inline void AppendKey(std::string* key, const char* str) { key->append(str); }
 | 
						|
 | 
						|
template <typename T>
 | 
						|
inline void AppendKey(std::string* key, const std::vector<T>& dims) {
 | 
						|
  for (size_t i = 0; i < dims.size(); i++) {
 | 
						|
    AppendKey(key, std::to_string(dims[i]));
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
// If MKLDNN build and CPU place then register suffix in DeviceContext
 | 
						|
inline void AttachPointerHashToMKLDNNKey(void* ptr,
 | 
						|
                                         const platform::Place& place) {
 | 
						|
  if (platform::is_cpu_place(place)) {
 | 
						|
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
 | 
						|
    platform::MKLDNNDeviceContext* dev_ctx =
 | 
						|
        (platform::MKLDNNDeviceContext*)pool.Get(place);
 | 
						|
    dev_ctx->SetKeySuffix("E" +
 | 
						|
                          std::to_string(reinterpret_cast<uintptr_t>(ptr)));
 | 
						|
    // When NaiveExecutor/Executor is used no info on thread id is needed in a
 | 
						|
    // key
 | 
						|
    dev_ctx->DisableThreadInfoInKey();
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
template <typename... ArgTypes>
 | 
						|
inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx,
 | 
						|
                             ArgTypes&&... args) {
 | 
						|
  std::string key;
 | 
						|
  key.reserve(64);
 | 
						|
  using expand_type = int[];
 | 
						|
  expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...};
 | 
						|
  key += dev_ctx.GetKeySuffix();
 | 
						|
  return key;
 | 
						|
}
 | 
						|
 | 
						|
inline std::string ExtendKeyWithThreadInfoIfNeeded(
 | 
						|
    const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) {
 | 
						|
  return ((dev_ctx.IsThreadIdUsedInKey() == true) &&
 | 
						|
          (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() ==
 | 
						|
           platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default))
 | 
						|
             ? key + "-t:" + ThreadIDasStr()
 | 
						|
             : key;
 | 
						|
}
 | 
						|
 | 
						|
inline std::vector<std::vector<int64_t>> ToMkldnnPadding(
 | 
						|
    const std::vector<int64_t>& paddings) {
 | 
						|
  if (paddings.size() == 6) {
 | 
						|
    int padding_front = paddings[0];
 | 
						|
    int padding_back = paddings[1];
 | 
						|
    int padding_top = paddings[2];
 | 
						|
    int padding_bottom = paddings[3];
 | 
						|
    int padding_left = paddings[4];
 | 
						|
    int padding_right = paddings[5];
 | 
						|
 | 
						|
    return {{padding_front, padding_top, padding_left},
 | 
						|
            {padding_back, padding_bottom, padding_right}};
 | 
						|
  } else {
 | 
						|
    int padding_top = paddings[0];
 | 
						|
    int padding_bottom = paddings[1];
 | 
						|
    int padding_left = paddings[2];
 | 
						|
    int padding_right = paddings[3];
 | 
						|
 | 
						|
    return {{padding_top, padding_left}, {padding_bottom, padding_right}};
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
// The function adjusts the vector of weight dimensions for group convolutions
 | 
						|
inline void GetGroupConvWeightsTz(std::vector<int64_t>& weights_tz,  // NOLINT
 | 
						|
                                  const int groups) {
 | 
						|
  if (groups > 1) {
 | 
						|
    // if (is_conv3d) [o, i, d, h, w]->[g, o/g, i, d, h, w]
 | 
						|
    // else [o, i, h, w] -> [g, o/g, i, h, w]
 | 
						|
    weights_tz.push_back(0);
 | 
						|
    std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end());
 | 
						|
    weights_tz[0] = groups;
 | 
						|
    weights_tz[1] = weights_tz[1] / groups;
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
inline bool HasOpINT8DataType(const paddle::framework::OpDesc* op) {
 | 
						|
  return (op->GetAttrIfExists<std::string>("mkldnn_data_type") == "int8" ||
 | 
						|
          op->GetAttrIfExists<bool>("use_quantizer"));
 | 
						|
}
 | 
						|
 | 
						|
inline bool HasOpBFLOAT16DataType(const paddle::framework::OpDesc* op) {
 | 
						|
  return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "bfloat16";
 | 
						|
}
 | 
						|
 | 
						|
inline bool HasOpFLOAT32DataType(const paddle::framework::OpDesc* op) {
 | 
						|
  return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "float32";
 | 
						|
}
 | 
						|
enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP };
 | 
						|
 | 
						|
}  // namespace platform
 | 
						|
}  // namespace paddle
 |