|
|
@ -232,6 +232,15 @@ Status SubgraphExecutor::PrepareNodes() {
|
|
|
|
node_state->SetKernelTask(node_item.kernel_task);
|
|
|
|
node_state->SetKernelTask(node_item.kernel_task);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto unique_task_context = TaskContext::Create(*node_state->GetNodeItem(), context_, subgraph_context_.get());
|
|
|
|
|
|
|
|
GE_CHECK_NOTNULL(unique_task_context);
|
|
|
|
|
|
|
|
const auto &task = node_state->GetKernelTask();
|
|
|
|
|
|
|
|
if (task == nullptr) {
|
|
|
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "[%s] NodeTask is null.", node_state->GetName().c_str());
|
|
|
|
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
|
|
|
|
|
|
|
|
node_state->SetTaskContex(shared_task_context);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (!ready_queue_.Push(p_node_state)) {
|
|
|
|
if (!ready_queue_.Push(p_node_state)) {
|
|
|
@ -267,6 +276,19 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
node_state.SetKernelTask(node_item.kernel_task);
|
|
|
|
node_state.SetKernelTask(node_item.kernel_task);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto unique_task_context = TaskContext::Create(*node_state.GetNodeItem(), context_, subgraph_context_.get());
|
|
|
|
|
|
|
|
GE_CHECK_NOTNULL(unique_task_context);
|
|
|
|
|
|
|
|
const auto &task = node_state.GetKernelTask();
|
|
|
|
|
|
|
|
if (task == nullptr) {
|
|
|
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "[%s] NodeTask is null.", node_state.GetName().c_str());
|
|
|
|
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
|
|
|
|
|
|
|
|
node_state.SetTaskContex(shared_task_context);
|
|
|
|
|
|
|
|
GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context));
|
|
|
|
|
|
|
|
RECORD_COMPILE_EVENT(ctx, node_item.NodeItem().c_str(), "[UpdateTilingData] start");
|
|
|
|
|
|
|
|
GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*shared_task_context)); // update op_desc before alloc ws
|
|
|
|
|
|
|
|
RECORD_COMPILE_EVENT(ctx, node_item.NodeItem().c_str(), "[UpdateTilingData] end");
|
|
|
|
return SUCCESS;
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -295,10 +317,9 @@ Status SubgraphExecutor::LaunchTasks() {
|
|
|
|
GE_CHK_STATUS_RET_NOLOG(node_state->WaitForPrepareDone());
|
|
|
|
GE_CHK_STATUS_RET_NOLOG(node_state->WaitForPrepareDone());
|
|
|
|
|
|
|
|
|
|
|
|
GELOGD("[%s] Start to execute.", node_state->GetName().c_str());
|
|
|
|
GELOGD("[%s] Start to execute.", node_state->GetName().c_str());
|
|
|
|
auto task_context = TaskContext::Create(*node_state->GetNodeItem(), context_, subgraph_context_.get());
|
|
|
|
auto shared_task_context = node_state->GetTaskContext();
|
|
|
|
GE_CHECK_NOTNULL(task_context);
|
|
|
|
GE_CHECK_NOTNULL(shared_task_context);
|
|
|
|
task_context->SetForceInferShape(force_infer_shape_);
|
|
|
|
shared_task_context->SetForceInferShape(force_infer_shape_);
|
|
|
|
auto shared_task_context = std::shared_ptr<TaskContext>(task_context.release());
|
|
|
|
|
|
|
|
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, shared_task_context, *context_),
|
|
|
|
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, shared_task_context, *context_),
|
|
|
|
"[%s] Execute node failed.",
|
|
|
|
"[%s] Execute node failed.",
|
|
|
|
node_state->GetName().c_str());
|
|
|
|
node_state->GetName().c_str());
|
|
|
|