|
|
|
@ -33,6 +33,9 @@ HybridModelExecutor::~HybridModelExecutor() {
|
|
|
|
|
if (context_.rt_gen_context != nullptr) {
|
|
|
|
|
(void) rtCtxDestroy(context_.rt_gen_context);
|
|
|
|
|
}
|
|
|
|
|
if (context_.global_step != nullptr) {
|
|
|
|
|
(void) rtFree(context_.global_step);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status HybridModelExecutor::Init() {
|
|
|
|
@ -47,6 +50,8 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) {
|
|
|
|
|
auto root_graph_item = model_->GetRootGraphItem();
|
|
|
|
|
GE_CHECK_NOTNULL(root_graph_item);
|
|
|
|
|
|
|
|
|
|
GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration,
|
|
|
|
|
sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream));
|
|
|
|
|
SubgraphExecutor executor(model_->GetRootGraphItem(), &context_);
|
|
|
|
|
auto ret = ExecuteGraphInternal(executor, args);
|
|
|
|
|
Cleanup();
|
|
|
|
@ -97,6 +102,7 @@ Status HybridModelExecutor::InitExecutionContext() {
|
|
|
|
|
GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context));
|
|
|
|
|
GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0));
|
|
|
|
|
GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context));
|
|
|
|
|
GE_CHK_RT_RET(rtMalloc(&context_.global_step, sizeof(uint64_t), RT_MEMORY_HBM));
|
|
|
|
|
|
|
|
|
|
context_.stream = stream_;
|
|
|
|
|
context_.model = model_;
|
|
|
|
|