|
|
|
@ -318,7 +318,7 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::
|
|
|
|
|
#endif
|
|
|
|
|
{
|
|
|
|
|
// run task on device
|
|
|
|
|
Execute(kernel_graph);
|
|
|
|
|
Execute(kernel_graph, true);
|
|
|
|
|
}
|
|
|
|
|
// summary
|
|
|
|
|
Summary(kernel_graph.get());
|
|
|
|
@ -348,17 +348,6 @@ void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelG
|
|
|
|
|
MS_LOG(INFO) << "Finish";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
|
|
|
|
MS_LOG(INFO) << "Start!";
|
|
|
|
|
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(runtime_instance);
|
|
|
|
|
bool ret_ok = runtime_instance->LaunchKernel(kernel_graph.get());
|
|
|
|
|
if (!ret_ok) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Run task error!";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Finish!";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {
|
|
|
|
|
return run_op_graphs_.find(graph_info) != run_op_graphs_.end();
|
|
|
|
|
}
|
|
|
|
@ -398,7 +387,7 @@ void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_i
|
|
|
|
|
// load input data to device
|
|
|
|
|
LoadInputData(graph, input_tensors);
|
|
|
|
|
// run op
|
|
|
|
|
RunOpExecTask(graph);
|
|
|
|
|
Execute(graph, false);
|
|
|
|
|
// get output
|
|
|
|
|
if (op_run_info.value != nullptr) {
|
|
|
|
|
std::vector<tensor::TensorPtr> pre_output_tensors;
|
|
|
|
@ -552,21 +541,30 @@ void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const {
|
|
|
|
|
|
|
|
|
|
void AscendSession::Load(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
|
|
|
|
MS_LOG(INFO) << "Start!";
|
|
|
|
|
auto context_ptr = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
|
|
|
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
|
|
|
|
|
(void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph);
|
|
|
|
|
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(runtime_instance);
|
|
|
|
|
bool ret_ok = runtime_instance->Load(kernel_graph.get());
|
|
|
|
|
bool ret_ok = runtime_instance->Load(kernel_graph.get(), is_task_sink);
|
|
|
|
|
if (!ret_ok) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Load task error!";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Finish!";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
|
|
|
|
void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const {
|
|
|
|
|
MS_LOG(INFO) << "Start!";
|
|
|
|
|
bool is_task_sink = false;
|
|
|
|
|
if (is_task) {
|
|
|
|
|
auto context_ptr = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
|
|
|
is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
|
|
|
|
|
}
|
|
|
|
|
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(runtime_instance);
|
|
|
|
|
bool ret_ok = runtime_instance->Run(kernel_graph.get());
|
|
|
|
|
bool ret_ok = runtime_instance->Run(kernel_graph.get(), is_task_sink);
|
|
|
|
|
if (!ret_ok) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "run task error!";
|
|
|
|
|
}
|
|
|
|
|