You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/predict/src/session.cc

155 lines
3.8 KiB

/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/session.h"
#include <map>
#include <atomic>
#include "include/errorcode.h"
#include "common/mslog.h"
#include "src/graph.h"
#include "src/graph_execution.h"
namespace mindspore {
namespace predict {
Context m_ctx;
bool m_isConfig = false;
// In 32bits, this evaluates to 2GB - 1
static constexpr auto MAX_BUFFER_SIZE = ((1ULL << (sizeof(int32_t) * 8 - 1)) - 1);
std::shared_ptr<Session> CreateSession(const char *graphBuf, size_t size, const Context &ctx) {
if (graphBuf == nullptr) {
MS_LOGE("the graphBuf is nullptr");
return nullptr;
}
if (size > MAX_BUFFER_SIZE) {
MS_LOGE("the size is invalid");
return nullptr;
}
auto session = std::make_shared<Session>(ctx);
MS_ASSERT(session != nullptr);
auto ret = session->Init(graphBuf, size);
if (ret != RET_OK) {
MS_LOGE("Init session failed.");
return nullptr;
}
return session;
}
Session::Session(const Context &ctx) : _ctx(ctx) {
Context cfgCtx;
cfgCtx = ctx;
if (cfgCtx.threadNum > m_ctx.threadNum) {
cfgCtx.threadNum = m_ctx.threadNum;
}
}
int Session::Init(const char *graphBuf, size_t size) {
_graph = Graph::CreateFromBuf(graphBuf, size, _ctx);
if (_graph == nullptr) {
MS_LOGE("Graph create from buf failed.");
return RET_NULL_PTR;
}
auto ret = this->InitExecutor();
if (ret != RET_OK) {
MS_LOGE("Init Executor failed");
return ret;
}
return ret;
}
int Session::InitExecutor() {
if (_executor != nullptr) {
delete _executor;
_executor = nullptr;
}
if (_graph != nullptr) {
_executor = new (std::nothrow) GraphExecution(_ctx, _graph);
if (_executor == nullptr) {
MS_LOGE("new GraphExecution fail");
return RET_ERROR;
}
return RET_OK;
} else {
MS_LOGE("the graph is nullptr");
return RET_ERROR;
}
}
Session::~Session() {
if (_executor != nullptr) {
delete _executor;
}
if (_graph != nullptr) {
delete _graph;
}
}
int Session::Run(const std::vector<Tensor *> &inputs) {
auto ret = RET_OK;
if (reinitExecutor) {
ret = this->InitExecutor();
if (ret != RET_OK) {
MS_LOGE("Init Executor failed");
return ret;
}
}
if (_executor == nullptr) {
MS_LOGE("_executor is nullptr");
return ret;
}
ret = _executor->Run(inputs);
return ret;
}
std::vector<Tensor *> Session::GetInput() {
if (_executor == nullptr) {
MS_LOGE("_executor is nullptr");
return std::vector<Tensor *>{};
}
auto inputs = _executor->GetInput();
if (inputs.empty()) {
MS_LOGI("output is empty.");
}
return inputs;
}
std::vector<Tensor *> Session::GetOutput(const std::string &nodeName) {
if (_executor == nullptr) {
MS_LOGE("graph's executor is nullptr.");
return std::vector<Tensor *>{};
}
auto outputs = _executor->GetOutput(nodeName);
if (outputs.empty()) {
MS_LOGI("output is empty.");
}
return outputs;
}
std::map<std::string, std::vector<Tensor *>> Session::GetAllOutput() {
if (_executor == nullptr) {
MS_LOGE("graph's executor is nullptr.");
return std::map<std::string, std::vector<Tensor *>>{};
}
auto outputs = _executor->GetAllOutput();
if (outputs.empty()) {
MS_LOGI("outputs is empty.");
}
return outputs;
}
} // namespace predict
} // namespace mindspore