You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
211 lines
8.1 KiB
211 lines
8.1 KiB
/**
|
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "hybrid_model_executor.h"
|
|
#include "graph/ge_context.h"
|
|
#include "graph/runtime_inference_context.h"
|
|
#include "common/dump/dump_manager.h"
|
|
#include "common/profiling/profiling_manager.h"
|
|
#include "mmpa/mmpa_api.h"
|
|
|
|
namespace ge {
|
|
namespace hybrid {
|
|
namespace {
|
|
const int kIntBase = 10;
|
|
const char *const kEnvProfilingLevel = "HYBRID_PROFILING_LEVEL";
|
|
const char *const kIndexId = "index_id";
|
|
const char *const kModeleId = "model_id";
|
|
const char *const kTimeStamp = "time_stamp";
|
|
const char *const kStreamId = "stream_id";
|
|
const char *const kTaskId = "task_id";
|
|
const char *const kTagId = "tag_id";
|
|
const char *const kThreadId = "thread_id";
|
|
const uint32_t kInteval = 2;
|
|
|
|
RTS_API rtError_t rtProfilerTraceEx(uint64_t id, uint64_t modelId, uint16_t tagId, rtStream_t stream) {
|
|
return RT_ERROR_NONE;
|
|
}
|
|
} // namespace
|
|
HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id, rtStream_t stream)
|
|
: model_(model), device_id_(device_id), stream_(stream) {
|
|
}
|
|
|
|
HybridModelExecutor::~HybridModelExecutor() {
|
|
if (context_.rt_gen_context != nullptr) {
|
|
(void) rtCtxDestroy(context_.rt_gen_context);
|
|
}
|
|
if (context_.global_step != nullptr) {
|
|
(void) rtFree(context_.global_step);
|
|
}
|
|
}
|
|
|
|
Status HybridModelExecutor::Init() {
|
|
GELOGD("Start to init HybridGraphEngine.");
|
|
GE_CHK_STATUS_RET_NOLOG(InitExecutionContext());
|
|
GELOGD("HybridGraphEngine initialized successfully.");
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) {
|
|
GELOGD("Start to execute model.");
|
|
auto root_graph_item = model_->GetRootGraphItem();
|
|
GE_CHECK_NOTNULL(root_graph_item);
|
|
|
|
GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration,
|
|
sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream));
|
|
SubgraphExecutor executor(model_->GetRootGraphItem(), &context_);
|
|
auto ret = ExecuteGraphInternal(executor, args);
|
|
Cleanup();
|
|
RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End");
|
|
GELOGD("Model executed successfully.");
|
|
if (context_.profiler != nullptr) {
|
|
context_.profiler->Dump(std::cout);
|
|
context_.profiler->Reset();
|
|
}
|
|
|
|
context_.iteration += 1;
|
|
if (ret == END_OF_SEQUENCE) {
|
|
args.is_eos = true;
|
|
} else {
|
|
GE_CHK_STATUS_RET(ret, "Failed to execute model");
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status HybridModelExecutor::ProfileStepInfo(uint16_t tag_id) {
|
|
auto &prof_mgr = ProfilingManager::Instance();
|
|
if (prof_mgr.ProfilingModelExecuteOn()) {
|
|
uint64_t index_id = context_.iteration + 1;
|
|
uint64_t model_id = static_cast<uint64_t>(model_->GetModelId());
|
|
GELOGD("Profiling Step Info TraceTask execute async start, index_id = %lu, model_id = %lu, tag_id = %u",
|
|
index_id, model_id, tag_id);
|
|
rtError_t rt_ret = rtProfilerTraceEx(index_id, model_id, tag_id, stream_);
|
|
if (rt_ret != RT_ERROR_NONE) {
|
|
GELOGE(RT_FAILED, "[Call][rtProfilerTraceEx] failed, ret: 0x%X", rt_ret);
|
|
return RT_ERROR_TO_GE_STATUS(rt_ret);
|
|
}
|
|
GELOGD("Profiling Step Info TraceTask execute async success, index_id = %lu, model_id = %lu, tag_id = %u",
|
|
index_id, model_id, tag_id);
|
|
|
|
mmTimespec timespec = mmGetTickCount();
|
|
// 1000 ^ 3 converts second to nanosecond
|
|
int64_t time = timespec.tv_sec * 1000 * 1000 * 1000 + timespec.tv_nsec;
|
|
uint32_t task_id = 0;
|
|
uint32_t stream_id = 0;
|
|
rt_ret = rtGetTaskIdAndStreamID(&task_id, &stream_id);
|
|
if (rt_ret != RT_ERROR_NONE) {
|
|
GELOGE(RT_FAILED, "[Get][RtsInfo] task_id and stream_id failed, ret: 0x%X.", rt_ret);
|
|
return RT_ERROR_TO_GE_STATUS(rt_ret);
|
|
}
|
|
GELOGD("Get profiling args, task_id[%u], stream_id[%u]", task_id, stream_id);
|
|
|
|
Json step_info;
|
|
step_info[kIndexId] = index_id;
|
|
step_info[kModeleId] = model_id;
|
|
step_info[kTimeStamp] = time;
|
|
step_info[kTagId] = tag_id;
|
|
step_info[kTaskId] = task_id;
|
|
step_info[kStreamId] = stream_id;
|
|
step_info[kThreadId] = mmGetTid();
|
|
|
|
std::string reported_data;
|
|
try {
|
|
reported_data = step_info.dump(kInteval, ' ', false, Json::error_handler_t::ignore);
|
|
} catch (std::exception &e) {
|
|
GELOGE(FAILED, "Failed to convert JSON to string, reason: %s.", e.what());
|
|
} catch (...) {
|
|
GELOGE(FAILED, "Failed to convert JSON to string.");
|
|
}
|
|
reported_data.append(",")
|
|
.append("\n");
|
|
prof_mgr.ReportData(device_id_, reported_data, "step_info");
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor,
|
|
HybridModelExecutor::ExecuteArgs &args) {
|
|
RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start");
|
|
GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_));
|
|
RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End");
|
|
|
|
// tag_id 0 means step begin, 1 meas step end.
|
|
GE_CHK_STATUS_RET_NOLOG(ProfileStepInfo(0));
|
|
HYBRID_CHK_STATUS_RET(executor.ExecuteAsync(args.inputs, args.input_desc, args.outputs),
|
|
"Failed to execute partitioned call.");
|
|
RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End");
|
|
GE_CHK_STATUS_RET_NOLOG(ProfileStepInfo(1));
|
|
|
|
HYBRID_CHK_STATUS_RET(executor.Synchronize(), "Failed to sync root graph.");
|
|
RECORD_MODEL_EXECUTION_EVENT(&context_, "[Synchronize] End");
|
|
|
|
args.outputs.clear();
|
|
HYBRID_CHK_STATUS_RET(executor.GetOutputs(args.outputs, args.output_desc), "Failed to get outputs");
|
|
RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End");
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status HybridModelExecutor::Cleanup() {
|
|
GELOGD("Start to cleanup.");
|
|
context_.callback_manager->Destroy();
|
|
RuntimeInferenceContext::DestroyContext(std::to_string(context_.context_id));
|
|
GELOGD("Cleanup successfully.");
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status HybridModelExecutor::InitExecutionContext() {
|
|
GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context));
|
|
GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0));
|
|
GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context));
|
|
GE_CHK_RT_RET(rtMalloc(&context_.global_step, sizeof(uint64_t), RT_MEMORY_HBM));
|
|
|
|
context_.stream = stream_;
|
|
context_.model = model_;
|
|
context_.is_eos_ = false;
|
|
context_.session_id = ::ge::GetContext().SessionId();
|
|
context_.ge_context = &GetThreadLocalContext();
|
|
GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id);
|
|
context_.allocator = NpuMemoryAllocator::GetAllocator(device_id_);
|
|
GE_CHECK_NOTNULL(context_.allocator);
|
|
context_.callback_manager = std::unique_ptr<CallbackManager>(new(std::nothrow)CallbackManager());
|
|
GE_CHECK_NOTNULL(context_.callback_manager);
|
|
context_.dump_properties = DumpManager::GetInstance().GetDumpProperties(context_.session_id);
|
|
const char *profiling_level = std::getenv(kEnvProfilingLevel);
|
|
if (profiling_level != nullptr) {
|
|
context_.profiling_level = std::strtol(profiling_level, nullptr, kIntBase);
|
|
GELOGD("Got profiling level = %ld", context_.profiling_level);
|
|
if (context_.profiling_level > 0) {
|
|
context_.profiler.reset(new(std::nothrow)HybridProfiler());
|
|
GE_CHECK_NOTNULL(context_.profiler);
|
|
}
|
|
}
|
|
|
|
if (IsLogEnable(GE_MODULE_NAME, DLOG_DEBUG)) {
|
|
context_.trace_enabled = true;
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status HybridModelExecutor::ResetExecutionContext(GraphExecutionContext &context) {
|
|
GE_CHK_STATUS_RET_NOLOG(context.callback_manager->Init());
|
|
string ctx_id = std::to_string(context.context_id);
|
|
RuntimeInferenceContext::DestroyContext(ctx_id);
|
|
GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::CreateContext(ctx_id), "Failed to Destroy RuntimeInferenceContext");
|
|
return SUCCESS;
|
|
}
|
|
} // namespace hybrid
|
|
} // namespace ge
|