*model_buf, size_t size, lite::Context *context) interfacepull/7577/head
parent
e805051c1f
commit
e19a3e3926
@ -0,0 +1,138 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/model_common.h"
|
||||
#include "include/version.h"
|
||||
#include "src/ops/ops_register.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model) {
|
||||
for (size_t i = 0; i < meta_graph->nodes()->size(); ++i) {
|
||||
Model::Node *node = new (std::nothrow) Model::Node();
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "new node fail!";
|
||||
return false;
|
||||
}
|
||||
auto c_node = meta_graph->nodes()->GetAs<schema::CNode>(i);
|
||||
auto src_prim = c_node->primitive();
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim));
|
||||
#else
|
||||
auto primitive = const_cast<schema::Primitive *>(src_prim);
|
||||
node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive);
|
||||
#endif
|
||||
if (node->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "unpack primitive == nullptr!";
|
||||
delete node;
|
||||
return false;
|
||||
}
|
||||
node->primitive_->SetQuantType(c_node->quantType());
|
||||
node->name_ = c_node->name()->c_str();
|
||||
node->node_type_ = c_node->nodeType();
|
||||
auto count = c_node->inputIndex()->size();
|
||||
for (uint32_t j = 0; j < count; ++j) {
|
||||
node->input_indices_.push_back(size_t(c_node->inputIndex()->GetAs<uint32_t>(j)));
|
||||
}
|
||||
if (c_node->outputIndex() != nullptr) {
|
||||
count = c_node->outputIndex()->size();
|
||||
for (uint32_t j = 0; j < count; ++j) {
|
||||
node->output_indices_.push_back(size_t(c_node->outputIndex()->GetAs<uint32_t>(j)));
|
||||
}
|
||||
}
|
||||
model->nodes_.push_back(node);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model) {
|
||||
auto tensor_count = meta_graph->allTensors()->size();
|
||||
for (uint32_t i = 0; i < tensor_count; ++i) {
|
||||
auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i);
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << i << "th tensor in model is nullptr";
|
||||
return false;
|
||||
}
|
||||
model->all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
|
||||
if (model_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "The model buf is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
|
||||
if (!schema::VerifyMetaGraphBuffer(verify)) {
|
||||
MS_LOG(ERROR) << "The buffer is invalid and fail to create graph.";
|
||||
return nullptr;
|
||||
}
|
||||
auto *model = new (std::nothrow) Model();
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "new model fail!";
|
||||
return nullptr;
|
||||
}
|
||||
if (take_buf) {
|
||||
model->buf = const_cast<char *>(model_buf);
|
||||
} else {
|
||||
model->buf = reinterpret_cast<char *>(malloc(size));
|
||||
if (model->buf == nullptr) {
|
||||
MS_LOG(ERROR) << "new inner model buf fail!";
|
||||
delete (model);
|
||||
return nullptr;
|
||||
}
|
||||
memcpy(model->buf, model_buf, size);
|
||||
}
|
||||
|
||||
auto meta_graph = schema::GetMetaGraph(model->buf);
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "meta_graph is nullptr!";
|
||||
delete (model);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (meta_graph->name() != nullptr) {
|
||||
model->name_ = meta_graph->name()->c_str();
|
||||
}
|
||||
if (meta_graph->version() != nullptr) {
|
||||
model->version_ = meta_graph->version()->c_str();
|
||||
}
|
||||
|
||||
if (model->version_ != Version()) {
|
||||
MS_LOG(WARNING) << "model version is " << model->version_ << ", inference version is " << Version() << " not equal";
|
||||
}
|
||||
|
||||
auto in_count = meta_graph->inputIndex()->size();
|
||||
for (uint32_t i = 0; i < in_count; ++i) {
|
||||
model->input_indices_.push_back(size_t(meta_graph->inputIndex()->GetAs<uint32_t>(i)));
|
||||
}
|
||||
|
||||
auto out_count = meta_graph->outputIndex()->size();
|
||||
for (uint32_t i = 0; i < out_count; ++i) {
|
||||
model->output_indices_.push_back(size_t(meta_graph->outputIndex()->GetAs<uint32_t>(i)));
|
||||
}
|
||||
if (!ConvertNodes(meta_graph, model)) {
|
||||
delete model;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!ConvertTensors(meta_graph, model)) {
|
||||
delete model;
|
||||
return nullptr;
|
||||
}
|
||||
return model;
|
||||
}
|
||||
} // namespace mindspore::lite
|
@ -0,0 +1,29 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_MODEL_COMMON_H_
|
||||
#define MINDSPORE_LITE_SRC_MODEL_COMMON_H_
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "include/model.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model);
|
||||
|
||||
bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model);
|
||||
|
||||
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf);
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_MODEL_COMMON_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue