|
|
|
@ -14,121 +14,59 @@
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
#include "src/model_common.h"
|
|
|
|
|
#include "include/version.h"
|
|
|
|
|
#ifndef PRIMITIVE_WRITEABLE
|
|
|
|
|
#include "src/ops/ops_register.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace mindspore::lite {
|
|
|
|
|
bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model) {
|
|
|
|
|
for (size_t i = 0; i < meta_graph->nodes()->size(); ++i) {
|
|
|
|
|
auto *node = new (std::nothrow) Model::Node();
|
|
|
|
|
if (node == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new node fail!";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto c_node = meta_graph->nodes()->GetAs<schema::CNode>(i);
|
|
|
|
|
auto src_prim = c_node->primitive();
|
|
|
|
|
#ifdef PRIMITIVE_WRITEABLE
|
|
|
|
|
node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim));
|
|
|
|
|
#else
|
|
|
|
|
auto primitive = const_cast<schema::Primitive *>(src_prim);
|
|
|
|
|
node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive);
|
|
|
|
|
#endif
|
|
|
|
|
if (node->primitive_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "unpack primitive == nullptr!";
|
|
|
|
|
delete node;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
node->primitive_->SetQuantType(c_node->quantType());
|
|
|
|
|
node->name_ = c_node->name()->c_str();
|
|
|
|
|
node->node_type_ = c_node->nodeType();
|
|
|
|
|
auto count = c_node->inputIndex()->size();
|
|
|
|
|
for (uint32_t j = 0; j < count; ++j) {
|
|
|
|
|
node->input_indices_.push_back(size_t(c_node->inputIndex()->GetAs<uint32_t>(j)));
|
|
|
|
|
}
|
|
|
|
|
if (c_node->outputIndex() != nullptr) {
|
|
|
|
|
count = c_node->outputIndex()->size();
|
|
|
|
|
for (uint32_t j = 0; j < count; ++j) {
|
|
|
|
|
node->output_indices_.push_back(size_t(c_node->outputIndex()->GetAs<uint32_t>(j)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
model->all_nodes_.push_back(node);
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model) {
|
|
|
|
|
auto tensor_count = meta_graph->allTensors()->size();
|
|
|
|
|
for (uint32_t i = 0; i < tensor_count; ++i) {
|
|
|
|
|
auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i);
|
|
|
|
|
if (tensor == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << i << "th tensor in model is nullptr";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
model->all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor));
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int ConvertSubGraph(const schema::SubGraph *sub_graph, Model *model) {
|
|
|
|
|
int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) {
|
|
|
|
|
MS_ASSERT(model != nullptr);
|
|
|
|
|
MS_ASSERT(sub_graph != nullptr);
|
|
|
|
|
auto *sub_graph_temp = new (std::nothrow) Model::SubGraph();
|
|
|
|
|
if (sub_graph_temp == nullptr) {
|
|
|
|
|
auto *subgraph = new (std::nothrow) Model::SubGraph();
|
|
|
|
|
if (subgraph == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new subGraph fail!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
sub_graph_temp->name_ = sub_graph->name()->c_str();
|
|
|
|
|
auto in_count = sub_graph->inputIndices()->size();
|
|
|
|
|
subgraph->name_ = sub_graph.name()->c_str();
|
|
|
|
|
auto in_count = sub_graph.inputIndices()->size();
|
|
|
|
|
for (uint32_t i = 0; i < in_count; ++i) {
|
|
|
|
|
sub_graph_temp->input_indices_.push_back(size_t(sub_graph->inputIndices()->GetAs<uint32_t>(i)));
|
|
|
|
|
subgraph->input_indices_.push_back(size_t(sub_graph.inputIndices()->GetAs<uint32_t>(i)));
|
|
|
|
|
}
|
|
|
|
|
auto out_count = sub_graph->outputIndices()->size();
|
|
|
|
|
auto out_count = sub_graph.outputIndices()->size();
|
|
|
|
|
for (uint32_t i = 0; i < out_count; ++i) {
|
|
|
|
|
sub_graph_temp->output_indices_.push_back(size_t(sub_graph->outputIndices()->GetAs<uint32_t>(i)));
|
|
|
|
|
subgraph->output_indices_.push_back(size_t(sub_graph.outputIndices()->GetAs<uint32_t>(i)));
|
|
|
|
|
}
|
|
|
|
|
auto node_count = sub_graph->nodeIndices()->size();
|
|
|
|
|
auto node_count = sub_graph.nodeIndices()->size();
|
|
|
|
|
for (uint32_t i = 0; i < node_count; ++i) {
|
|
|
|
|
sub_graph_temp->node_indices_.push_back(size_t(sub_graph->nodeIndices()->GetAs<uint32_t>(i)));
|
|
|
|
|
subgraph->node_indices_.push_back(size_t(sub_graph.nodeIndices()->GetAs<uint32_t>(i)));
|
|
|
|
|
}
|
|
|
|
|
auto tensor_count = sub_graph->nodeIndices()->size();
|
|
|
|
|
auto tensor_count = sub_graph.nodeIndices()->size();
|
|
|
|
|
for (uint32_t i = 0; i < tensor_count; ++i) {
|
|
|
|
|
sub_graph_temp->tensor_indices_.push_back(size_t(sub_graph->tensorIndices()->GetAs<uint32_t>(i)));
|
|
|
|
|
subgraph->tensor_indices_.push_back(size_t(sub_graph.tensorIndices()->GetAs<uint32_t>(i)));
|
|
|
|
|
}
|
|
|
|
|
model->sub_graphs_.push_back(sub_graph_temp);
|
|
|
|
|
model->sub_graphs_.push_back(subgraph);
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int MetaGraphMappingSubGraph(const mindspore::schema::MetaGraph *meta_graph, Model *model) {
|
|
|
|
|
MS_ASSERT(model != nullptr);
|
|
|
|
|
MS_ASSERT(meta_graph != nullptr);
|
|
|
|
|
auto *sub_graph_temp = new (std::nothrow) Model::SubGraph();
|
|
|
|
|
if (sub_graph_temp == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new subGraph fail!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (meta_graph->name() != nullptr) {
|
|
|
|
|
sub_graph_temp->name_ = meta_graph->name()->c_str();
|
|
|
|
|
}
|
|
|
|
|
auto in_count = meta_graph->inputIndex()->size();
|
|
|
|
|
for (uint32_t i = 0; i < in_count; ++i) {
|
|
|
|
|
sub_graph_temp->input_indices_.push_back(size_t(meta_graph->inputIndex()->GetAs<uint32_t>(i)));
|
|
|
|
|
int VersionVerify(flatbuffers::Verifier *verify) {
|
|
|
|
|
if (schema::VerifyMetaGraphBuffer(*verify)) {
|
|
|
|
|
return SCHEMA_VERSION::SCHEMA_CUR;
|
|
|
|
|
}
|
|
|
|
|
auto out_count = meta_graph->outputIndex()->size();
|
|
|
|
|
for (uint32_t i = 0; i < out_count; ++i) {
|
|
|
|
|
sub_graph_temp->output_indices_.push_back(size_t(meta_graph->outputIndex()->GetAs<uint32_t>(i)));
|
|
|
|
|
}
|
|
|
|
|
auto node_count = meta_graph->nodes()->size();
|
|
|
|
|
for (uint32_t i = 0; i < node_count; ++i) {
|
|
|
|
|
sub_graph_temp->node_indices_.push_back(i);
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) {
|
|
|
|
|
MS_ASSERT(buf != nullptr);
|
|
|
|
|
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) {
|
|
|
|
|
return reinterpret_cast<const void *>(schema::GetMetaGraph(buf));
|
|
|
|
|
}
|
|
|
|
|
auto tensor_count = meta_graph->nodes()->size();
|
|
|
|
|
for (uint32_t i = 0; i < tensor_count; ++i) {
|
|
|
|
|
sub_graph_temp->tensor_indices_.push_back(i);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int GenerateModelByVersion(const void *meta_graph, Model *model, const int &schema_version) {
|
|
|
|
|
MS_ASSERT(model != nullptr);
|
|
|
|
|
int status = RET_ERROR;
|
|
|
|
|
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) {
|
|
|
|
|
status = GenerateModel<schema::MetaGraph, schema::CNode>(*reinterpret_cast<const schema::MetaGraph *>(meta_graph),
|
|
|
|
|
model, schema_version);
|
|
|
|
|
}
|
|
|
|
|
model->sub_graphs_.push_back(sub_graph_temp);
|
|
|
|
|
return RET_OK;
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
|
|
|
|
@ -137,7 +75,8 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
|
|
|
|
|
if (!schema::VerifyMetaGraphBuffer(verify)) {
|
|
|
|
|
int schema_version = VersionVerify(&verify);
|
|
|
|
|
if (schema_version == -1) {
|
|
|
|
|
MS_LOG(ERROR) << "The buffer is invalid and fail to create graph.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
@ -162,54 +101,25 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
|
|
|
|
|
}
|
|
|
|
|
memcpy(model->buf, model_buf, size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto meta_graph = schema::GetMetaGraph(model->buf);
|
|
|
|
|
const void *meta_graph = GetMetaGraphByVerison(model->buf, schema_version);
|
|
|
|
|
if (meta_graph == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "meta_graph is nullptr!";
|
|
|
|
|
delete (model);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (meta_graph->name() != nullptr) {
|
|
|
|
|
model->name_ = meta_graph->name()->c_str();
|
|
|
|
|
}
|
|
|
|
|
if (meta_graph->version() != nullptr) {
|
|
|
|
|
model->version_ = meta_graph->version()->c_str();
|
|
|
|
|
int status = GenerateModelByVersion(meta_graph, model, schema_version);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
delete (model);
|
|
|
|
|
MS_LOG(ERROR) << "fail to generate model";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (model->version_ != Version()) {
|
|
|
|
|
MS_LOG(WARNING) << "model version is " << model->version_ << ", inference version is " << Version() << " not equal";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!ConvertNodes(meta_graph, model)) {
|
|
|
|
|
delete model;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!ConvertTensors(meta_graph, model)) {
|
|
|
|
|
delete model;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (meta_graph->subGraph() == nullptr) {
|
|
|
|
|
int ret = MetaGraphMappingSubGraph(meta_graph, model);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "converter old version model wrong.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto sub_graphs = meta_graph->subGraph();
|
|
|
|
|
auto sub_graph_size = sub_graphs->size();
|
|
|
|
|
for (size_t i = 0; i < sub_graph_size; i++) {
|
|
|
|
|
auto sub_graph = sub_graphs->GetAs<schema::SubGraph>(i);
|
|
|
|
|
int ret = ConvertSubGraph(sub_graph, model);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "converter subgraph wrong.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (model->sub_graphs_.empty()) {
|
|
|
|
|
delete (model);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
return model;
|
|
|
|
|