!8297 [lite] mindspore lite model compatibility

From: @xu_anyue
Reviewed-by: 
Signed-off-by:
pull/8297/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 9c23fa0a6c

@ -32,6 +32,7 @@ enum CHWK_SHAPE { CHWK_C = 0, CHWK_H = 1, CHWK_W = 2, CHWK_K = 3 };
enum KHWC_SHAPE { KHWC_K = 0, KHWC_H = 1, KHWC_W = 2, KHWC_C = 3 };
enum CHW_SHAPE { CHW_C = 0, CHW_H = 1, CHW_W = 2 };
enum HWC_SHAPE { HWC_H = 0, HWC_W = 1, HWC_C = 2 };
enum SCHEMA_VERSION { SCHEMA_CUR = 0 };
static constexpr int kNCHWDimNumber = 4;
static constexpr int kNHWCDimNumber = 4;

@ -14,121 +14,59 @@
* limitations under the License.
*/
#include "src/model_common.h"
#include "include/version.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore::lite {
bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model) {
for (size_t i = 0; i < meta_graph->nodes()->size(); ++i) {
auto *node = new (std::nothrow) Model::Node();
if (node == nullptr) {
MS_LOG(ERROR) << "new node fail!";
return false;
}
auto c_node = meta_graph->nodes()->GetAs<schema::CNode>(i);
auto src_prim = c_node->primitive();
#ifdef PRIMITIVE_WRITEABLE
node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim));
#else
auto primitive = const_cast<schema::Primitive *>(src_prim);
node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive);
#endif
if (node->primitive_ == nullptr) {
MS_LOG(ERROR) << "unpack primitive == nullptr!";
delete node;
return false;
}
node->primitive_->SetQuantType(c_node->quantType());
node->name_ = c_node->name()->c_str();
node->node_type_ = c_node->nodeType();
auto count = c_node->inputIndex()->size();
for (uint32_t j = 0; j < count; ++j) {
node->input_indices_.push_back(size_t(c_node->inputIndex()->GetAs<uint32_t>(j)));
}
if (c_node->outputIndex() != nullptr) {
count = c_node->outputIndex()->size();
for (uint32_t j = 0; j < count; ++j) {
node->output_indices_.push_back(size_t(c_node->outputIndex()->GetAs<uint32_t>(j)));
}
}
model->all_nodes_.push_back(node);
}
return true;
}
bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model) {
auto tensor_count = meta_graph->allTensors()->size();
for (uint32_t i = 0; i < tensor_count; ++i) {
auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i);
if (tensor == nullptr) {
MS_LOG(ERROR) << i << "th tensor in model is nullptr";
return false;
}
model->all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor));
}
return true;
}
int ConvertSubGraph(const schema::SubGraph *sub_graph, Model *model) {
int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) {
MS_ASSERT(model != nullptr);
MS_ASSERT(sub_graph != nullptr);
auto *sub_graph_temp = new (std::nothrow) Model::SubGraph();
if (sub_graph_temp == nullptr) {
auto *subgraph = new (std::nothrow) Model::SubGraph();
if (subgraph == nullptr) {
MS_LOG(ERROR) << "new subGraph fail!";
return RET_ERROR;
}
sub_graph_temp->name_ = sub_graph->name()->c_str();
auto in_count = sub_graph->inputIndices()->size();
subgraph->name_ = sub_graph.name()->c_str();
auto in_count = sub_graph.inputIndices()->size();
for (uint32_t i = 0; i < in_count; ++i) {
sub_graph_temp->input_indices_.push_back(size_t(sub_graph->inputIndices()->GetAs<uint32_t>(i)));
subgraph->input_indices_.push_back(size_t(sub_graph.inputIndices()->GetAs<uint32_t>(i)));
}
auto out_count = sub_graph->outputIndices()->size();
auto out_count = sub_graph.outputIndices()->size();
for (uint32_t i = 0; i < out_count; ++i) {
sub_graph_temp->output_indices_.push_back(size_t(sub_graph->outputIndices()->GetAs<uint32_t>(i)));
subgraph->output_indices_.push_back(size_t(sub_graph.outputIndices()->GetAs<uint32_t>(i)));
}
auto node_count = sub_graph->nodeIndices()->size();
auto node_count = sub_graph.nodeIndices()->size();
for (uint32_t i = 0; i < node_count; ++i) {
sub_graph_temp->node_indices_.push_back(size_t(sub_graph->nodeIndices()->GetAs<uint32_t>(i)));
subgraph->node_indices_.push_back(size_t(sub_graph.nodeIndices()->GetAs<uint32_t>(i)));
}
auto tensor_count = sub_graph->nodeIndices()->size();
auto tensor_count = sub_graph.nodeIndices()->size();
for (uint32_t i = 0; i < tensor_count; ++i) {
sub_graph_temp->tensor_indices_.push_back(size_t(sub_graph->tensorIndices()->GetAs<uint32_t>(i)));
subgraph->tensor_indices_.push_back(size_t(sub_graph.tensorIndices()->GetAs<uint32_t>(i)));
}
model->sub_graphs_.push_back(sub_graph_temp);
model->sub_graphs_.push_back(subgraph);
return RET_OK;
}
int MetaGraphMappingSubGraph(const mindspore::schema::MetaGraph *meta_graph, Model *model) {
MS_ASSERT(model != nullptr);
MS_ASSERT(meta_graph != nullptr);
auto *sub_graph_temp = new (std::nothrow) Model::SubGraph();
if (sub_graph_temp == nullptr) {
MS_LOG(ERROR) << "new subGraph fail!";
return RET_ERROR;
}
if (meta_graph->name() != nullptr) {
sub_graph_temp->name_ = meta_graph->name()->c_str();
}
auto in_count = meta_graph->inputIndex()->size();
for (uint32_t i = 0; i < in_count; ++i) {
sub_graph_temp->input_indices_.push_back(size_t(meta_graph->inputIndex()->GetAs<uint32_t>(i)));
int VersionVerify(flatbuffers::Verifier *verify) {
if (schema::VerifyMetaGraphBuffer(*verify)) {
return SCHEMA_VERSION::SCHEMA_CUR;
}
auto out_count = meta_graph->outputIndex()->size();
for (uint32_t i = 0; i < out_count; ++i) {
sub_graph_temp->output_indices_.push_back(size_t(meta_graph->outputIndex()->GetAs<uint32_t>(i)));
}
auto node_count = meta_graph->nodes()->size();
for (uint32_t i = 0; i < node_count; ++i) {
sub_graph_temp->node_indices_.push_back(i);
return -1;
}
const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) {
MS_ASSERT(buf != nullptr);
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) {
return reinterpret_cast<const void *>(schema::GetMetaGraph(buf));
}
auto tensor_count = meta_graph->nodes()->size();
for (uint32_t i = 0; i < tensor_count; ++i) {
sub_graph_temp->tensor_indices_.push_back(i);
return nullptr;
}
int GenerateModelByVersion(const void *meta_graph, Model *model, const int &schema_version) {
MS_ASSERT(model != nullptr);
int status = RET_ERROR;
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) {
status = GenerateModel<schema::MetaGraph, schema::CNode>(*reinterpret_cast<const schema::MetaGraph *>(meta_graph),
model, schema_version);
}
model->sub_graphs_.push_back(sub_graph_temp);
return RET_OK;
return status;
}
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
@ -137,7 +75,8 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
return nullptr;
}
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
if (!schema::VerifyMetaGraphBuffer(verify)) {
int schema_version = VersionVerify(&verify);
if (schema_version == -1) {
MS_LOG(ERROR) << "The buffer is invalid and fail to create graph.";
return nullptr;
}
@ -162,54 +101,25 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
}
memcpy(model->buf, model_buf, size);
}
auto meta_graph = schema::GetMetaGraph(model->buf);
const void *meta_graph = GetMetaGraphByVerison(model->buf, schema_version);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "meta_graph is nullptr!";
delete (model);
return nullptr;
}
if (meta_graph->name() != nullptr) {
model->name_ = meta_graph->name()->c_str();
}
if (meta_graph->version() != nullptr) {
model->version_ = meta_graph->version()->c_str();
int status = GenerateModelByVersion(meta_graph, model, schema_version);
if (status != RET_OK) {
delete (model);
MS_LOG(ERROR) << "fail to generate model";
return nullptr;
}
if (model->version_ != Version()) {
MS_LOG(WARNING) << "model version is " << model->version_ << ", inference version is " << Version() << " not equal";
}
if (!ConvertNodes(meta_graph, model)) {
delete model;
return nullptr;
}
if (!ConvertTensors(meta_graph, model)) {
delete model;
return nullptr;
}
if (meta_graph->subGraph() == nullptr) {
int ret = MetaGraphMappingSubGraph(meta_graph, model);
if (ret != RET_OK) {
MS_LOG(ERROR) << "converter old version model wrong.";
return nullptr;
}
} else {
auto sub_graphs = meta_graph->subGraph();
auto sub_graph_size = sub_graphs->size();
for (size_t i = 0; i < sub_graph_size; i++) {
auto sub_graph = sub_graphs->GetAs<schema::SubGraph>(i);
int ret = ConvertSubGraph(sub_graph, model);
if (ret != RET_OK) {
MS_LOG(ERROR) << "converter subgraph wrong.";
return nullptr;
}
}
}
if (model->sub_graphs_.empty()) {
delete (model);
return nullptr;
}
return model;

@ -16,17 +16,149 @@
#ifndef MINDSPORE_LITE_SRC_MODEL_COMMON_H_
#define MINDSPORE_LITE_SRC_MODEL_COMMON_H_
#include <string>
#include "src/ops/primitive_c.h"
#include "include/model.h"
#include "include/version.h"
#include "schema/model_generated.h"
#include "src/common/common.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore::lite {
bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model);
int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model);
template <typename T = schema::MetaGraph, typename U = schema::CNode>
bool ConvertNodes(const T &meta_graph, Model *model, int schema_version = 0) {
MS_ASSERT(model != nullptr);
for (size_t i = 0; i < meta_graph.nodes()->size(); ++i) {
auto *node = new (std::nothrow) Model::Node();
if (node == nullptr) {
MS_LOG(ERROR) << "new node fail!";
return false;
}
auto c_node = meta_graph.nodes()->template GetAs<U>(i);
auto src_prim = reinterpret_cast<const schema::Primitive *>(c_node->primitive());
#ifdef PRIMITIVE_WRITEABLE
node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim));
#else
auto primitive = const_cast<schema::Primitive *>(src_prim);
node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive);
#endif
if (node->primitive_ == nullptr) {
MS_LOG(ERROR) << "unpack primitive == nullptr!";
delete node;
return false;
}
node->primitive_->SetQuantType(static_cast<schema::QuantType>(c_node->quantType()));
node->name_ = c_node->name()->c_str();
node->node_type_ = static_cast<NodeType>(c_node->nodeType());
auto count = c_node->inputIndex()->size();
for (uint32_t j = 0; j < count; ++j) {
node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j)));
}
if (c_node->outputIndex() != nullptr) {
count = c_node->outputIndex()->size();
for (uint32_t j = 0; j < count; ++j) {
node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j)));
}
}
model->all_nodes_.push_back(node);
}
return true;
}
template <typename T = schema::MetaGraph>
bool ConvertTensors(const T &meta_graph, Model *model) {
MS_ASSERT(model != nullptr);
auto tensor_count = meta_graph.allTensors()->size();
for (uint32_t i = 0; i < tensor_count; ++i) {
auto *tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(i);
if (tensor == nullptr) {
MS_LOG(ERROR) << i << "th tensor in model is nullptr";
return false;
}
model->all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor));
}
return true;
}
template <typename T = schema::MetaGraph>
int MetaGraphMappingSubGraph(const T &meta_graph, Model *model) {
MS_ASSERT(model != nullptr);
auto *subgraph = new (std::nothrow) Model::SubGraph();
if (subgraph == nullptr) {
MS_LOG(ERROR) << "new subGraph fail!";
return RET_ERROR;
}
if (meta_graph.name() != nullptr) {
subgraph->name_ = meta_graph.name()->c_str();
}
auto in_count = meta_graph.inputIndex()->size();
for (uint32_t i = 0; i < in_count; ++i) {
subgraph->input_indices_.push_back(size_t(meta_graph.inputIndex()->template GetAs<uint32_t>(i)));
}
auto out_count = meta_graph.outputIndex()->size();
for (uint32_t i = 0; i < out_count; ++i) {
subgraph->output_indices_.push_back(size_t(meta_graph.outputIndex()->template GetAs<uint32_t>(i)));
}
auto node_count = meta_graph.nodes()->size();
for (uint32_t i = 0; i < node_count; ++i) {
subgraph->node_indices_.push_back(i);
}
auto tensor_count = meta_graph.nodes()->size();
for (uint32_t i = 0; i < tensor_count; ++i) {
subgraph->tensor_indices_.push_back(i);
}
model->sub_graphs_.push_back(subgraph);
return RET_OK;
}
template <typename T = schema::MetaGraph, typename U = schema::CNode>
int GenerateModel(const T &meta_graph, Model *model, int schema_version = 0) {
MS_ASSERT(model != nullptr);
if (meta_graph.name() != nullptr) {
model->name_ = meta_graph.name()->c_str();
}
if (meta_graph.version() != nullptr) {
model->version_ = meta_graph.version()->c_str();
}
if (!ConvertNodes<T, U>(meta_graph, model, schema_version)) {
MS_LOG(ERROR) << "convert node failed";
return RET_ERROR;
}
if (!ConvertTensors<T>(meta_graph, model)) {
MS_LOG(ERROR) << "convert tensor failed";
return RET_ERROR;
}
if (meta_graph.subGraph() == nullptr) {
int ret = MetaGraphMappingSubGraph<T>(meta_graph, model);
if (ret != RET_OK) {
MS_LOG(ERROR) << "converter old version model wrong.";
return ret;
}
} else {
auto sub_graphs = meta_graph.subGraph();
auto sub_graph_size = sub_graphs->size();
for (size_t i = 0; i < sub_graph_size; i++) {
auto sub_graph = sub_graphs->template GetAs<schema::SubGraph>(i);
int ret = ConvertSubGraph(*sub_graph, model);
if (ret != RET_OK) {
MS_LOG(ERROR) << "converter subgraph wrong.";
return ret;
}
}
}
return RET_OK;
}
bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model);
int VersionVerify(flatbuffers::Verifier *verify);
int ConvertSubGraph(const schema::SubGraph *sub_graph, Model *model);
const void *GetMetaGraphByVerison(const char *buf, const int &schema_version);
int MetaGraphMappingSubGraph(const mindspore::schema::MetaGraph *meta_graph, Model *model);
int GenerateModelByVersion(const void *meta_graph, Model *model, const int &schema_version);
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf);
} // namespace mindspore::lite

@ -22,9 +22,6 @@
namespace mindspore::lite {
bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model);
bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model);
TrainModel *TrainModel::Import(const char *model_buf, size_t size) {
if (model_buf == nullptr) {
MS_LOG(ERROR) << "The model buf is nullptr";
@ -62,18 +59,18 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) {
if (meta_graph->version() != nullptr) {
model->version_ = meta_graph->version()->c_str();
}
if (!ConvertNodes(meta_graph, model)) {
if (!ConvertNodes(*meta_graph, model)) {
delete model;
return nullptr;
}
if (!ConvertTensors(meta_graph, model)) {
if (!ConvertTensors(*meta_graph, model)) {
delete model;
return nullptr;
}
if (meta_graph->subGraph() == nullptr) {
int ret = MetaGraphMappingSubGraph(meta_graph, model);
int ret = MetaGraphMappingSubGraph(*meta_graph, model);
if (ret != RET_OK) {
MS_LOG(ERROR) << "converter old version model wrong.";
return nullptr;
@ -83,7 +80,7 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) {
auto sub_graph_size = sub_graphs->size();
for (size_t i = 0; i < sub_graph_size; i++) {
auto sub_graph = sub_graphs->GetAs<schema::SubGraph>(i);
int ret = ConvertSubGraph(sub_graph, model);
int ret = ConvertSubGraph(*sub_graph, model);
if (ret != RET_OK) {
MS_LOG(ERROR) << "converter subgraph wrong.";
return nullptr;

Loading…
Cancel
Save