free meta_graph after compile graph

pull/5369/head
hangq 5 years ago
parent 9d8fb786cb
commit 194253635d

@ -61,5 +61,22 @@ std::vector<size_t> GetGraphOutputNodes(const schema::MetaGraph *meta_graph) {
}
return ret;
}
std::vector<size_t> GetLinkedPostNodeIdx(const schema::MetaGraph &graph, const size_t &tensor_idx) {
std::vector<size_t> post_node_idxes;
for (size_t i = 0; i < graph.nodes()->size(); i++) {
auto node = graph.nodes()->GetAs<schema::CNode>(i);
if (node == nullptr) {
continue;
}
auto node_input_idxes = node->inputIndex();
auto is_contain = std::any_of(node_input_idxes->begin(), node_input_idxes->end(),
[&](const uint32_t &node_input_idx) { return node_input_idx == tensor_idx; });
if (is_contain) {
post_node_idxes.emplace_back(i);
}
}
return post_node_idxes;
}
} // namespace lite
} // namespace mindspore

@ -34,215 +34,8 @@ std::vector<size_t> GetGraphInputNodes(const schema::MetaGraph *meta_graph);
std::vector<size_t> GetGraphOutputNodes(const schema::MetaGraph *meta_graph);
class OpNode {
public:
explicit OpNode(const NODE_ID &nodeId) : id(nodeId) {}
NODE_ID ID() { return id; };
void AddInEdge(NODE_ID nodeId) { inEdges.insert(nodeId); }
void AddOutEdge(NODE_ID nodeId) { outEdges.insert(nodeId); }
std::unordered_set<NODE_ID> GetAllInEdges() { return inEdges; }
std::unordered_set<NODE_ID> GetAllOutEdges() { return outEdges; }
protected:
NODE_ID id;
std::unordered_set<NODE_ID> inEdges;
std::unordered_set<NODE_ID> outEdges;
};
template <typename NODE_T>
class OpGraph {
public:
OpGraph() {}
~OpGraph();
int Build(const schema::MetaGraph *subGraphDef);
NODE_T *GetNode(NODE_ID nodeId);
NODE_T *AddNode(NODE_ID nodeId);
std::unordered_set<NODE_T *> GetInputNode();
std::unordered_set<NODE_T *> GetOutputNode();
void AddNodes(std::vector<NODE_T *> addNodes);
void DeleteNodes(std::vector<NODE_T *> deleteNodes);
void AddEdge(NODE_ID nodeId);
int AddEdge(NODE_ID srcId, NODE_ID dstId);
int AddEdge(const schema::CNode *srcNodeDef, const flatbuffers::Vector<flatbuffers::Offset<schema::CNode>> *opDefs);
std::unordered_map<NODE_T *, std::unordered_set<NODE_T *>> GetDepends();
protected:
std::unordered_map<NODE_ID, NODE_T *> nodes;
};
template <typename NODE_T>
int OpGraph<NODE_T>::Build(const schema::MetaGraph *subGraphDef) {
if (subGraphDef == nullptr) {
// MS_LOGE("subGraphDef is nullptr");
return RET_ERROR;
}
auto opDefs = subGraphDef->nodes();
uint32_t opCount = opDefs->size();
for (uint32_t i = 0; i < opCount; i++) {
auto opDef = opDefs->GetAs<schema::CNode>(i);
auto node = AddNode(std::string(opDef->name()->c_str()));
if (node == nullptr) {
// MS_LOGE("add srcNode failed,name %s", opDef->name()->c_str());
return RET_ERROR;
}
auto ret = AddEdge(opDef, opDefs);
if (ret != RET_OK) {
// MS_LOGE("%s add edge failed. ret:%d", opDef->name()->c_str(), ret);
return RET_ERROR;
}
}
return RET_OK;
}
template <typename NODE_T>
int OpGraph<NODE_T>::AddEdge(const schema::CNode *srcNodeDef,
const flatbuffers::Vector<flatbuffers::Offset<schema::CNode>> *nodeDefs) {
MS_ASSERT(srcNodeDef != nullptr);
MS_ASSERT(nodeDefs != nullptr);
NODE_ID srcId = std::string(srcNodeDef->name()->c_str());
uint32_t opCount = nodeDefs->size();
// for single op condition
AddNode(srcId);
for (auto index : *(srcNodeDef->outputIndex())) {
for (uint32_t i = 0; i < opCount; i++) {
auto dstNodeDef = nodeDefs->GetAs<schema::CNode>(i);
bool find = false;
auto inputIndex = dstNodeDef->inputIndex();
if (std::any_of(inputIndex->begin(), inputIndex->end(), [&index](int i) { return i == index; })) {
find = true;
}
if (!find) {
continue;
}
NODE_ID dstId = std::string(dstNodeDef->name()->c_str());
auto ret = AddEdge(srcId, dstId);
if (ret != RET_OK) {
return ret;
}
}
}
return RET_OK;
}
template <typename NODE_T>
int OpGraph<NODE_T>::AddEdge(NODE_ID srcId, NODE_ID dstId) {
auto srcNode = AddNode(srcId);
if (srcNode == nullptr) {
// MS_LOGE("add srcNode failed");
return RET_ERROR;
}
auto dstNode = AddNode(dstId);
if (dstNode == nullptr) {
// MS_LOGE("add dstNode failed");
return RET_ERROR;
}
srcNode->AddOutEdge(dstNode);
dstNode->AddInEdge(srcNode);
return RET_OK;
}
template <typename NODE_T>
NODE_T *OpGraph<NODE_T>::GetNode(NODE_ID nodeId) {
auto node = nodes.find(nodeId);
if (node == nodes.end()) {
return nullptr;
}
return node->second;
}
template <typename NODE_T>
NODE_T *OpGraph<NODE_T>::AddNode(NODE_ID nodeId) {
auto node = GetNode(nodeId);
if (node != nullptr) {
return node;
}
node = new (std::nothrow) NODE_T(nodeId);
if (node == nullptr) {
// MS_LOGE("new node failed");
return nullptr;
}
nodes[nodeId] = node;
return node;
}
template <typename NODE_T>
void OpGraph<NODE_T>::AddNodes(std::vector<NODE_T *> addNodes) {
for (auto node : addNodes) {
if (node == nullptr) {
return;
}
nodes[node->ID()] = node;
}
}
template <typename NODE_T>
void OpGraph<NODE_T>::DeleteNodes(std::vector<NODE_T *> deleteNodes) {
for (auto deletenode : deleteNodes) {
if (deletenode == nullptr) {
continue;
}
auto node = GetNode(deletenode->ID());
if (node == nullptr) {
continue;
}
nodes.erase(deletenode->ID());
}
}
template <typename NODE_T>
std::unordered_set<NODE_T *> OpGraph<NODE_T>::GetInputNode() {
std::unordered_set<NODE_T *> inputNodes;
for (const auto &iter : nodes) {
auto node = iter.second;
if (node->GetAllInEdges().empty()) {
inputNodes.insert(node);
}
}
return inputNodes;
}
template <typename NODE_T>
std::unordered_set<NODE_T *> OpGraph<NODE_T>::GetOutputNode() {
std::unordered_set<NODE_T *> outputNodes;
for (const auto &iter : nodes) {
auto node = iter.second;
if (node->GetAllOutEdges().empty()) {
outputNodes.insert(node);
}
}
return outputNodes;
}
template <typename NODE_T>
std::unordered_map<NODE_T *, std::unordered_set<NODE_T *>> OpGraph<NODE_T>::GetDepends() {
std::unordered_map<NODE_T *, std::unordered_set<NODE_T *>> depends;
for (auto nodeIter : nodes) {
depends[nodeIter.second] = nodeIter.second->GetAllInEdges();
}
return depends;
}
template <typename NODE_T>
OpGraph<NODE_T>::~OpGraph() {
for (auto iter : nodes) {
delete iter.second;
}
nodes.clear();
}
std::vector<size_t> GetLinkedPostNodeIdx(const schema::MetaGraph &graph, const size_t &tensor_idx);
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_COMMON_GRAPH_UTIL_H_

@ -32,10 +32,29 @@
namespace mindspore {
namespace lite {
static std::vector<schema::PrimitiveType> packed_op = {
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_MatMul};
// this method will not check whether tensor_idx is a weight tensor index, caller should ensure this.
static bool WeightTensorNeedCopy(const lite::Model *model, const uint32_t tensor_idx) {
MS_ASSERT(nullptr != model);
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(nullptr != meta_graph);
auto post_node_idxes = GetLinkedPostNodeIdx(*meta_graph, tensor_idx);
return std::none_of(post_node_idxes.begin(), post_node_idxes.end(), [&](const size_t &post_node_idx) {
auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(post_node_idx);
MS_ASSERT(cNode != nullptr);
return IsContain(packed_op, cNode->primitive()->value_type());
});
}
int LiteSession::ConvertTensors(const lite::Model *model) {
MS_ASSERT(nullptr != model);
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(nullptr != meta_graph);
copyed_tensor_idxes_.clear();
uint32_t tensorCount = meta_graph->allTensors()->size();
for (uint32_t i = 0; i < tensorCount; i++) {
auto *srcTensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i);
@ -54,16 +73,30 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
}
}
int dataType = srcTensor->dataType();
auto *dstTensor = new tensor::Tensor(TypeId(dataType), shape, srcTensor->format(), srcTensor->nodeType());
auto *dstTensor =
new (std::nothrow) tensor::Tensor(TypeId(dataType), shape, srcTensor->format(), srcTensor->nodeType());
if (dstTensor == nullptr) {
MS_LOG(ERROR) << "new " << i << "th tensor failed";
return RET_NULL_PTR;
}
if (srcTensor->nodeType() == schema::NodeType_ValueNode && srcTensor->data() != nullptr &&
srcTensor->data()->size() > 0) {
if (shape.empty()) {
shape.push_back(1);
dstTensor->set_shape(shape);
}
MS_ASSERT(dstTensor != nullptr);
MS_ASSERT(dstTensor->Size() == srcTensor->data()->size());
// no copy data, do copy when call LiteKernel::Init
dstTensor->SetData(const_cast<unsigned char *>(srcTensor->data()->data()));
if (WeightTensorNeedCopy(model, i)) {
auto ret = dstTensor->MallocData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Malloc data for " << i << "th tensor failed";
return RET_ERROR;
}
memcpy(dstTensor->Data(), srcTensor->data()->data(), dstTensor->Size());
copyed_tensor_idxes_.emplace_back(i);
} else {
dstTensor->SetData(const_cast<unsigned char *>(srcTensor->data()->data()));
}
}
auto quant_params = srcTensor->quantParams();
if (quant_params != nullptr) {
@ -74,7 +107,6 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
dstTensor->AddQuantParam(quant_arg);
}
}
this->tensors_.emplace_back(dstTensor);
}
@ -240,6 +272,7 @@ int LiteSession::CompileGraph(Model *model) {
}
executor->Prepare(this->kernels_);
model->FreeMetaGraph();
return RET_OK;
}
@ -277,7 +310,10 @@ int LiteSession::Init(Context *context) {
}
#endif
executor = new Executor();
MS_ASSERT(nullptr != executor);
if (nullptr == executor) {
MS_LOG(ERROR) << "new Executor failed";
return RET_ERROR;
}
return RET_OK;
}
@ -288,9 +324,12 @@ void LiteSession::BindThread(bool if_bind) {
}
LiteSession::~LiteSession() {
for (auto *tensor : tensors_) {
// weight data can not be to free, we will free weight data when freeing meta_graph
if (tensor->TensorType() == schema::NodeType_ValueNode && !IsContain(this->inputs_, tensor)) {
for (size_t i = 0; i < tensors_.size(); i++) {
auto *tensor = tensors_.at(i);
MS_ASSERT(tensor != nullptr);
// data of weight tensor of node in packed_op can not be to free, we will free weight data when freeing meta_graph
if (tensor->TensorType() == schema::NodeType_ValueNode && !IsContain(this->inputs_, tensor) &&
!IsContain(copyed_tensor_idxes_, i)) {
tensor->SetData(nullptr);
}
delete tensor;

@ -87,6 +87,7 @@ class LiteSession : public session::LiteSession {
Context *context_ = nullptr;
std::vector<kernel::LiteKernel *> kernels_;
std::vector<tensor::Tensor *> tensors_;
std::vector<size_t> copyed_tensor_idxes_;
// graph input tensors
std::vector<tensor::Tensor *> inputs_;
// graph output tensors

@ -135,7 +135,7 @@ mindspore::lite::PrimitiveC *Model::GetOp(const std::string &name) const {
void Model::FreeMetaGraph() {
MS_ASSERT(nullptr != model_impl_);
return model_impl_->FreeMetaGraph();
model_impl_->FreeMetaGraph();
}
const schema::MetaGraph *Model::GetMetaGraph() const {

@ -0,0 +1,32 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/abs.h"
namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
int Abs::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateAbs(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Abs, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
} // namespace lite
} // namespace mindspore

@ -32,27 +32,9 @@ class Abs : public ArithmeticSelf {
Abs() = default;
explicit Abs(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#else
explicit Abs(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto val_offset = schema::CreateAbs(fbb);
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Abs, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Abs() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
};
} // namespace lite

@ -55,7 +55,19 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
return RET_OK;
}
#else
int Activation::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Activation();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Activation return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateActivation(*fbb, attr->type(), attr->alpha());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Activation, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int Activation::GetType() const { return this->primitive_->value_as_Activation()->type(); }
float Activation::GetAlpha() const { return this->primitive_->value_as_Activation()->alpha(); }
#endif

@ -30,34 +30,13 @@ class Activation : public PrimitiveC {
MS_DECLARE_PARENT(Activation, PrimitiveC);
Activation() = default;
explicit Activation(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetType(int type);
void SetAlpha(float alpha);
#else
explicit Activation(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Activation();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateActivation(fbb, attr->type(), attr->alpha());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Activation, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Activation() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int GetType() const;
float GetAlpha() const;

@ -26,7 +26,19 @@ void ActivationGrad::SetType(int type) {
}
#else
int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_ActivationGrad();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_ActivationGrad return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateActivationGrad(*fbb, attr->type());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ActivationGrad, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); }
#endif

@ -33,30 +33,9 @@ class ActivationGrad : public PrimitiveC {
explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetType(int type);
#else
explicit ActivationGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_ActivationGrad();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateActivationGrad(fbb, attr->type());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ActivationGrad, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
ActivationGrad() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int GetType() const;
};

@ -50,7 +50,19 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
}
#else
int Add::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Add();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Add return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateAdd(*fbb, attr->activationType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Add, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int Add::GetActivationType() const { return this->primitive_->value_as_Add()->activationType(); }
#endif

@ -31,33 +31,12 @@ class Add : public Arithmetic {
MS_DECLARE_PARENT(Add, Arithmetic);
Add() = default;
explicit Add(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetActivationType(int activation_type);
#else
explicit Add(schema::Primitive *primitive) : Arithmetic(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Add();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateAdd(fbb, attr->activationType());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Add, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Add() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int GetActivationType() const;
};

@ -24,7 +24,19 @@ int AddN::GetN() const { return this->primitive_->value.AsAddN()->N; }
void AddN::SetN(int n) { this->primitive_->value.AsAddN()->N = n; }
#else
int AddN::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_AddN();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_AddN return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateAddN(*fbb, attr->N());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_AddN, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int AddN::GetN() const { return this->primitive_->value_as_AddN()->N(); }
#endif

@ -33,30 +33,9 @@ class AddN : public PrimitiveC {
explicit AddN(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetN(int n);
#else
explicit AddN(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_AddN();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateAddN(fbb, attr->N());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_AddN, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
AddN() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetN() const;

@ -32,7 +32,20 @@ void ArgMax::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMax()->k
void ArgMax::SetAxisType(int axis_type) { this->primitive_->value.AsArgMax()->axisType = axis_type; }
#else
int ArgMax::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_ArgMax();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_ArgMax return nullptr";
return RET_ERROR;
}
auto val_offset =
schema::CreateArgMax(*fbb, attr->axis(), attr->outMaxValue(), attr->topK(), attr->keepDims(), attr->axisType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ArgMax, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int ArgMax::GetAxis() const { return this->primitive_->value_as_ArgMax()->axis(); }
bool ArgMax::GetOutMaxValue() const { return this->primitive_->value_as_ArgMax()->outMaxValue(); }
int ArgMax::GetTopK() const { return this->primitive_->value_as_ArgMax()->topK(); }

@ -37,31 +37,9 @@ class ArgMax : public PrimitiveC {
void SetKeepDims(bool keep_dims);
void SetAxisType(int axis_type);
#else
explicit ArgMax(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_ArgMax();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateArgMax(fbb, attr->axis(), attr->outMaxValue(),
attr->topK(), attr->keepDims(), attr->axisType());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ArgMax, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
ArgMax() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;

@ -32,7 +32,20 @@ void ArgMin::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMin()->k
void ArgMin::SetAxisType(int axis_type) { this->primitive_->value.AsArgMin()->axisType = axis_type; }
#else
int ArgMin::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_ArgMin();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_ArgMin return nullptr";
return RET_ERROR;
}
auto val_offset =
schema::CreateArgMin(*fbb, attr->axis(), attr->outMaxValue(), attr->topK(), attr->keepDims(), attr->axisType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ArgMin, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int ArgMin::GetAxis() const { return this->primitive_->value_as_ArgMin()->axis(); }
bool ArgMin::GetOutMaxValue() const { return this->primitive_->value_as_ArgMin()->outMaxValue(); }
int ArgMin::GetTopK() const { return this->primitive_->value_as_ArgMin()->topK(); }

@ -37,31 +37,9 @@ class ArgMin : public PrimitiveC {
void SetKeepDims(bool keep_dims);
void SetAxisType(int axis_type);
#else
explicit ArgMin(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_ArgMin();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateArgMin(fbb, attr->axis(), attr->outMaxValue(),
attr->topK(), attr->keepDims(), attr->axisType());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ArgMin, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
ArgMin() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;

@ -32,7 +32,11 @@ class Arithmetic : public PrimitiveC {
Arithmetic() = default;
explicit Arithmetic(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {}
// explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {}
Arithmetic() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override {
return RET_ERROR;
}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
bool Broadcasting() { return this->broadcasting_; }

@ -29,7 +29,11 @@ class ArithmeticSelf : public PrimitiveC {
ArithmeticSelf() = default;
explicit ArithmeticSelf(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit ArithmeticSelf(schema::Primitive *primitive) : PrimitiveC(primitive) {}
// explicit ArithmeticSelf(schema::Primitive *primitive) : PrimitiveC(primitive) {}
ArithmeticSelf() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override {
return RET_ERROR;
}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
};

@ -49,7 +49,14 @@ int BatchNorm::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &
}
#else
int BatchNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateBatchNorm(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchNorm, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
float BatchNorm::GetEpsilon() const { return this->primitive_->value_as_BatchNorm()->epsilon(); }
#endif

@ -31,30 +31,12 @@ class BatchNorm : public PrimitiveC {
MS_DECLARE_PARENT(BatchNorm, PrimitiveC);
BatchNorm() = default;
explicit BatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetEpsilon(float epsilon);
#else
explicit BatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto val_offset = schema::CreateBatchNorm(fbb);
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BatchNorm, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
BatchNorm() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
float GetEpsilon() const;
};

@ -32,7 +32,31 @@ void BatchToSpace::SetBlockShape(const std::vector<int> &block_shape) {
void BatchToSpace::SetCrops(const std::vector<int> &crops) { this->primitive_->value.AsBatchToSpace()->crops = crops; }
#else
int BatchToSpace::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_BatchToSpace();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_BatchToSpace return nullptr";
return RET_ERROR;
}
std::vector<int32_t> blockShape;
if (attr->blockShape() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->blockShape()->size()); i++) {
blockShape.push_back(attr->blockShape()->data()[i]);
}
}
std::vector<int32_t> crops;
if (attr->crops() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->crops()->size()); i++) {
crops.push_back(attr->crops()->data()[i]);
}
}
auto val_offset = schema::CreateBatchToSpaceDirect(*fbb, &blockShape, &crops);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchToSpace, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
std::vector<int> BatchToSpace::GetBlockShape() const {
auto fb_vector = this->primitive_->value_as_BatchToSpace()->blockShape();
return std::vector<int>(fb_vector->begin(), fb_vector->end());

@ -35,39 +35,9 @@ class BatchToSpace : public PrimitiveC {
void SetBlockShape(const std::vector<int> &block_shape);
void SetCrops(const std::vector<int> &crops);
#else
explicit BatchToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_BatchToSpace();
MS_ASSERT(attr != nullptr);
auto blockShape = std::make_unique<std::vector<int32_t>>();
for (int i = 0; i < static_cast<int>(attr->blockShape()->size()); i++) {
blockShape->push_back(attr->blockShape()->data()[i]);
}
auto crops = std::make_unique<std::vector<int32_t>>();
for (int i = 0; i < static_cast<int>(attr->crops()->size()); i++) {
crops->push_back(attr->crops()->data()[i]);
}
auto val_offset = schema::CreateBatchToSpaceDirect(fbb, blockShape.release(), crops.release());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BatchToSpace, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
BatchToSpace() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
std::vector<int> GetBlockShape() const;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save