|
|
|
@ -29,6 +29,7 @@
|
|
|
|
|
#include "backend/kernel_compiler/kernel.h"
|
|
|
|
|
#include "backend/session/session_factory.h"
|
|
|
|
|
#include "backend/session/ascend_control_parser.h"
|
|
|
|
|
#include "runtime/device/ascend/ascend_memory_pool.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace session {
|
|
|
|
@ -37,7 +38,7 @@ enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2,
|
|
|
|
|
class AscendSession : public SessionBasic {
|
|
|
|
|
public:
|
|
|
|
|
AscendSession() { final_graph_id_ = kInvalidGraphId; }
|
|
|
|
|
~AscendSession() override = default;
|
|
|
|
|
~AscendSession() override { mindspore::device::ascend::AscendMemoryPool::GetInstance().ResetIdleMemBuf(); }
|
|
|
|
|
void Init(uint32_t device_id) override {
|
|
|
|
|
SessionBasic::Init(device_id);
|
|
|
|
|
context_ = std::make_shared<Context>(kAscendDevice, device_id);
|
|
|
|
|