|
|
|
@ -14,33 +14,39 @@
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
#include "src/model_common.h"
|
|
|
|
|
#include "src/ops/while.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore::lite {
|
|
|
|
|
int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) {
|
|
|
|
|
MS_ASSERT(model != nullptr);
|
|
|
|
|
if (model == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "model is null.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (sub_graph.name() == nullptr || sub_graph.inputIndices() == nullptr || sub_graph.outputIndices() == nullptr ||
|
|
|
|
|
sub_graph.nodeIndices() == nullptr || sub_graph.tensorIndices() == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "sub_graph is invalid.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto *subgraph = new (std::nothrow) Model::SubGraph();
|
|
|
|
|
if (subgraph == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new subGraph fail!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
MS_ASSERT(sub_graph.name() != nullptr);
|
|
|
|
|
|
|
|
|
|
subgraph->name_ = sub_graph.name()->c_str();
|
|
|
|
|
MS_ASSERT(sub_graph.inputIndices() != nullptr);
|
|
|
|
|
auto in_count = sub_graph.inputIndices()->size();
|
|
|
|
|
for (uint32_t i = 0; i < in_count; ++i) {
|
|
|
|
|
subgraph->input_indices_.push_back(sub_graph.inputIndices()->Get(i));
|
|
|
|
|
}
|
|
|
|
|
MS_ASSERT(sub_graph.outputIndices() != nullptr);
|
|
|
|
|
auto out_count = sub_graph.outputIndices()->size();
|
|
|
|
|
for (uint32_t i = 0; i < out_count; ++i) {
|
|
|
|
|
subgraph->output_indices_.push_back(sub_graph.outputIndices()->Get(i));
|
|
|
|
|
}
|
|
|
|
|
MS_ASSERT(sub_graph.nodeIndices() != nullptr);
|
|
|
|
|
auto node_count = sub_graph.nodeIndices()->size();
|
|
|
|
|
for (uint32_t i = 0; i < node_count; ++i) {
|
|
|
|
|
subgraph->node_indices_.push_back(sub_graph.nodeIndices()->Get(i));
|
|
|
|
|
}
|
|
|
|
|
MS_ASSERT(sub_graph.tensorIndices() != nullptr);
|
|
|
|
|
auto tensor_count = sub_graph.tensorIndices()->size();
|
|
|
|
|
for (uint32_t i = 0; i < tensor_count; ++i) {
|
|
|
|
|
subgraph->tensor_indices_.push_back(sub_graph.tensorIndices()->Get(i));
|
|
|
|
@ -50,6 +56,10 @@ int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int VersionVerify(flatbuffers::Verifier *verify) {
|
|
|
|
|
if (verify == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "verify is null.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (schema::VerifyMetaGraphBuffer(*verify)) {
|
|
|
|
|
return SCHEMA_VERSION::SCHEMA_CUR;
|
|
|
|
|
} else if (schema::v0::VerifyMetaGraphBuffer(*verify)) {
|
|
|
|
@ -58,8 +68,90 @@ int VersionVerify(flatbuffers::Verifier *verify) {
|
|
|
|
|
return SCHEMA_VERSION::SCHEMA_INVALID;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int NodeVerify(const Model &model) {
|
|
|
|
|
auto tensor_size = model.all_tensors_.size();
|
|
|
|
|
uint32_t subGraph_size = model.sub_graphs_.size();
|
|
|
|
|
|
|
|
|
|
for (auto &node : model.all_nodes_) {
|
|
|
|
|
if (node == nullptr || node->primitive_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "node or its primitive_ is null.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (std::any_of(node->input_indices_.begin(), node->input_indices_.end(),
|
|
|
|
|
[&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
|
|
|
|
|
MS_LOG(ERROR) << "Index of node->input_indices_ is beyond size.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (std::any_of(node->output_indices_.begin(), node->output_indices_.end(),
|
|
|
|
|
[&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
|
|
|
|
|
MS_LOG(ERROR) << "Index of node->output_indices_ is beyond size.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto prim = node->primitive_;
|
|
|
|
|
if (prim->Type() == schema::PrimitiveType_While) {
|
|
|
|
|
auto whileOp = reinterpret_cast<mindspore::lite::While *>(const_cast<mindspore::lite::PrimitiveC *>(prim));
|
|
|
|
|
if (whileOp == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "whileOp is null.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (static_cast<uint32_t>(whileOp->GetBodySubgraphIndex()) >= subGraph_size ||
|
|
|
|
|
static_cast<uint32_t>(whileOp->GetCondSubgraphIndex()) >= subGraph_size) {
|
|
|
|
|
MS_LOG(ERROR) << "index of subGraph is beyond subGraph_size.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int SubGraphVerify(const Model &model) {
|
|
|
|
|
auto tensor_size = model.all_tensors_.size();
|
|
|
|
|
auto node_size = model.all_nodes_.size();
|
|
|
|
|
|
|
|
|
|
for (auto &graph : model.sub_graphs_) {
|
|
|
|
|
if (graph == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "graph is null.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (std::any_of(graph->input_indices_.begin(), graph->input_indices_.end(),
|
|
|
|
|
[&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
|
|
|
|
|
MS_LOG(ERROR) << "Index of graph->input_indices_ is beyond tensor_size.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (std::any_of(graph->output_indices_.begin(), graph->output_indices_.end(),
|
|
|
|
|
[&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
|
|
|
|
|
MS_LOG(ERROR) << "Index of graph->output_indices_ is beyond tensor_size.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (std::any_of(graph->tensor_indices_.begin(), graph->tensor_indices_.end(),
|
|
|
|
|
[&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
|
|
|
|
|
MS_LOG(ERROR) << "Index of graph->tensor_indices_ is beyond tensor_size.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (std::any_of(graph->node_indices_.begin(), graph->node_indices_.end(),
|
|
|
|
|
[&node_size](const uint32_t &idx) { return idx >= node_size; })) {
|
|
|
|
|
MS_LOG(ERROR) << "Index of graph->node_indices_ is beyond node_size.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int ModelVerify(const Model &model, const int &schema_version) {
|
|
|
|
|
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) {
|
|
|
|
|
return NodeVerify(model) == RET_OK && SubGraphVerify(model) == RET_OK;
|
|
|
|
|
} else if (schema_version == SCHEMA_VERSION::SCHEMA_V0) {
|
|
|
|
|
return NodeVerify(model) == RET_OK;
|
|
|
|
|
}
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) {
|
|
|
|
|
MS_ASSERT(buf != nullptr);
|
|
|
|
|
if (buf == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "buf is null.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) {
|
|
|
|
|
return reinterpret_cast<const void *>(schema::GetMetaGraph(buf));
|
|
|
|
|
} else if (schema_version == SCHEMA_VERSION::SCHEMA_V0) {
|
|
|
|
@ -69,8 +161,10 @@ const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int GenerateModelByVersion(const void *meta_graph, Model *model, const int &schema_version) {
|
|
|
|
|
MS_ASSERT(meta_graph != nullptr);
|
|
|
|
|
MS_ASSERT(model != nullptr);
|
|
|
|
|
if (meta_graph == nullptr || model == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "meta_graph or model is null.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
int status = RET_ERROR;
|
|
|
|
|
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) {
|
|
|
|
|
status = GenerateModel<schema::MetaGraph, schema::CNode>(*reinterpret_cast<const schema::MetaGraph *>(meta_graph),
|
|
|
|
@ -135,6 +229,7 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
|
|
|
|
|
delete (model);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
return model;
|
|
|
|
|
|
|
|
|
|
return ModelVerify(*model, schema_version) ? model : nullptr;
|
|
|
|
|
}
|
|
|
|
|
} // namespace mindspore::lite
|
|
|
|
|