|
|
|
@ -31,7 +31,6 @@
|
|
|
|
|
#include "backend/kernel_compiler/kernel.h"
|
|
|
|
|
#include "backend/session/session_factory.h"
|
|
|
|
|
#include "backend/session/ascend_control_parser.h"
|
|
|
|
|
#include "runtime/context.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace session {
|
|
|
|
@ -40,28 +39,8 @@ enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2,
|
|
|
|
|
class AscendSession : public SessionBasic {
|
|
|
|
|
public:
|
|
|
|
|
AscendSession() { final_graph_id_ = kInvalidGraphId; }
|
|
|
|
|
~AscendSession() {
|
|
|
|
|
if (rt_context_ != nullptr) {
|
|
|
|
|
auto ret = rtCtxDestroy(rt_context_);
|
|
|
|
|
if (ret != RT_ERROR_NONE) {
|
|
|
|
|
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxDestroy, ret[" << ret << "]";
|
|
|
|
|
}
|
|
|
|
|
rt_context_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Init(uint32_t device_id) override {
|
|
|
|
|
InitDevice(kAscendDevice, device_id);
|
|
|
|
|
auto ret = rtCtxCreate(&rt_context_, 0, device_id);
|
|
|
|
|
if (ret != RT_ERROR_NONE) {
|
|
|
|
|
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
|
|
|
|
|
}
|
|
|
|
|
ret = rtCtxSetCurrent(rt_context_);
|
|
|
|
|
if (ret != RT_ERROR_NONE) {
|
|
|
|
|
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~AscendSession() = default;
|
|
|
|
|
void Init(uint32_t device_id) override;
|
|
|
|
|
// get graph id of final graph
|
|
|
|
|
GraphId GetFinalRunGraph() const override { return final_graph_id_; }
|
|
|
|
|
|
|
|
|
@ -136,8 +115,6 @@ class AscendSession : public SessionBasic {
|
|
|
|
|
std::map<std::pair<GraphId, size_t>, tensor::TensorPtr> initial_tenosrs_;
|
|
|
|
|
// final_graph_id is used in every root graph has it's own session situation
|
|
|
|
|
GraphId final_graph_id_;
|
|
|
|
|
// ascend runtime context
|
|
|
|
|
rtContext_t rt_context_{nullptr};
|
|
|
|
|
};
|
|
|
|
|
MS_REG_SESSION(kAscendDevice, AscendSession);
|
|
|
|
|
} // namespace session
|
|
|
|
|