You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/client/ge_api.cc

703 lines
27 KiB

5 years ago
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ge/ge_api.h"
#include <iostream>
#include <malloc.h>
5 years ago
#include "common/debug/log.h"
#include "framework/common/debug/ge_log.h"
#include "common/ge/datatype_util.h"
#include "proto/ge_api.pb.h"
5 years ago
#include "graph/model_serialize.h"
#include "graph/detail/model_serialize_imp.h"
5 years ago
#include "graph/utils/tensor_adapter.h"
#include "init/gelib.h"
#include "session/session_manager.h"
#include "graph/opsproto_manager.h"
#include "graph/utils/type_utils.h"
#include "graph/manager/util/rt_context_util.h"
#include "graph/common/ge_call_wrapper.h"
#include "register/op_registry.h"
#include "common/ge/tbe_plugin_manager.h"
#include "common/util/error_manager/error_manager.h"
#include "toolchain/plog.h"
5 years ago
using domi::OpRegistry;
using std::map;
using std::string;
using std::vector;
namespace {
const int32_t kMaxStrLen = 128;
} // namespace
static bool g_ge_initialized = false;
static std::mutex g_ge_release_mutex; // GEFinalize and ~Session use
5 years ago
namespace ge {
5 years ago
void GetOpsProtoPath(std::string &opsproto_path) {
GELOGI("Enter get ops proto path schedule");
const char *path_env = std::getenv("ASCEND_OPP_PATH");
if (path_env != nullptr) {
std::string path = path_env;
opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/");
5 years ago
GELOGI("Get opsproto so path from env: %s", path.c_str());
return;
}
std::string path_base = PluginManager::GetPath();
GELOGI("path_base is %s", path_base.c_str());
path_base = path_base.substr(0, path_base.rfind('/'));
path_base = path_base.substr(0, path_base.rfind('/') + 1);
opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
5 years ago
}
Status CheckOptionsValid(const std::map<string, string> &options) {
// check job_id is valid
auto job_id_iter = options.find(OPTION_EXEC_JOB_ID);
if (job_id_iter != options.end()) {
if (job_id_iter->second.length() > kMaxStrLen) {
4 years ago
GELOGE(PARAM_INVALID,"[Check][JobId]Failed,"
"the job_id [%s] string length > max string length: %d",
job_id_iter->second.c_str(), kMaxStrLen);
4 years ago
REPORT_INPUT_ERROR("E10051", std::vector<std::string>({"id","length"}), std::vector<std::string>({job_id_iter->second, std::to_string(kMaxStrLen)}));
5 years ago
return FAILED;
}
}
return SUCCESS;
}
// Initialize GE, prepare for execution, call GELib::Initialize
Status GEInitializeImpl(const std::map<string, string> &options) {
4 years ago
ErrorManager::GetInstance().GenWorkStreamIdDefault();
5 years ago
GELOGT(TRACE_INIT, "GEInitialize start");
std::string path_base = ge::GELib::GetPath();
auto ret = ErrorManager::GetInstance().Init(path_base);
if (ret != SUCCESS) {
4 years ago
GELOGE(GE_CLI_INIT_FAILED,
"[Init][PathBase]Init failed when pass param path_base:%s", path_base.c_str());
return ret;
}
5 years ago
// 0.check init status
if (g_ge_initialized) {
5 years ago
GELOGW("GEInitialize is called more than once");
return SUCCESS;
}
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kInitialize, ErrorMessage::kOpsProtoInit);
5 years ago
// Load OpsProto lib plugin
std::string opsproto_path;
GetOpsProtoPath(opsproto_path);
OpsProtoManager *manager = OpsProtoManager::Instance();
std::map<string, string> option_tmp;
option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path));
GE_TIMESTAMP_START(GEInitialize);
5 years ago
bool is_proto_init = manager->Initialize(option_tmp);
GE_TIMESTAMP_END(GEInitialize, "GEInitialize::ManagerInitialize");
5 years ago
if (!is_proto_init) {
4 years ago
GELOGE(GE_CLI_INIT_FAILED,
"[Init][OpsProtoPath]Loading OpsProto lib plugin failed, OpsProtoPath:%s invalid.",
opsproto_path.c_str());
5 years ago
return FAILED;
}
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kInitialize, ErrorMessage::kOther);
5 years ago
// check options is valid
GE_TIMESTAMP_START(CheckOptionsValid);
5 years ago
if (CheckOptionsValid(options) != SUCCESS) {
return FAILED;
}
GE_TIMESTAMP_END(CheckOptionsValid, "GEInitialize::CheckOptionsValid");
5 years ago
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kInitialize, ErrorMessage::kOpsProtoInit);
GE_TIMESTAMP_START(InitPreparation);
TBEPluginManager::Instance().InitPreparation(options);
GE_TIMESTAMP_END(InitPreparation, "GEInitialize::InitPreparation");
5 years ago
// call Initialize
GELOGT(TRACE_RUNNING, "Initializing environment");
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kInitialize, ErrorMessage::kOther);
GE_TIMESTAMP_START(GELibInitialize);
ret = ge::GELib::Initialize(options);
GE_TIMESTAMP_END(GELibInitialize, "GEInitialize::GELibInitialize");
5 years ago
if (ret != SUCCESS) {
4 years ago
GELOGE(GE_CLI_INIT_FAILED, "[Init][GELib]Failed, error code = %u", ret);
5 years ago
return FAILED;
}
// 7.check return status, return
if (!g_ge_initialized) {
5 years ago
// Initialize success, first time calling initialize
g_ge_initialized = true;
5 years ago
}
GELOGT(TRACE_STOP, "GEInitialize finished");
return ret;
}
// Initialize GE, prepare for execution, call GELib::Initialize
Status GEInitialize(const std::map<string, string> &options) {
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kInitialize, ErrorMessage::kOther);
if (DlogReportInitialize() != SUCCESS) {
GELOGW("Dlog report device log initialize failed.");
}
return GEInitializeImpl(options);
}
Status GEInitialize(const std::map<AscendString, AscendString> &options) {
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kInitialize, ErrorMessage::kOther);
std::map<std::string, std::string> str_options;
4 years ago
for (auto &option : options) {
if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) {
4 years ago
GELOGE(FAILED, "[Check][Param]Options invalid, first or second option is nullptr.");
REPORT_INNER_ERROR("E19999", "Check parameter's options invalid,"
"the first or second option is nullptr.");
return FAILED;
}
std::string key = option.first.GetString();
std::string val = option.second.GetString();
str_options[key] = val;
}
if (DlogReportInitialize() != SUCCESS) {
GELOGW("Dlog report device log initialize failed.");
}
return GEInitializeImpl(str_options);
}
5 years ago
// GE finalize, releasing all resources
Status GEFinalize() {
std::lock_guard<std::mutex> lock(g_ge_release_mutex);
5 years ago
// check init status
if (!g_ge_initialized) {
GELOGW("[FINAL][FINAL]GEFinalize is called before GEInitialize");
5 years ago
return SUCCESS;
}
ErrorManager::GetInstance().SetStage(ErrorMessage::kFinalize, ErrorMessage::kFinalize);
ErrorManager::GetInstance().GenWorkStreamIdDefault();
GELOGT(TRACE_INIT, "GEFinalize start");
5 years ago
// call Finalize
Status ret = SUCCESS;
Status middle_ret;
5 years ago
GELOGT(TRACE_RUNNING, "Finalizing environment");
std::shared_ptr<GELib> instancePtr = ge::GELib::GetInstance();
if (instancePtr == nullptr || !instancePtr->InitFlag()) {
GELOGW("GEFinalize Failed: GE not initialized.");
ret = GE_CLI_GE_NOT_INITIALIZED;
5 years ago
}
if (ret != GE_CLI_GE_NOT_INITIALIZED) {
middle_ret = instancePtr->Finalize();
GELOGI("GEFinalize finalize gelib ret=%u", middle_ret);
if (middle_ret != SUCCESS) {
ret = middle_ret;
}
5 years ago
}
middle_ret = TBEPluginManager::Instance().Finalize();
if (middle_ret != SUCCESS) {
ret = middle_ret;
}
if (g_ge_initialized && ret == SUCCESS) {
// Unified destruct rt_context
RtContextUtil::GetInstance().DestroyAllRtContexts();
g_ge_initialized = false;
5 years ago
}
// to avoid memory fragment, use malloc_trim to back free stack to system
malloc_trim(0);
if (DlogReportFinalize() != SUCCESS) {
GELOGW("Dlog report device log finalize failed.");
}
5 years ago
GELOGT(TRACE_STOP, "GEFinalize finished");
return ret;
}
std::string GEGetErrorMsg() {
return ErrorManager::GetInstance().GetErrorMessage();
}
std::string GEGetWarningMsg() {
return ErrorManager::GetInstance().GetWarningMessage();
}
5 years ago
// Initialize sessionwhich calls innerSession
Session::Session(const std::map<string, string> &options) {
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kInitialize, ErrorMessage::kOther);
5 years ago
GELOGT(TRACE_INIT, "Session Constructor start");
ErrorManager::GetInstance().GenWorkStreamIdDefault();
5 years ago
// check init status
sessionId_ = 0;
if (!g_ge_initialized) {
4 years ago
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Construct][Session]Failed because lack GEInitialize call before.");
4 years ago
REPORT_INNER_ERROR("E19999",
4 years ago
"Creating session failed because lack GEInitialize call before.");
5 years ago
return;
}
// call Initialize
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
4 years ago
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Construct][Session]Failed, GELib instance is nullptr or it is not InitFlag");
5 years ago
return;
}
GELOGT(TRACE_RUNNING, "Creating session");
uint64_t session_id = 0;
Status ret = instance_ptr->SessionManagerObj().CreateSession(options, session_id);
GELOGT(TRACE_RUNNING, "Session id is %lu", session_id);
// check return status, return, update session id if success
if (ret == SUCCESS) {
sessionId_ = session_id;
} else {
4 years ago
GELOGE(ret, "[Construct][Session]Failed, error code:%u.", ret);
5 years ago
return;
}
GELOGT(TRACE_STOP, "Session Constructor finished");
}
Session::Session(const std::map<AscendString, AscendString> &options) {
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kInitialize, ErrorMessage::kOther);
GELOGT(TRACE_INIT, "Session Constructor start");
ErrorManager::GetInstance().GenWorkStreamIdDefault();
// check init status
sessionId_ = 0;
if (!g_ge_initialized) {
4 years ago
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Construct][Session]Failed because lack GEInitialize call before.");
4 years ago
REPORT_INNER_ERROR("E19999",
4 years ago
"Creating session failed because lack GEInitialize call before.");
return;
}
// call Initialize
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
4 years ago
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Construct][Session]Failed, the GELib instance is nullptr or is not InitFlag");
return;
}
GELOGT(TRACE_RUNNING, "Creating session");
std::map<std::string, std::string> str_options;
for (auto &option : options) {
if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) {
4 years ago
GELOGE(FAILED, "[Construct][Session]Failed, the first or second option is nullptr.");
REPORT_INNER_ERROR("E19999", "Creating session's options invalid,"
"the first or second option is nullptr.");
return;
}
std::string key = option.first.GetString();
std::string val = option.second.GetString();
str_options[key] = val;
}
uint64_t session_id = 0;
Status ret = instance_ptr->SessionManagerObj().CreateSession(str_options, session_id);
GELOGT(TRACE_RUNNING, "Session id is %lu", session_id);
// check return status, return, update session id if success
if (ret == SUCCESS) {
sessionId_ = session_id;
} else {
4 years ago
GELOGE(ret, "[Construct][Session]Failed, error code:%u.", ret);
return;
}
GELOGT(TRACE_STOP, "Session Constructor finished");
}
5 years ago
// session destructor
Session::~Session() {
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kFinalize, ErrorMessage::kFinalize);
5 years ago
GELOGT(TRACE_INIT, "Session Destructor start");
// 0.check init status
if (!g_ge_initialized) {
5 years ago
GELOGW("GE is not yet initialized or is finalized.");
return;
}
Status ret = FAILED;
std::lock_guard<std::mutex> lock(g_ge_release_mutex);
5 years ago
try {
uint64_t session_id = sessionId_;
// call DestroySession
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGW("GE is not yet initialized or is finalized.");
return;
}
GELOGT(TRACE_RUNNING, "Session id is %lu", session_id);
GELOGT(TRACE_RUNNING, "Destroying session");
ret = instance_ptr->SessionManagerObj().DestroySession(session_id);
} catch (google::protobuf::FatalException &e) {
4 years ago
GELOGE(GE_CLI_SESS_DESTROY_FAILED, "[Destruct][Session]Failed because get fatalException.");
5 years ago
}
// check return status, return, update session id if success
if (ret != SUCCESS) {
4 years ago
GELOGE(ret, "[Destruct][Session]Failed, error code:%u.", ret);
5 years ago
}
GELOGT(TRACE_STOP, "Session Destructor finished");
}
4 years ago
// Add Graph
5 years ago
Status Session::AddGraph(uint32_t graph_id, const Graph &graph) {
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
std::map<std::string, std::string> options;
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
return AddGraph(graph_id, graph, options);
}
4 years ago
// Add Graph
Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map<std::string, std::string> &options) {
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_);
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
5 years ago
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
4 years ago
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Add][Graph]Failed because GELib instance is nullptr or it is not InitFlag.");
REPORT_INNER_ERROR("E19999",
"AddGraph Failed, GELib instance is nullptr or it is not InitFlag.");
5 years ago
return FAILED;
}
GELOGD("Adding graph to session");
Status ret = instance_ptr->SessionManagerObj().AddGraph(sessionId_, graph_id, graph, options);
5 years ago
if (ret != SUCCESS) {
4 years ago
GELOGE(ret,
"[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.",
ret, sessionId_, graph_id);
5 years ago
return FAILED;
}
GELOGD("AddGraph finished in Session.");
5 years ago
return ret;
}
4 years ago
//Add Graph
Status Session::AddGraph(uint32_t graph_id, const Graph &graph,
const std::map<AscendString, AscendString> &options) {
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_);
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
4 years ago
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Add][Graph]Failed, the GELib instance is nullptr or is not InitFlag.");
REPORT_INNER_ERROR("E19999",
"AddGraph Failed, GELib instance is nullptr or it is not InitFlag.");
return FAILED;
}
GELOGD("Adding graph to session");
std::map<std::string, std::string> str_options;
for (auto &option : options) {
if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) {
4 years ago
GELOGE(FAILED, "[Add][Graph]Failed, the first or second option is nullptr.");
REPORT_INNER_ERROR("E19999",
"Add Graph Failed, the first or second option is nullptr.");
return FAILED;
}
std::string key = option.first.GetString();
std::string val = option.second.GetString();
str_options[key] = val;
}
Status ret = instance_ptr->SessionManagerObj().AddGraph(sessionId_, graph_id, graph, str_options);
if (ret != SUCCESS) {
4 years ago
GELOGE(ret,
"[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.",
ret, sessionId_, graph_id);
return FAILED;
}
GELOGD("AddGraph finished in Session.");
return ret;
}
4 years ago
Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph) {
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
4 years ago
std::map<AscendString, AscendString> options;
return AddGraphWithCopy(graph_id, graph, options);
}
4 years ago
// Add Graph With Copy
4 years ago
Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph,
const std::map<AscendString, AscendString> &options) {
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
4 years ago
GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_);
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
4 years ago
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
4 years ago
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Add][Graph]Failed, the GELib instance is nullptr or is not InitFlag.");
REPORT_INNER_ERROR("E19999",
"AddGraph Failed, GELib instance is nullptr or is not InitFlag.");
4 years ago
return FAILED;
}
std::map<std::string, std::string> str_options;
for (auto it = options.begin(); it != options.end(); ++it) {
str_options.insert({it->first.GetString(), it->second.GetString()});
}
GELOGD("Adding graph to session");
Status ret = instance_ptr->SessionManagerObj().AddGraphWithCopy(sessionId_, graph_id, graph, str_options);
if (ret != SUCCESS) {
4 years ago
GELOGE(ret,
4 years ago
"[Add][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.",
4 years ago
ret, sessionId_, graph_id);
4 years ago
return FAILED;
}
GELOGD("AddGraph finished in Session.");
return ret;
}
4 years ago
// Remove Graph
5 years ago
Status Session::RemoveGraph(uint32_t graph_id) {
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
5 years ago
GELOGT(TRACE_INIT, "Session RemoveGraph start");
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
5 years ago
// call RemoveGraph
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (!instance_ptr || !instance_ptr->InitFlag()) {
4 years ago
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Remove][Graph]Failed, GELib instance is nullptr or is not InitFlag ");
REPORT_INNER_ERROR("E19999",
"RemoveGraph Failed, GELib instance is nullptr or is not InitFlag.");
5 years ago
return FAILED;
}
GELOGT(TRACE_RUNNING, "Removing Graph from session");
Status ret = instance_ptr->SessionManagerObj().RemoveGraph(sessionId_, graph_id);
// check return status, return
if (ret != SUCCESS) {
4 years ago
GELOGE(ret,
"[Remove][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.",
ret, sessionId_, graph_id);
5 years ago
return FAILED;
}
GELOGT(TRACE_STOP, "Session RemoveGraph finished");
return ret;
}
4 years ago
// Print Output Result
5 years ago
void PrintOutputResult(std::vector<Tensor> &outputs) {
if (outputs.empty() || outputs[0].GetData() == nullptr) {
GELOGW("outputs is empty or data is nullptr.");
return;
}
size_t out_buf_size = outputs[0].GetSize();
TensorDesc desc(outputs[0].GetTensorDesc());
DataType data_type = desc.GetDataType();
auto iter = CONST_OPDATA_TYPE_SIZE_MAP.find(data_type);
if (iter == CONST_OPDATA_TYPE_SIZE_MAP.end()) {
GELOGI("DataType %s has not defined size", TypeUtils::DataTypeToSerialString(data_type).c_str());
return;
}
size_t length = CONST_OPDATA_TYPE_SIZE_MAP[data_type];
for (size_t i = 0; i < 10 && i < (out_buf_size / length); ++i) { // take first 10 at most
switch (data_type) {
case DT_BOOL:
case DT_INT8:
case DT_UINT8:
GELOGI("output data[%zu]=%d", i, *(reinterpret_cast<int8_t *>(outputs[0].GetData()) + i));
break;
case DT_INT16:
case DT_UINT16:
GELOGI("output data[%zu]=%d", i, *(reinterpret_cast<int16_t *>(outputs[0].GetData()) + i));
break;
case DT_INT32:
case DT_UINT32:
GELOGI("output data[%zu]=%d", i, *(reinterpret_cast<int32_t *>(outputs[0].GetData()) + i));
break;
case DT_INT64:
case DT_UINT64:
GELOGI("output data[%zu]=%ld", i, *(reinterpret_cast<int64_t *>(outputs[0].GetData()) + i));
break;
case DT_FLOAT:
GELOGI("output data[%zu]=%f", i, *(reinterpret_cast<float *>(outputs[0].GetData()) + i));
break;
case DT_DOUBLE:
GELOGI("output data[%zu]=%lf", i, *(reinterpret_cast<double *>(outputs[0].GetData()) + i));
break;
default:
GELOGI("Output datatype %s is not supported.", TypeUtils::DataTypeToSerialString(data_type).c_str());
5 years ago
return;
}
}
}
4 years ago
// Run Graph
5 years ago
Status Session::RunGraph(uint32_t graph_id, const std::vector<Tensor> &inputs, std::vector<Tensor> &outputs) {
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
5 years ago
GELOGT(TRACE_INIT, "Session RunGraph start");
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
5 years ago
std::vector<Tensor> graph_inputs = inputs;
// call RunGraph
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
4 years ago
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Run][Graph]Failed, GELib instance is nullptr or is not InitFlag.");
REPORT_INNER_ERROR("E19999",
"RunGraph Failed, GELib instance is nullptr or is not InitFlag.");
5 years ago
return FAILED;
}
GELOGT(TRACE_RUNNING, "Running Graph");
Status ret = instance_ptr->SessionManagerObj().RunGraph(sessionId_, graph_id, graph_inputs, outputs);
// check return status
if (ret != SUCCESS) {
4 years ago
GELOGE(ret,
"[Run][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.",
ret, sessionId_, graph_id);
5 years ago
return FAILED;
}
// print output
if (outputs.size() > 0) {
PrintOutputResult(outputs);
}
// return
GELOGT(TRACE_STOP, "Session RunGraph finished");
return ret;
}
4 years ago
// Register Call Back
5 years ago
Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback) {
ErrorManager::GetInstance().GenWorkStreamIdDefault();
5 years ago
return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback);
}
Status Session::RegisterCallBackFunc(const char *key, const session::pCallBackFunc &callback) {
ErrorManager::GetInstance().GenWorkStreamIdDefault();
std::string str_key;
if (key != nullptr) {
str_key = key;
}
return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, str_key, callback);
}
4 years ago
// Build Graph
Status Session::BuildGraph(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) {
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
4 years ago
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Build][Graph]Failed, the GELib instance is nullptr or is not InitFlag.");
REPORT_INNER_ERROR("E19999",
"Build graph failed, the GELib instance is nullptr or is not InitFlag.");
return FAILED;
}
GELOGT(TRACE_RUNNING, "Building Graph");
Status ret = instance_ptr->SessionManagerObj().BuildGraph(sessionId_, graph_id, inputs);
if (ret != SUCCESS) {
4 years ago
GELOGE(ret,
"[Build][Graph]Failed, error code:%u, session_id:%lu, graph_id:%u.",
ret, sessionId_, graph_id);
return FAILED;
}
return SUCCESS;
}
4 years ago
// Run Graph Asynchronously
Status Session::RunGraphAsync(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs,
RunAsyncCallback callback) {
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelExecute, ErrorMessage::kModelExecute);
ErrorManager::GetInstance().GenWorkStreamIdBySessionGraph(sessionId_, graph_id);
5 years ago
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
4 years ago
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
"[Run][Graph]RunGraphAsyncFailed, the GELib instance is nullptr or is not InitFlag.");
REPORT_INNER_ERROR("E19999",
"RunGraphAsync Failed, the GELib instance is nullptr or is not InitFlag.");
5 years ago
return FAILED;
}
GELOGT(TRACE_RUNNING, "Run Graph Asynchronously");
GELOGW(
4 years ago
"The callback function will not be checked. Please ensure that the implementation of the function is trusted.");
5 years ago
Status ret = ge::GELib::GetInstance()->SessionManagerObj().RunGraphAsync(sessionId_, graph_id, inputs, callback);
5 years ago
if (ret != SUCCESS) {
4 years ago
GELOGE(ret, "[Run][Graph]RunGraphAsync Failed, error code:%u, session_id:%lu, graph_id:%u.",
ret, sessionId_, graph_id);
5 years ago
return FAILED;
}
return SUCCESS;
}
4 years ago
// Get Variables
Status Session::GetVariables(const std::vector<std::string> &var_names, std::vector<Tensor> &var_values) {
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelExecute, ErrorMessage::kModelExecute);
ErrorManager::GetInstance().GenWorkStreamIdDefault();
auto instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
4 years ago
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
4 years ago
"[Get][Variables]Failed, the GELib instance is nullptr or is not InitFlag.");
4 years ago
REPORT_INNER_ERROR("E19999",
4 years ago
"GetVariables failed, the GELib instance is nullptr or is not InitFlag.");
return FAILED;
}
GELOGT(TRACE_RUNNING, "Get Variables");
Status ret = ge::GELib::GetInstance()->SessionManagerObj().GetVariables(sessionId_, var_names, var_values);
if (ret != SUCCESS) {
4 years ago
GELOGE(ret, "[Get][Variables]Failed, error code:%u, session_id:%lu.", ret, sessionId_);
return FAILED;
}
return SUCCESS;
}
4 years ago
// Get Variables
Status Session::GetVariables(const std::vector<AscendString> &var_names, std::vector<Tensor> &var_values) {
4 years ago
ErrorManager::GetInstance().SetStage(ErrorMessage::kModelExecute, ErrorMessage::kModelExecute);
ErrorManager::GetInstance().GenWorkStreamIdDefault();
auto instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
4 years ago
GELOGE(GE_CLI_GE_NOT_INITIALIZED,
4 years ago
"[Get][Variables]Failed, the GELib instance is nullptr or is not InitFlag.");
4 years ago
REPORT_INNER_ERROR("E19999",
4 years ago
"GetVariables failed, the GELib instance is nullptr or is not InitFlag.");
return FAILED;
}
GELOGT(TRACE_RUNNING, "Get Variables");
std::vector<ge::string> str_var_names;
for (auto &var_name : var_names) {
if (var_name.GetString() == nullptr) {
4 years ago
GELOGE(FAILED, "[Get][Variable]Failed, variables' names are nullptr.");
REPORT_INNER_ERROR("E19999", "GetVariables failed, variables' names are nullptr.");
return FAILED;
}
str_var_names.emplace_back(var_name.GetString());
}
Status ret = ge::GELib::GetInstance()->SessionManagerObj().GetVariables(sessionId_, str_var_names, var_values);
if (ret != SUCCESS) {
4 years ago
GELOGE(ret, "[Get][Variables]Failed, error code:%u, session_id:%lu.", ret, sessionId_);
return FAILED;
}
return SUCCESS;
}
5 years ago
bool Session::IsGraphNeedRebuild(uint32_t graph_id) {
return ge::GELib::GetInstance()->SessionManagerObj().IsGraphNeedRebuild(sessionId_, graph_id);
}
} // namespace ge