commit
3e9d95dca1
@ -0,0 +1,44 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_COMMON_VERSION_MANAGER_H_
|
||||
#define MINDSPORE_LITE_SRC_COMMON_VERSION_MANAGER_H_
|
||||
|
||||
#include <string>
|
||||
#include "src/lite_model.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class VersionManager {
|
||||
public:
|
||||
static VersionManager *GetInstance() {
|
||||
static VersionManager instance;
|
||||
return &instance;
|
||||
}
|
||||
virtual ~VersionManager() = default;
|
||||
|
||||
void SetSchemaVersion(const int schema_version) { schema_version_ = schema_version; }
|
||||
int GetSchemaVersion() const { return schema_version_; }
|
||||
|
||||
private:
|
||||
VersionManager() = default;
|
||||
|
||||
private:
|
||||
int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_COMMON_VERSION_MANAGER_H_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,223 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_LITE_MODEL_H_
|
||||
#define MINDSPORE_LITE_SRC_LITE_MODEL_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/model.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "include/version.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/common/common.h"
|
||||
#include "src/common/version_manager.h"
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
#ifdef ENABLE_V0
|
||||
#include "schema/model_v0_generated.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class LiteModel : public Model {
|
||||
public:
|
||||
int ConstructModel();
|
||||
|
||||
bool ModelVerify() const;
|
||||
|
||||
void Free() override;
|
||||
|
||||
~LiteModel() override;
|
||||
|
||||
private:
|
||||
#ifdef ENABLE_V0
|
||||
int ConvertAttrs(Model::Node *node, const schema::v0::Primitive *prim, std::vector<schema::Tensor *> *dst_tensor);
|
||||
|
||||
int ConvertAttrToTensors(const void *meta_graph);
|
||||
#endif
|
||||
|
||||
template <typename T = schema::MetaGraph, typename U = schema::CNode>
|
||||
bool ConvertNodes(const T &meta_graph) {
|
||||
if (meta_graph.nodes() == nullptr) {
|
||||
MS_LOG(ERROR) << "meta_graph is invalid, please check your model file.";
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < meta_graph.nodes()->size(); ++i) {
|
||||
auto *node = new (std::nothrow) Model::Node();
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "new node fail!";
|
||||
return false;
|
||||
}
|
||||
auto c_node = meta_graph.nodes()->template GetAs<U>(i);
|
||||
auto src_prim = reinterpret_cast<const schema::Primitive *>(c_node->primitive());
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim));
|
||||
#else
|
||||
auto primitive = const_cast<schema::Primitive *>(src_prim);
|
||||
auto func_pointer = OpsRegistry::GetInstance()->GetPrimitiveCreator(primitive->value_type());
|
||||
if (func_pointer == nullptr) {
|
||||
MS_LOG(ERROR) << "PrimitiveCreator function pointer is nullptr, type: "
|
||||
<< schema::EnumNamePrimitiveType(primitive->value_type());
|
||||
delete node;
|
||||
return false;
|
||||
}
|
||||
node->primitive_ = func_pointer(primitive);
|
||||
#endif
|
||||
if (node->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "unpack primitive == nullptr!";
|
||||
delete node;
|
||||
return false;
|
||||
}
|
||||
node->primitive_->set_quant_type(static_cast<schema::QuantType>(c_node->quantType()));
|
||||
node->name_ = c_node->name()->c_str();
|
||||
node->node_type_ = static_cast<NodeType>(c_node->nodeType());
|
||||
auto count = c_node->inputIndex()->size();
|
||||
for (uint32_t j = 0; j < count; ++j) {
|
||||
node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j)));
|
||||
}
|
||||
if (c_node->outputIndex() != nullptr) {
|
||||
count = c_node->outputIndex()->size();
|
||||
for (uint32_t j = 0; j < count; ++j) {
|
||||
node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j)));
|
||||
}
|
||||
}
|
||||
this->all_nodes_.push_back(node);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T = schema::MetaGraph>
|
||||
bool ConvertTensors(const T &meta_graph) {
|
||||
if (meta_graph.allTensors() == nullptr) {
|
||||
MS_LOG(ERROR) << "meta_graph is invalid, please check your model file.";
|
||||
return false;
|
||||
}
|
||||
auto tensor_count = meta_graph.allTensors()->size();
|
||||
for (uint32_t i = 0; i < tensor_count; ++i) {
|
||||
auto *tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(i);
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << i << "the tensor in metagraph is nullptr";
|
||||
return false;
|
||||
}
|
||||
this->all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T = schema::MetaGraph>
|
||||
int MetaGraphMappingSubGraph(const T &meta_graph) {
|
||||
if (meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr || meta_graph.nodes() == nullptr ||
|
||||
meta_graph.allTensors() == nullptr) {
|
||||
MS_LOG(ERROR) << "meta_graph is invalid, please check your model file.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto *subgraph = new (std::nothrow) Model::SubGraph();
|
||||
if (subgraph == nullptr) {
|
||||
MS_LOG(ERROR) << "new subGraph fail!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (meta_graph.name() != nullptr) {
|
||||
subgraph->name_ = meta_graph.name()->c_str();
|
||||
}
|
||||
auto in_count = meta_graph.inputIndex()->size();
|
||||
for (uint32_t i = 0; i < in_count; ++i) {
|
||||
subgraph->input_indices_.push_back(size_t(meta_graph.inputIndex()->template GetAs<uint32_t>(i)));
|
||||
}
|
||||
auto out_count = meta_graph.outputIndex()->size();
|
||||
for (uint32_t i = 0; i < out_count; ++i) {
|
||||
subgraph->output_indices_.push_back(size_t(meta_graph.outputIndex()->template GetAs<uint32_t>(i)));
|
||||
}
|
||||
auto node_count = meta_graph.nodes()->size();
|
||||
for (uint32_t i = 0; i < node_count; ++i) {
|
||||
subgraph->node_indices_.push_back(i);
|
||||
}
|
||||
auto tensor_count = meta_graph.allTensors()->size();
|
||||
for (uint32_t i = 0; i < tensor_count; ++i) {
|
||||
subgraph->tensor_indices_.push_back(i);
|
||||
}
|
||||
this->sub_graphs_.push_back(subgraph);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
template <typename T = schema::MetaGraph, typename U = schema::CNode>
|
||||
int GenerateModel(const T &meta_graph) {
|
||||
if (meta_graph.name() != nullptr) {
|
||||
this->name_ = meta_graph.name()->c_str();
|
||||
}
|
||||
if (meta_graph.version() != nullptr) {
|
||||
this->version_ = meta_graph.version()->c_str();
|
||||
}
|
||||
if (!ConvertNodes<T, U>(meta_graph)) {
|
||||
MS_LOG(ERROR) << "convert node failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!ConvertTensors<T>(meta_graph)) {
|
||||
MS_LOG(ERROR) << "convert tensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (meta_graph.subGraph() == nullptr) {
|
||||
int ret = MetaGraphMappingSubGraph<T>(meta_graph);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "converter old version model wrong.";
|
||||
return ret;
|
||||
}
|
||||
} else {
|
||||
auto sub_graphs = meta_graph.subGraph();
|
||||
auto sub_graph_size = sub_graphs->size();
|
||||
for (size_t i = 0; i < sub_graph_size; i++) {
|
||||
auto sub_graph = sub_graphs->template GetAs<schema::SubGraph>(i);
|
||||
int ret = ConvertSubGraph(*sub_graph);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "converter subgraph wrong.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef ENABLE_V0
|
||||
if (ConvertAttrToTensors(&meta_graph) != RET_OK) {
|
||||
MS_LOG(ERROR) << "fail to convert attr to tensor.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int VersionVerify(flatbuffers::Verifier *verify) const;
|
||||
|
||||
const void *GetMetaGraphByVerison();
|
||||
|
||||
int GenerateModelByVersion(const void *meta_graph);
|
||||
|
||||
int ConvertSubGraph(const schema::SubGraph &sub_graph);
|
||||
|
||||
int NodeVerify() const;
|
||||
|
||||
int SubGraphVerify() const;
|
||||
|
||||
public:
|
||||
size_t buf_size_ = 0;
|
||||
|
||||
protected:
|
||||
std::vector<char *> attr_tensor_bufs_;
|
||||
};
|
||||
|
||||
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_LITE_MODEL_H_
|
@ -1,52 +0,0 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "include/model.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/model_common.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); }
|
||||
|
||||
void Model::Free() {
|
||||
if (this->buf != nullptr) {
|
||||
free(this->buf);
|
||||
this->buf = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void Model::Destroy() {
|
||||
Free();
|
||||
auto nodes_size = this->all_nodes_.size();
|
||||
for (size_t i = 0; i < nodes_size; ++i) {
|
||||
auto node = this->all_nodes_[i];
|
||||
MS_ASSERT(node != nullptr);
|
||||
MS_ASSERT(node->primitive_ != nullptr);
|
||||
delete node->primitive_;
|
||||
node->primitive_ = nullptr;
|
||||
delete node;
|
||||
}
|
||||
this->all_nodes_.clear();
|
||||
|
||||
auto sub_graph_size = this->sub_graphs_.size();
|
||||
for (size_t i = 0; i < sub_graph_size; ++i) {
|
||||
auto sub_graph = this->sub_graphs_[i];
|
||||
delete sub_graph;
|
||||
}
|
||||
}
|
||||
|
||||
Model::~Model() { Destroy(); }
|
||||
} // namespace mindspore::lite
|
@ -1,192 +0,0 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_MODEL_COMMON_H_
|
||||
#define MINDSPORE_LITE_SRC_MODEL_COMMON_H_
|
||||
|
||||
#include <string>
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "include/model.h"
|
||||
#include "include/version.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "schema/model_v0_generated.h"
|
||||
#include "src/common/common.h"
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore::lite {
|
||||
int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model);
|
||||
|
||||
template <typename T = schema::MetaGraph, typename U = schema::CNode>
|
||||
bool ConvertNodes(const T &meta_graph, Model *model, int schema_version = SCHEMA_CUR) {
|
||||
if (model == nullptr || meta_graph.nodes() == nullptr) {
|
||||
MS_LOG(ERROR) << "model or meta_graph is invalid, please check your model file.";
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < meta_graph.nodes()->size(); ++i) {
|
||||
auto *node = new (std::nothrow) Model::Node();
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "new node fail!";
|
||||
return false;
|
||||
}
|
||||
auto c_node = meta_graph.nodes()->template GetAs<U>(i);
|
||||
auto src_prim = reinterpret_cast<const schema::Primitive *>(c_node->primitive());
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim));
|
||||
#else
|
||||
auto primitive = const_cast<schema::Primitive *>(src_prim);
|
||||
auto func_pointer = OpsRegistry::GetInstance()->GetPrimitiveCreator(primitive->value_type());
|
||||
if (func_pointer == nullptr) {
|
||||
MS_LOG(ERROR) << "PrimitiveCreator function pointer is nullptr, type: "
|
||||
<< schema::EnumNamePrimitiveType(primitive->value_type());
|
||||
delete node;
|
||||
return false;
|
||||
}
|
||||
node->primitive_ = func_pointer(primitive);
|
||||
#endif
|
||||
if (node->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "unpack primitive == nullptr!";
|
||||
delete node;
|
||||
return false;
|
||||
}
|
||||
node->primitive_->set_quant_type(static_cast<schema::QuantType>(c_node->quantType()));
|
||||
node->name_ = c_node->name()->c_str();
|
||||
node->node_type_ = static_cast<NodeType>(c_node->nodeType());
|
||||
auto count = c_node->inputIndex()->size();
|
||||
for (uint32_t j = 0; j < count; ++j) {
|
||||
node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j)));
|
||||
}
|
||||
if (c_node->outputIndex() != nullptr) {
|
||||
count = c_node->outputIndex()->size();
|
||||
for (uint32_t j = 0; j < count; ++j) {
|
||||
node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j)));
|
||||
}
|
||||
}
|
||||
model->all_nodes_.push_back(node);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T = schema::MetaGraph>
|
||||
bool ConvertTensors(const T &meta_graph, Model *model) {
|
||||
if (model == nullptr || meta_graph.allTensors() == nullptr) {
|
||||
MS_LOG(ERROR) << "model or meta_graph is invalid, please check your model file.";
|
||||
return false;
|
||||
}
|
||||
auto tensor_count = meta_graph.allTensors()->size();
|
||||
for (uint32_t i = 0; i < tensor_count; ++i) {
|
||||
auto *tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(i);
|
||||
if (tensor == nullptr) {
|
||||
MS_LOG(ERROR) << i << "th tensor in model is nullptr";
|
||||
return false;
|
||||
}
|
||||
model->all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T = schema::MetaGraph>
|
||||
int MetaGraphMappingSubGraph(const T &meta_graph, Model *model) {
|
||||
if (model == nullptr || meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr ||
|
||||
meta_graph.nodes() == nullptr || meta_graph.allTensors() == nullptr) {
|
||||
MS_LOG(ERROR) << "model or meta_graph is invalid, please check your model file.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto *subgraph = new (std::nothrow) Model::SubGraph();
|
||||
if (subgraph == nullptr) {
|
||||
MS_LOG(ERROR) << "new subGraph fail!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (meta_graph.name() != nullptr) {
|
||||
subgraph->name_ = meta_graph.name()->c_str();
|
||||
}
|
||||
auto in_count = meta_graph.inputIndex()->size();
|
||||
for (uint32_t i = 0; i < in_count; ++i) {
|
||||
subgraph->input_indices_.push_back(size_t(meta_graph.inputIndex()->template GetAs<uint32_t>(i)));
|
||||
}
|
||||
auto out_count = meta_graph.outputIndex()->size();
|
||||
for (uint32_t i = 0; i < out_count; ++i) {
|
||||
subgraph->output_indices_.push_back(size_t(meta_graph.outputIndex()->template GetAs<uint32_t>(i)));
|
||||
}
|
||||
auto node_count = meta_graph.nodes()->size();
|
||||
for (uint32_t i = 0; i < node_count; ++i) {
|
||||
subgraph->node_indices_.push_back(i);
|
||||
}
|
||||
auto tensor_count = meta_graph.allTensors()->size();
|
||||
for (uint32_t i = 0; i < tensor_count; ++i) {
|
||||
subgraph->tensor_indices_.push_back(i);
|
||||
}
|
||||
model->sub_graphs_.push_back(subgraph);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
template <typename T = schema::MetaGraph, typename U = schema::CNode>
|
||||
int GenerateModel(const T &meta_graph, Model *model, int schema_version = 0) {
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "model is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (meta_graph.name() != nullptr) {
|
||||
model->name_ = meta_graph.name()->c_str();
|
||||
}
|
||||
if (meta_graph.version() != nullptr) {
|
||||
model->version_ = meta_graph.version()->c_str();
|
||||
}
|
||||
if (!ConvertNodes<T, U>(meta_graph, model, schema_version)) {
|
||||
MS_LOG(ERROR) << "convert node failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!ConvertTensors<T>(meta_graph, model)) {
|
||||
MS_LOG(ERROR) << "convert tensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (meta_graph.subGraph() == nullptr) {
|
||||
int ret = MetaGraphMappingSubGraph<T>(meta_graph, model);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "converter old version model wrong.";
|
||||
return ret;
|
||||
}
|
||||
} else {
|
||||
auto sub_graphs = meta_graph.subGraph();
|
||||
auto sub_graph_size = sub_graphs->size();
|
||||
for (size_t i = 0; i < sub_graph_size; i++) {
|
||||
auto sub_graph = sub_graphs->template GetAs<schema::SubGraph>(i);
|
||||
int ret = ConvertSubGraph(*sub_graph, model);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "converter subgraph wrong.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int VersionVerify(flatbuffers::Verifier *verify);
|
||||
|
||||
int NodeVerify(const Model &model);
|
||||
|
||||
int SubGraphVerify(const Model &model);
|
||||
|
||||
bool ModelVerify(const Model &model);
|
||||
|
||||
const void *GetMetaGraphByVerison(const char *buf, const int &schema_version);
|
||||
|
||||
int GenerateModelByVersion(const void *meta_graph, Model *model, const int &schema_version);
|
||||
|
||||
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf);
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_MODEL_COMMON_H_
|
@ -0,0 +1,65 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/compat/attr_transfer_common.h"
|
||||
#include <vector>
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
schema::Tensor *AttrToTensor(void *data, int data_size, bool is_array, TypeId type_id,
|
||||
std::vector<char *> *tensor_bufs) {
|
||||
if (data == nullptr || tensor_bufs == nullptr) {
|
||||
MS_LOG(ERROR) << "the parameter of this function is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto dst_tensor =
|
||||
(is_array ? new (std::nothrow) Tensor(type_id, {data_size}, schema::Format_NHWC, Tensor::Category::CONST_TENSOR)
|
||||
: new (std::nothrow) Tensor(type_id, {}, schema::Format_NHWC, Tensor::Category::CONST_SCALAR));
|
||||
auto dst_data = dst_tensor->MutableData();
|
||||
if (dst_data == nullptr) {
|
||||
MS_LOG(ERROR) << "Data from tensor is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<uint8_t> uint8_data;
|
||||
uint8_data.resize(dst_tensor->Size());
|
||||
memcpy(uint8_data.data(), data, dst_tensor->Size());
|
||||
auto shape = dst_tensor->shape();
|
||||
flatbuffers::FlatBufferBuilder fbb(1024);
|
||||
auto tensor_offset = schema::CreateTensorDirect(fbb, schema::NodeType_ValueNode, type_id, &shape, schema::Format_NHWC,
|
||||
0, 0, &uint8_data);
|
||||
fbb.Finish(tensor_offset);
|
||||
delete dst_tensor;
|
||||
auto buf = fbb.GetBufferPointer();
|
||||
if (buf == nullptr) {
|
||||
MS_LOG(ERROR) << "GetBufferPointer return nullptr";
|
||||
fbb.Clear();
|
||||
return nullptr;
|
||||
}
|
||||
auto tensor_buf = reinterpret_cast<char *>(malloc(fbb.GetSize()));
|
||||
if (tensor_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc primitive_buf_ failed";
|
||||
fbb.Clear();
|
||||
return nullptr;
|
||||
}
|
||||
memcpy(tensor_buf, buf, fbb.GetSize());
|
||||
auto tensor = flatbuffers::GetRoot<schema::Tensor>(tensor_buf);
|
||||
tensor_bufs->push_back(tensor_buf);
|
||||
fbb.Clear();
|
||||
return const_cast<schema::Tensor *>(tensor);
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,35 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_OP_ATTR_TRANSFER_COMMON_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_OP_ATTR_TRANSFER_COMMON_H_
|
||||
|
||||
#include <vector>
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "src/tensor.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/model_v0_generated.h"
|
||||
#include "src/common/common.h"
|
||||
#include "src/ops/compat/compat_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
schema::Tensor *AttrToTensor(void *data, int data_size, bool is_array, TypeId type_id,
|
||||
std::vector<char *> *tensor_bufs);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_OP_ATTR_TRANSFER_COMMON_H_
|
@ -0,0 +1,67 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_OP_COMPAT_REGISTER_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_OP_COMPAT_REGISTER_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/model.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
// compatibility, transfer attr to input tensor.
|
||||
typedef int (*TransferAttrFunc)(const void *primitive, Model::Node *node, std::vector<schema::Tensor *> *tensor,
|
||||
std::vector<char *> *tensor_bufs);
|
||||
class CompatRegistry {
|
||||
public:
|
||||
static CompatRegistry *GetInstance() {
|
||||
static CompatRegistry registry;
|
||||
return ®istry;
|
||||
}
|
||||
|
||||
void InsertTransferAttrFuncMap(int schema_version, int primitive_type, TransferAttrFunc transfer_attr_func) {
|
||||
int key = primitive_type * 10 + schema_version;
|
||||
transfer_attr_funcs_[key] = transfer_attr_func;
|
||||
}
|
||||
|
||||
TransferAttrFunc GetTransferAttrFunc(int schema_version, int primitive_type) {
|
||||
int key = primitive_type * 10 + schema_version;
|
||||
if (transfer_attr_funcs_.find(key) != transfer_attr_funcs_.end()) {
|
||||
return transfer_attr_funcs_[key];
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Unsupported transformer type in Create : " << key;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
std::unordered_map<int, TransferAttrFunc> transfer_attr_funcs_;
|
||||
};
|
||||
|
||||
class Register {
|
||||
public:
|
||||
Register(int schema_version, int primitive_type, TransferAttrFunc transfer_attr_func) {
|
||||
CompatRegistry::GetInstance()->InsertTransferAttrFuncMap(schema_version, primitive_type, transfer_attr_func);
|
||||
}
|
||||
virtual ~Register() = default;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_OP_COMPAT_REGISTER_H_
|
@ -0,0 +1,48 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/compat/attr_transfer_common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
int TransferBroadcastToAttr(const void *primitive, Model::Node *node, std::vector<schema::Tensor *> *dst_tensors,
|
||||
std::vector<char *> *tensor_bufs) {
|
||||
if (primitive == nullptr || node == nullptr || dst_tensors == nullptr || tensor_bufs == nullptr) {
|
||||
MS_LOG(ERROR) << "the parameter of this function is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (node->input_indices_.size() != 1) {
|
||||
MS_LOG(DEBUG) << "broadcast_to don't need to convert attr to tensor.";
|
||||
return RET_OK;
|
||||
}
|
||||
dst_tensors->clear();
|
||||
tensor_bufs->clear();
|
||||
auto prim = reinterpret_cast<const schema::v0::Primitive *>(primitive);
|
||||
auto dst_shape_attr = prim->value_as_BroadcastTo()->dst_shape();
|
||||
std::vector<int> dst_shape = std::vector<int>(dst_shape_attr->begin(), dst_shape_attr->end());
|
||||
auto dst_shape_tensor = AttrToTensor(dst_shape.data(), dst_shape.size(), true, kNumberTypeInt32, tensor_bufs);
|
||||
if (dst_shape_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "attr tensor is nullptr, transform is failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
dst_tensors->push_back(dst_shape_tensor);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
Register BroadcastToTransferRegistry(SCHEMA_VERSION::SCHEMA_V0, schema::v0::PrimitiveType_BroadcastTo,
|
||||
TransferBroadcastToAttr);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/compat/attr_transfer_common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
int TransferReshapeAttr(const void *primitive, Model::Node *node, std::vector<schema::Tensor *> *dst_tensors,
|
||||
std::vector<char *> *tensor_bufs) {
|
||||
if (primitive == nullptr || node == nullptr || dst_tensors == nullptr || tensor_bufs == nullptr) {
|
||||
MS_LOG(ERROR) << "the parameter of this function is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (node->input_indices_.size() != 1) {
|
||||
MS_LOG(DEBUG) << "reshape need to convert attr to tensor.";
|
||||
return RET_OK;
|
||||
}
|
||||
dst_tensors->clear();
|
||||
tensor_bufs->clear();
|
||||
auto prim = reinterpret_cast<const schema::v0::Primitive *>(primitive);
|
||||
auto dst_shape_attr = prim->value_as_Reshape()->shape();
|
||||
std::vector<int> dst_shape = std::vector<int>(dst_shape_attr->begin(), dst_shape_attr->end());
|
||||
auto dst_shape_tensor = AttrToTensor(dst_shape.data(), dst_shape.size(), true, kNumberTypeInt32, tensor_bufs);
|
||||
if (dst_shape_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "attr tensor is nullptr, transform is failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
dst_tensors->push_back(dst_shape_tensor);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
Register ReshapeTransferRegistry(SCHEMA_VERSION::SCHEMA_V0, schema::v0::PrimitiveType_Reshape, TransferReshapeAttr);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,67 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/compat/attr_transfer_common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
int TransferStridedSliceAttr(const void *primitive, Model::Node *node, std::vector<schema::Tensor *> *dst_tensors,
|
||||
std::vector<char *> *tensor_bufs) {
|
||||
if (primitive == nullptr || node == nullptr || dst_tensors == nullptr || tensor_bufs == nullptr) {
|
||||
MS_LOG(ERROR) << "the parameter of this function is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
dst_tensors->clear();
|
||||
tensor_bufs->clear();
|
||||
auto prim = reinterpret_cast<const schema::v0::Primitive *>(primitive);
|
||||
int inputs_size = node->input_indices_.size();
|
||||
switch (inputs_size) {
|
||||
case 1: {
|
||||
auto begins_attr = prim->value_as_StridedSlice()->begin();
|
||||
std::vector<int> dst_begins = std::vector<int>(begins_attr->begin(), begins_attr->end());
|
||||
auto dst_begins_tensor = AttrToTensor(dst_begins.data(), dst_begins.size(), true, kNumberTypeInt32, tensor_bufs);
|
||||
dst_tensors->push_back(dst_begins_tensor);
|
||||
}
|
||||
case 2: {
|
||||
auto ends_attr = prim->value_as_StridedSlice()->end();
|
||||
std::vector<int> dst_ends = std::vector<int>(ends_attr->begin(), ends_attr->end());
|
||||
auto dst_ends_tensor = AttrToTensor(dst_ends.data(), dst_ends.size(), true, kNumberTypeInt32, tensor_bufs);
|
||||
dst_tensors->push_back(dst_ends_tensor);
|
||||
}
|
||||
case 3: {
|
||||
auto strides_attr = prim->value_as_StridedSlice()->stride();
|
||||
std::vector<int> dst_strides = std::vector<int>(strides_attr->begin(), strides_attr->end());
|
||||
auto dst_strides_tensor =
|
||||
AttrToTensor(dst_strides.data(), dst_strides.size(), true, kNumberTypeInt32, tensor_bufs);
|
||||
dst_tensors->push_back(dst_strides_tensor);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
MS_LOG(DEBUG) << "stride_slice don't need to convert attr to tensor.";
|
||||
return RET_OK;
|
||||
}
|
||||
}
|
||||
if (std::any_of(dst_tensors->begin(), dst_tensors->end(), [](schema::Tensor *tensor) { return tensor == nullptr; })) {
|
||||
MS_LOG(ERROR) << "convert attr to tensor failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
Register StridedSliceTransferRegistry(SCHEMA_VERSION::SCHEMA_V0, schema::v0::PrimitiveType_StridedSlice,
|
||||
TransferStridedSliceAttr);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
Loading…
Reference in new issue