|
|
|
@ -21,6 +21,10 @@ limitations under the License. */
|
|
|
|
|
#define EIGEN_USE_GPU
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
#include "mkldnn.hpp"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#include "paddle/platform/enforce.h"
|
|
|
|
|
#include "paddle/platform/place.h"
|
|
|
|
|
#include "unsupported/Eigen/CXX11/Tensor"
|
|
|
|
@ -117,6 +121,65 @@ class CUDNNDeviceContext : public CUDADeviceContext {
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
using MKLDNNStream = mkldnn::stream;
|
|
|
|
|
using MKLDNNEngine = mkldnn::engine;
|
|
|
|
|
using MKLDNNMemory = mkldnn::memory;
|
|
|
|
|
using MKLDNNPrimitive = mkldnn::primitive;
|
|
|
|
|
using MKLDNNPrimitiveDesc = mkldnn::handle<mkldnn_primitive_desc_t>;
|
|
|
|
|
|
|
|
|
|
typedef std::shared_ptr<MKLDNNEngine> MKLDNNEnginePtr;
|
|
|
|
|
typedef std::shared_ptr<MKLDNNMemory> MKLDNNMemoryPtr;
|
|
|
|
|
typedef std::shared_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr;
|
|
|
|
|
typedef std::shared_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr;
|
|
|
|
|
class MKLDNNDeviceContext : public CPUDeviceContext {
|
|
|
|
|
public:
|
|
|
|
|
explicit MKLDNNDeviceContext(CPUPlace place);
|
|
|
|
|
virtual ~MKLDNNDeviceContext();
|
|
|
|
|
|
|
|
|
|
/* \brief Add new element: memory, primitive or primitive desc */
|
|
|
|
|
template <typename T>
|
|
|
|
|
void AddElement(const std::string& op_key, const T& value);
|
|
|
|
|
|
|
|
|
|
/* \brief Get existed element: memory, primitive or primitive desc */
|
|
|
|
|
template <typename T>
|
|
|
|
|
const T GetElement(const std::string& op_key) const;
|
|
|
|
|
|
|
|
|
|
/* \brief Get element pool: memory, primitive or primitive desc pool */
|
|
|
|
|
template <typename T>
|
|
|
|
|
const std::unordered_map<const std::string, const T, std::hash<std::string>>&
|
|
|
|
|
GetElementPool() const;
|
|
|
|
|
|
|
|
|
|
/* \brief Get the active engine */
|
|
|
|
|
const MKLDNNEnginePtr GetEngine() const { return engine_; }
|
|
|
|
|
|
|
|
|
|
/* \brief Submit primitive to pipeline */
|
|
|
|
|
void Submit(const MKLDNNPrimitivePtr& p) { pipeline_.push_back(*p); }
|
|
|
|
|
|
|
|
|
|
/*! \brief Execute all submitted primitives in pipeline */
|
|
|
|
|
void Execute(bool block = true);
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
/*! \brief Reset the stream to prepare next exectue */
|
|
|
|
|
void ResetStream();
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::unordered_map<const std::string, const MKLDNNMemoryPtr,
|
|
|
|
|
std::hash<std::string>>
|
|
|
|
|
memory_pool_;
|
|
|
|
|
std::unordered_map<const std::string, const MKLDNNPrimitivePtr,
|
|
|
|
|
std::hash<std::string>>
|
|
|
|
|
primitive_pool_;
|
|
|
|
|
std::unordered_map<const std::string, const MKLDNNPrimitiveDescPtr,
|
|
|
|
|
std::hash<std::string>>
|
|
|
|
|
primitive_desc_pool_;
|
|
|
|
|
std::vector<MKLDNNPrimitive> pipeline_;
|
|
|
|
|
std::unique_ptr<MKLDNNStream> stream_;
|
|
|
|
|
MKLDNNEnginePtr engine_;
|
|
|
|
|
bool ready_;
|
|
|
|
|
};
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
/*! \brief device context pool singleton */
|
|
|
|
|
class DeviceContextPool {
|
|
|
|
|
public:
|
|
|
|
|