From 98e46042a5f0e3810331edf67bad8a8811f24e90 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Fri, 6 Nov 2020 14:58:30 +0800 Subject: [PATCH] compatible --- mindspore/lite/src/common/common.h | 1 + mindspore/lite/src/model_common.cc | 174 ++++++------------------ mindspore/lite/src/model_common.h | 140 ++++++++++++++++++- mindspore/lite/src/train/train_model.cc | 11 +- 4 files changed, 183 insertions(+), 143 deletions(-) diff --git a/mindspore/lite/src/common/common.h b/mindspore/lite/src/common/common.h index 0dc90fe2cb..9664dea33d 100644 --- a/mindspore/lite/src/common/common.h +++ b/mindspore/lite/src/common/common.h @@ -32,6 +32,7 @@ enum CHWK_SHAPE { CHWK_C = 0, CHWK_H = 1, CHWK_W = 2, CHWK_K = 3 }; enum KHWC_SHAPE { KHWC_K = 0, KHWC_H = 1, KHWC_W = 2, KHWC_C = 3 }; enum CHW_SHAPE { CHW_C = 0, CHW_H = 1, CHW_W = 2 }; enum HWC_SHAPE { HWC_H = 0, HWC_W = 1, HWC_C = 2 }; +enum SCHEMA_VERSION { SCHEMA_CUR = 0 }; static constexpr int kNCHWDimNumber = 4; static constexpr int kNHWCDimNumber = 4; diff --git a/mindspore/lite/src/model_common.cc b/mindspore/lite/src/model_common.cc index 51062b07e5..a0073e41fd 100644 --- a/mindspore/lite/src/model_common.cc +++ b/mindspore/lite/src/model_common.cc @@ -14,121 +14,59 @@ * limitations under the License. */ #include "src/model_common.h" -#include "include/version.h" -#ifndef PRIMITIVE_WRITEABLE -#include "src/ops/ops_register.h" -#endif namespace mindspore::lite { -bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model) { - for (size_t i = 0; i < meta_graph->nodes()->size(); ++i) { - auto *node = new (std::nothrow) Model::Node(); - if (node == nullptr) { - MS_LOG(ERROR) << "new node fail!"; - return false; - } - auto c_node = meta_graph->nodes()->GetAs(i); - auto src_prim = c_node->primitive(); -#ifdef PRIMITIVE_WRITEABLE - node->primitive_ = PrimitiveC::Create(const_cast(src_prim)); -#else - auto primitive = const_cast(src_prim); - node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive); -#endif - if (node->primitive_ == nullptr) { - MS_LOG(ERROR) << "unpack primitive == nullptr!"; - delete node; - return false; - } - node->primitive_->SetQuantType(c_node->quantType()); - node->name_ = c_node->name()->c_str(); - node->node_type_ = c_node->nodeType(); - auto count = c_node->inputIndex()->size(); - for (uint32_t j = 0; j < count; ++j) { - node->input_indices_.push_back(size_t(c_node->inputIndex()->GetAs(j))); - } - if (c_node->outputIndex() != nullptr) { - count = c_node->outputIndex()->size(); - for (uint32_t j = 0; j < count; ++j) { - node->output_indices_.push_back(size_t(c_node->outputIndex()->GetAs(j))); - } - } - model->all_nodes_.push_back(node); - } - return true; -} - -bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model) { - auto tensor_count = meta_graph->allTensors()->size(); - for (uint32_t i = 0; i < tensor_count; ++i) { - auto *tensor = meta_graph->allTensors()->GetAs(i); - if (tensor == nullptr) { - MS_LOG(ERROR) << i << "th tensor in model is nullptr"; - return false; - } - model->all_tensors_.push_back(const_cast(tensor)); - } - return true; -} - -int ConvertSubGraph(const schema::SubGraph *sub_graph, Model *model) { +int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) { MS_ASSERT(model != nullptr); - MS_ASSERT(sub_graph != nullptr); - auto *sub_graph_temp = new (std::nothrow) Model::SubGraph(); - if (sub_graph_temp == nullptr) { + auto *subgraph = new (std::nothrow) Model::SubGraph(); + if (subgraph == nullptr) { MS_LOG(ERROR) << "new subGraph fail!"; return RET_ERROR; } - sub_graph_temp->name_ = sub_graph->name()->c_str(); - auto in_count = sub_graph->inputIndices()->size(); + subgraph->name_ = sub_graph.name()->c_str(); + auto in_count = sub_graph.inputIndices()->size(); for (uint32_t i = 0; i < in_count; ++i) { - sub_graph_temp->input_indices_.push_back(size_t(sub_graph->inputIndices()->GetAs(i))); + subgraph->input_indices_.push_back(size_t(sub_graph.inputIndices()->GetAs(i))); } - auto out_count = sub_graph->outputIndices()->size(); + auto out_count = sub_graph.outputIndices()->size(); for (uint32_t i = 0; i < out_count; ++i) { - sub_graph_temp->output_indices_.push_back(size_t(sub_graph->outputIndices()->GetAs(i))); + subgraph->output_indices_.push_back(size_t(sub_graph.outputIndices()->GetAs(i))); } - auto node_count = sub_graph->nodeIndices()->size(); + auto node_count = sub_graph.nodeIndices()->size(); for (uint32_t i = 0; i < node_count; ++i) { - sub_graph_temp->node_indices_.push_back(size_t(sub_graph->nodeIndices()->GetAs(i))); + subgraph->node_indices_.push_back(size_t(sub_graph.nodeIndices()->GetAs(i))); } - auto tensor_count = sub_graph->nodeIndices()->size(); + auto tensor_count = sub_graph.nodeIndices()->size(); for (uint32_t i = 0; i < tensor_count; ++i) { - sub_graph_temp->tensor_indices_.push_back(size_t(sub_graph->tensorIndices()->GetAs(i))); + subgraph->tensor_indices_.push_back(size_t(sub_graph.tensorIndices()->GetAs(i))); } - model->sub_graphs_.push_back(sub_graph_temp); + model->sub_graphs_.push_back(subgraph); return RET_OK; } -int MetaGraphMappingSubGraph(const mindspore::schema::MetaGraph *meta_graph, Model *model) { - MS_ASSERT(model != nullptr); - MS_ASSERT(meta_graph != nullptr); - auto *sub_graph_temp = new (std::nothrow) Model::SubGraph(); - if (sub_graph_temp == nullptr) { - MS_LOG(ERROR) << "new subGraph fail!"; - return RET_ERROR; - } - if (meta_graph->name() != nullptr) { - sub_graph_temp->name_ = meta_graph->name()->c_str(); - } - auto in_count = meta_graph->inputIndex()->size(); - for (uint32_t i = 0; i < in_count; ++i) { - sub_graph_temp->input_indices_.push_back(size_t(meta_graph->inputIndex()->GetAs(i))); +int VersionVerify(flatbuffers::Verifier *verify) { + if (schema::VerifyMetaGraphBuffer(*verify)) { + return SCHEMA_VERSION::SCHEMA_CUR; } - auto out_count = meta_graph->outputIndex()->size(); - for (uint32_t i = 0; i < out_count; ++i) { - sub_graph_temp->output_indices_.push_back(size_t(meta_graph->outputIndex()->GetAs(i))); - } - auto node_count = meta_graph->nodes()->size(); - for (uint32_t i = 0; i < node_count; ++i) { - sub_graph_temp->node_indices_.push_back(i); + return -1; +} + +const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) { + MS_ASSERT(buf != nullptr); + if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) { + return reinterpret_cast(schema::GetMetaGraph(buf)); } - auto tensor_count = meta_graph->nodes()->size(); - for (uint32_t i = 0; i < tensor_count; ++i) { - sub_graph_temp->tensor_indices_.push_back(i); + return nullptr; +} + +int GenerateModelByVersion(const void *meta_graph, Model *model, const int &schema_version) { + MS_ASSERT(model != nullptr); + int status = RET_ERROR; + if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) { + status = GenerateModel(*reinterpret_cast(meta_graph), + model, schema_version); } - model->sub_graphs_.push_back(sub_graph_temp); - return RET_OK; + return status; } Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) { @@ -137,7 +75,8 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) { return nullptr; } flatbuffers::Verifier verify((const uint8_t *)model_buf, size); - if (!schema::VerifyMetaGraphBuffer(verify)) { + int schema_version = VersionVerify(&verify); + if (schema_version == -1) { MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; return nullptr; } @@ -162,54 +101,25 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) { } memcpy(model->buf, model_buf, size); } - - auto meta_graph = schema::GetMetaGraph(model->buf); + const void *meta_graph = GetMetaGraphByVerison(model->buf, schema_version); if (meta_graph == nullptr) { MS_LOG(ERROR) << "meta_graph is nullptr!"; delete (model); return nullptr; } - if (meta_graph->name() != nullptr) { - model->name_ = meta_graph->name()->c_str(); - } - if (meta_graph->version() != nullptr) { - model->version_ = meta_graph->version()->c_str(); + int status = GenerateModelByVersion(meta_graph, model, schema_version); + if (status != RET_OK) { + delete (model); + MS_LOG(ERROR) << "fail to generate model"; + return nullptr; } if (model->version_ != Version()) { MS_LOG(WARNING) << "model version is " << model->version_ << ", inference version is " << Version() << " not equal"; } - - if (!ConvertNodes(meta_graph, model)) { - delete model; - return nullptr; - } - - if (!ConvertTensors(meta_graph, model)) { - delete model; - return nullptr; - } - - if (meta_graph->subGraph() == nullptr) { - int ret = MetaGraphMappingSubGraph(meta_graph, model); - if (ret != RET_OK) { - MS_LOG(ERROR) << "converter old version model wrong."; - return nullptr; - } - } else { - auto sub_graphs = meta_graph->subGraph(); - auto sub_graph_size = sub_graphs->size(); - for (size_t i = 0; i < sub_graph_size; i++) { - auto sub_graph = sub_graphs->GetAs(i); - int ret = ConvertSubGraph(sub_graph, model); - if (ret != RET_OK) { - MS_LOG(ERROR) << "converter subgraph wrong."; - return nullptr; - } - } - } if (model->sub_graphs_.empty()) { + delete (model); return nullptr; } return model; diff --git a/mindspore/lite/src/model_common.h b/mindspore/lite/src/model_common.h index 2162328e01..54ad92f6fb 100644 --- a/mindspore/lite/src/model_common.h +++ b/mindspore/lite/src/model_common.h @@ -16,17 +16,149 @@ #ifndef MINDSPORE_LITE_SRC_MODEL_COMMON_H_ #define MINDSPORE_LITE_SRC_MODEL_COMMON_H_ + +#include #include "src/ops/primitive_c.h" #include "include/model.h" +#include "include/version.h" +#include "schema/model_generated.h" +#include "src/common/common.h" +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif namespace mindspore::lite { -bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model); +int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model); + +template +bool ConvertNodes(const T &meta_graph, Model *model, int schema_version = 0) { + MS_ASSERT(model != nullptr); + for (size_t i = 0; i < meta_graph.nodes()->size(); ++i) { + auto *node = new (std::nothrow) Model::Node(); + if (node == nullptr) { + MS_LOG(ERROR) << "new node fail!"; + return false; + } + auto c_node = meta_graph.nodes()->template GetAs(i); + auto src_prim = reinterpret_cast(c_node->primitive()); +#ifdef PRIMITIVE_WRITEABLE + node->primitive_ = PrimitiveC::Create(const_cast(src_prim)); +#else + auto primitive = const_cast(src_prim); + node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive); +#endif + if (node->primitive_ == nullptr) { + MS_LOG(ERROR) << "unpack primitive == nullptr!"; + delete node; + return false; + } + node->primitive_->SetQuantType(static_cast(c_node->quantType())); + node->name_ = c_node->name()->c_str(); + node->node_type_ = static_cast(c_node->nodeType()); + auto count = c_node->inputIndex()->size(); + for (uint32_t j = 0; j < count; ++j) { + node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs(j))); + } + if (c_node->outputIndex() != nullptr) { + count = c_node->outputIndex()->size(); + for (uint32_t j = 0; j < count; ++j) { + node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs(j))); + } + } + model->all_nodes_.push_back(node); + } + return true; +} + +template +bool ConvertTensors(const T &meta_graph, Model *model) { + MS_ASSERT(model != nullptr); + auto tensor_count = meta_graph.allTensors()->size(); + for (uint32_t i = 0; i < tensor_count; ++i) { + auto *tensor = meta_graph.allTensors()->template GetAs(i); + if (tensor == nullptr) { + MS_LOG(ERROR) << i << "th tensor in model is nullptr"; + return false; + } + model->all_tensors_.push_back(const_cast(tensor)); + } + return true; +} + +template +int MetaGraphMappingSubGraph(const T &meta_graph, Model *model) { + MS_ASSERT(model != nullptr); + auto *subgraph = new (std::nothrow) Model::SubGraph(); + if (subgraph == nullptr) { + MS_LOG(ERROR) << "new subGraph fail!"; + return RET_ERROR; + } + if (meta_graph.name() != nullptr) { + subgraph->name_ = meta_graph.name()->c_str(); + } + auto in_count = meta_graph.inputIndex()->size(); + for (uint32_t i = 0; i < in_count; ++i) { + subgraph->input_indices_.push_back(size_t(meta_graph.inputIndex()->template GetAs(i))); + } + auto out_count = meta_graph.outputIndex()->size(); + for (uint32_t i = 0; i < out_count; ++i) { + subgraph->output_indices_.push_back(size_t(meta_graph.outputIndex()->template GetAs(i))); + } + auto node_count = meta_graph.nodes()->size(); + for (uint32_t i = 0; i < node_count; ++i) { + subgraph->node_indices_.push_back(i); + } + auto tensor_count = meta_graph.nodes()->size(); + for (uint32_t i = 0; i < tensor_count; ++i) { + subgraph->tensor_indices_.push_back(i); + } + model->sub_graphs_.push_back(subgraph); + return RET_OK; +} + +template +int GenerateModel(const T &meta_graph, Model *model, int schema_version = 0) { + MS_ASSERT(model != nullptr); + if (meta_graph.name() != nullptr) { + model->name_ = meta_graph.name()->c_str(); + } + if (meta_graph.version() != nullptr) { + model->version_ = meta_graph.version()->c_str(); + } + if (!ConvertNodes(meta_graph, model, schema_version)) { + MS_LOG(ERROR) << "convert node failed"; + return RET_ERROR; + } + if (!ConvertTensors(meta_graph, model)) { + MS_LOG(ERROR) << "convert tensor failed"; + return RET_ERROR; + } + if (meta_graph.subGraph() == nullptr) { + int ret = MetaGraphMappingSubGraph(meta_graph, model); + if (ret != RET_OK) { + MS_LOG(ERROR) << "converter old version model wrong."; + return ret; + } + } else { + auto sub_graphs = meta_graph.subGraph(); + auto sub_graph_size = sub_graphs->size(); + for (size_t i = 0; i < sub_graph_size; i++) { + auto sub_graph = sub_graphs->template GetAs(i); + int ret = ConvertSubGraph(*sub_graph, model); + if (ret != RET_OK) { + MS_LOG(ERROR) << "converter subgraph wrong."; + return ret; + } + } + } + return RET_OK; +} -bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model); +int VersionVerify(flatbuffers::Verifier *verify); -int ConvertSubGraph(const schema::SubGraph *sub_graph, Model *model); +const void *GetMetaGraphByVerison(const char *buf, const int &schema_version); -int MetaGraphMappingSubGraph(const mindspore::schema::MetaGraph *meta_graph, Model *model); +int GenerateModelByVersion(const void *meta_graph, Model *model, const int &schema_version); Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf); } // namespace mindspore::lite diff --git a/mindspore/lite/src/train/train_model.cc b/mindspore/lite/src/train/train_model.cc index eeca608eac..3576f9900b 100644 --- a/mindspore/lite/src/train/train_model.cc +++ b/mindspore/lite/src/train/train_model.cc @@ -22,9 +22,6 @@ namespace mindspore::lite { -bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model); -bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model); - TrainModel *TrainModel::Import(const char *model_buf, size_t size) { if (model_buf == nullptr) { MS_LOG(ERROR) << "The model buf is nullptr"; @@ -62,18 +59,18 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) { if (meta_graph->version() != nullptr) { model->version_ = meta_graph->version()->c_str(); } - if (!ConvertNodes(meta_graph, model)) { + if (!ConvertNodes(*meta_graph, model)) { delete model; return nullptr; } - if (!ConvertTensors(meta_graph, model)) { + if (!ConvertTensors(*meta_graph, model)) { delete model; return nullptr; } if (meta_graph->subGraph() == nullptr) { - int ret = MetaGraphMappingSubGraph(meta_graph, model); + int ret = MetaGraphMappingSubGraph(*meta_graph, model); if (ret != RET_OK) { MS_LOG(ERROR) << "converter old version model wrong."; return nullptr; @@ -83,7 +80,7 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) { auto sub_graph_size = sub_graphs->size(); for (size_t i = 0; i < sub_graph_size; i++) { auto sub_graph = sub_graphs->GetAs(i); - int ret = ConvertSubGraph(sub_graph, model); + int ret = ConvertSubGraph(*sub_graph, model); if (ret != RET_OK) { MS_LOG(ERROR) << "converter subgraph wrong."; return nullptr;