|
|
|
@ -132,7 +132,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
|
|
|
|
|
|
|
|
|
|
/* \brief Get existed element: memory, primitive or primitive desc */
|
|
|
|
|
template <typename T>
|
|
|
|
|
const T GetElement(const std::string& op_key) const;
|
|
|
|
|
const T& GetElement(const std::string& op_key) const;
|
|
|
|
|
|
|
|
|
|
/* \brief Get element pool: memory, primitive or primitive desc pool */
|
|
|
|
|
template <typename T>
|
|
|
|
@ -140,7 +140,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
|
|
|
|
|
GetElementPool() const;
|
|
|
|
|
|
|
|
|
|
/* \brief Get the active engine */
|
|
|
|
|
const MKLDNNEnginePtr GetEngine() const { return engine_; }
|
|
|
|
|
const MKLDNNEngine& engine() const { return *engine_; }
|
|
|
|
|
|
|
|
|
|
/* \brief Submit primitive to pipeline */
|
|
|
|
|
void Submit(const MKLDNNPrimitivePtr& p) { pipeline_.push_back(*p); }
|
|
|
|
@ -163,7 +163,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
|
|
|
|
|
std::hash<std::string>>
|
|
|
|
|
primitive_desc_pool_;
|
|
|
|
|
std::vector<MKLDNNPrimitive> pipeline_;
|
|
|
|
|
std::unique_ptr<MKLDNNStream> stream_;
|
|
|
|
|
MKLDNNStreamPtr stream_;
|
|
|
|
|
MKLDNNEnginePtr engine_;
|
|
|
|
|
bool ready_;
|
|
|
|
|
};
|
|
|
|
|