|
|
|
@ -37,8 +37,8 @@ namespace mindspore::inference {
|
|
|
|
|
std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &device, uint32_t device_id) {
|
|
|
|
|
try {
|
|
|
|
|
auto session = std::make_shared<MSInferSession>();
|
|
|
|
|
bool ret = session->InitEnv(device, device_id);
|
|
|
|
|
if (!ret) {
|
|
|
|
|
Status ret = session->InitEnv(device, device_id);
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
return session;
|
|
|
|
@ -84,21 +84,21 @@ std::shared_ptr<std::vector<char>> MSInferSession::ReadFile(const std::string &f
|
|
|
|
|
return buf;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
|
|
|
|
Status MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
|
|
|
|
|
auto graphBuf = ReadFile(file_name);
|
|
|
|
|
if (graphBuf == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
|
|
|
|
|
return false;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
auto graph = LoadModel(graphBuf->data(), graphBuf->size(), device_type_);
|
|
|
|
|
if (graph == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
|
|
|
|
return false;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
bool ret = CompileGraph(graph, model_id);
|
|
|
|
|
if (!ret) {
|
|
|
|
|
Status ret = CompileGraph(graph, model_id);
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << "Compile graph model failed, file name is " << file_name.c_str();
|
|
|
|
|
return false;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Load model from file " << file_name << " success";
|
|
|
|
|
|
|
|
|
@ -107,14 +107,14 @@ bool MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &m
|
|
|
|
|
rtError_t rt_ret = rtCtxGetCurrent(&context_);
|
|
|
|
|
if (rt_ret != RT_ERROR_NONE || context_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "the ascend device context is null";
|
|
|
|
|
return false;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MSInferSession::UnloadModel(uint32_t model_id) { return true; }
|
|
|
|
|
Status MSInferSession::UnloadModel(uint32_t model_id) { return SUCCESS; }
|
|
|
|
|
|
|
|
|
|
tensor::TensorPtr ServingTensor2MSTensor(const InferTensorBase &out_tensor) {
|
|
|
|
|
std::vector<int> shape;
|
|
|
|
@ -170,16 +170,16 @@ void MSTensor2ServingTensor(tensor::TensorPtr ms_tensor, InferTensorBase &out_te
|
|
|
|
|
out_tensor.set_data(ms_tensor->data_c(), ms_tensor->Size());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) {
|
|
|
|
|
Status MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) {
|
|
|
|
|
#ifdef ENABLE_D
|
|
|
|
|
if (context_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "rtCtx is nullptr";
|
|
|
|
|
return false;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
rtError_t rt_ret = rtCtxSetCurrent(context_);
|
|
|
|
|
if (rt_ret != RT_ERROR_NONE) {
|
|
|
|
|
MS_LOG(ERROR) << "set Ascend rtCtx failed";
|
|
|
|
|
return false;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
@ -187,47 +187,47 @@ bool MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request,
|
|
|
|
|
for (size_t i = 0; i < request.size(); i++) {
|
|
|
|
|
if (request[i] == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, input tensor is null, index " << i;
|
|
|
|
|
return false;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
auto input = ServingTensor2MSTensor(*request[i]);
|
|
|
|
|
if (input == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Tensor convert failed";
|
|
|
|
|
return false;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
inputs.push_back(input);
|
|
|
|
|
}
|
|
|
|
|
if (!CheckModelInputs(model_id, inputs)) {
|
|
|
|
|
MS_LOG(ERROR) << "Check Model " << model_id << " Inputs Failed";
|
|
|
|
|
return false;
|
|
|
|
|
return INVALID_INPUTS;
|
|
|
|
|
}
|
|
|
|
|
vector<tensor::TensorPtr> outputs = RunGraph(model_id, inputs);
|
|
|
|
|
if (outputs.empty()) {
|
|
|
|
|
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed";
|
|
|
|
|
return false;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
reply.clear();
|
|
|
|
|
for (const auto &tensor : outputs) {
|
|
|
|
|
auto out_tensor = reply.add();
|
|
|
|
|
if (out_tensor == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, add output tensor failed";
|
|
|
|
|
return false;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
MSTensor2ServingTensor(tensor, *out_tensor);
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MSInferSession::FinalizeEnv() {
|
|
|
|
|
Status MSInferSession::FinalizeEnv() {
|
|
|
|
|
auto ms_context = MsContext::GetInstance();
|
|
|
|
|
if (ms_context == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Get Context failed!";
|
|
|
|
|
return false;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
if (!ms_context->CloseTsd()) {
|
|
|
|
|
MS_LOG(ERROR) << "Inference CloseTsd failed!";
|
|
|
|
|
return false;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<FuncGraph> MSInferSession::LoadModel(const char *model_buf, size_t size, const std::string &device) {
|
|
|
|
@ -292,16 +292,16 @@ void MSInferSession::RegAllOp() {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MSInferSession::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id) {
|
|
|
|
|
Status MSInferSession::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id) {
|
|
|
|
|
MS_ASSERT(session_impl_ != nullptr);
|
|
|
|
|
try {
|
|
|
|
|
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
|
|
|
|
|
py::gil_scoped_release gil_release;
|
|
|
|
|
model_id = graph_id;
|
|
|
|
|
return true;
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
} catch (std::exception &e) {
|
|
|
|
|
MS_LOG(ERROR) << "Inference CompileGraph failed";
|
|
|
|
|
return false;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -327,31 +327,31 @@ string MSInferSession::AjustTargetName(const std::string &device) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
|
|
|
|
|
Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
|
|
|
|
|
RegAllOp();
|
|
|
|
|
auto ms_context = MsContext::GetInstance();
|
|
|
|
|
ms_context->set_execution_mode(kGraphMode);
|
|
|
|
|
ms_context->set_device_id(device_id);
|
|
|
|
|
auto ajust_device = AjustTargetName(device);
|
|
|
|
|
if (ajust_device == "") {
|
|
|
|
|
return false;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
ms_context->set_device_target(device);
|
|
|
|
|
session_impl_ = session::SessionFactory::Get().Create(ajust_device);
|
|
|
|
|
if (session_impl_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
|
|
|
|
|
return false;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
session_impl_->Init(device_id);
|
|
|
|
|
if (ms_context == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Get Context failed!";
|
|
|
|
|
return false;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
if (!ms_context->OpenTsd()) {
|
|
|
|
|
MS_LOG(ERROR) << "Session init OpenTsd failed!";
|
|
|
|
|
return false;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MSInferSession::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const {
|
|
|
|
|