diff --git a/mindspore/lite/src/common/graph_util.cc b/mindspore/lite/src/common/graph_util.cc index e463948889..835f9b9e22 100644 --- a/mindspore/lite/src/common/graph_util.cc +++ b/mindspore/lite/src/common/graph_util.cc @@ -61,5 +61,22 @@ std::vector GetGraphOutputNodes(const schema::MetaGraph *meta_graph) { } return ret; } + +std::vector GetLinkedPostNodeIdx(const schema::MetaGraph &graph, const size_t &tensor_idx) { + std::vector post_node_idxes; + for (size_t i = 0; i < graph.nodes()->size(); i++) { + auto node = graph.nodes()->GetAs(i); + if (node == nullptr) { + continue; + } + auto node_input_idxes = node->inputIndex(); + auto is_contain = std::any_of(node_input_idxes->begin(), node_input_idxes->end(), + [&](const uint32_t &node_input_idx) { return node_input_idx == tensor_idx; }); + if (is_contain) { + post_node_idxes.emplace_back(i); + } + } + return post_node_idxes; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/common/graph_util.h b/mindspore/lite/src/common/graph_util.h old mode 100755 new mode 100644 index 7b1abf36b7..5f494aa202 --- a/mindspore/lite/src/common/graph_util.h +++ b/mindspore/lite/src/common/graph_util.h @@ -34,215 +34,8 @@ std::vector GetGraphInputNodes(const schema::MetaGraph *meta_graph); std::vector GetGraphOutputNodes(const schema::MetaGraph *meta_graph); -class OpNode { - public: - explicit OpNode(const NODE_ID &nodeId) : id(nodeId) {} - NODE_ID ID() { return id; }; - void AddInEdge(NODE_ID nodeId) { inEdges.insert(nodeId); } - void AddOutEdge(NODE_ID nodeId) { outEdges.insert(nodeId); } - std::unordered_set GetAllInEdges() { return inEdges; } - std::unordered_set GetAllOutEdges() { return outEdges; } - - protected: - NODE_ID id; - std::unordered_set inEdges; - std::unordered_set outEdges; -}; - - -template -class OpGraph { - public: - OpGraph() {} - - ~OpGraph(); - - int Build(const schema::MetaGraph *subGraphDef); - NODE_T *GetNode(NODE_ID nodeId); - NODE_T *AddNode(NODE_ID nodeId); - std::unordered_set GetInputNode(); - std::unordered_set GetOutputNode(); - - void AddNodes(std::vector addNodes); - void DeleteNodes(std::vector deleteNodes); - - void AddEdge(NODE_ID nodeId); - int AddEdge(NODE_ID srcId, NODE_ID dstId); - int AddEdge(const schema::CNode *srcNodeDef, const flatbuffers::Vector> *opDefs); - std::unordered_map> GetDepends(); - - protected: - std::unordered_map nodes; -}; - -template -int OpGraph::Build(const schema::MetaGraph *subGraphDef) { - if (subGraphDef == nullptr) { - // MS_LOGE("subGraphDef is nullptr"); - return RET_ERROR; - } - - auto opDefs = subGraphDef->nodes(); - - uint32_t opCount = opDefs->size(); - for (uint32_t i = 0; i < opCount; i++) { - auto opDef = opDefs->GetAs(i); - auto node = AddNode(std::string(opDef->name()->c_str())); - if (node == nullptr) { - // MS_LOGE("add srcNode failed,name %s", opDef->name()->c_str()); - return RET_ERROR; - } - auto ret = AddEdge(opDef, opDefs); - if (ret != RET_OK) { - // MS_LOGE("%s add edge failed. ret:%d", opDef->name()->c_str(), ret); - return RET_ERROR; - } - } - - return RET_OK; -} -template -int OpGraph::AddEdge(const schema::CNode *srcNodeDef, - const flatbuffers::Vector> *nodeDefs) { - MS_ASSERT(srcNodeDef != nullptr); - MS_ASSERT(nodeDefs != nullptr); - NODE_ID srcId = std::string(srcNodeDef->name()->c_str()); - uint32_t opCount = nodeDefs->size(); - // for single op condition - AddNode(srcId); - for (auto index : *(srcNodeDef->outputIndex())) { - for (uint32_t i = 0; i < opCount; i++) { - auto dstNodeDef = nodeDefs->GetAs(i); - bool find = false; - auto inputIndex = dstNodeDef->inputIndex(); - if (std::any_of(inputIndex->begin(), inputIndex->end(), [&index](int i) { return i == index; })) { - find = true; - } - - if (!find) { - continue; - } - NODE_ID dstId = std::string(dstNodeDef->name()->c_str()); - auto ret = AddEdge(srcId, dstId); - if (ret != RET_OK) { - return ret; - } - } - } - - return RET_OK; -} - -template -int OpGraph::AddEdge(NODE_ID srcId, NODE_ID dstId) { - auto srcNode = AddNode(srcId); - if (srcNode == nullptr) { - // MS_LOGE("add srcNode failed"); - return RET_ERROR; - } - auto dstNode = AddNode(dstId); - if (dstNode == nullptr) { - // MS_LOGE("add dstNode failed"); - return RET_ERROR; - } - - srcNode->AddOutEdge(dstNode); - - dstNode->AddInEdge(srcNode); - return RET_OK; -} - -template -NODE_T *OpGraph::GetNode(NODE_ID nodeId) { - auto node = nodes.find(nodeId); - if (node == nodes.end()) { - return nullptr; - } - return node->second; -} - -template -NODE_T *OpGraph::AddNode(NODE_ID nodeId) { - auto node = GetNode(nodeId); - if (node != nullptr) { - return node; - } - node = new (std::nothrow) NODE_T(nodeId); - if (node == nullptr) { - // MS_LOGE("new node failed"); - return nullptr; - } - nodes[nodeId] = node; - return node; -} - -template -void OpGraph::AddNodes(std::vector addNodes) { - for (auto node : addNodes) { - if (node == nullptr) { - return; - } - - nodes[node->ID()] = node; - } -} - -template -void OpGraph::DeleteNodes(std::vector deleteNodes) { - for (auto deletenode : deleteNodes) { - if (deletenode == nullptr) { - continue; - } - auto node = GetNode(deletenode->ID()); - if (node == nullptr) { - continue; - } - nodes.erase(deletenode->ID()); - } -} - -template -std::unordered_set OpGraph::GetInputNode() { - std::unordered_set inputNodes; - for (const auto &iter : nodes) { - auto node = iter.second; - if (node->GetAllInEdges().empty()) { - inputNodes.insert(node); - } - } - return inputNodes; -} - -template -std::unordered_set OpGraph::GetOutputNode() { - std::unordered_set outputNodes; - for (const auto &iter : nodes) { - auto node = iter.second; - if (node->GetAllOutEdges().empty()) { - outputNodes.insert(node); - } - } - return outputNodes; -} - -template -std::unordered_map> OpGraph::GetDepends() { - std::unordered_map> depends; - for (auto nodeIter : nodes) { - depends[nodeIter.second] = nodeIter.second->GetAllInEdges(); - } - return depends; -} - -template -OpGraph::~OpGraph() { - for (auto iter : nodes) { - delete iter.second; - } - nodes.clear(); -} +std::vector GetLinkedPostNodeIdx(const schema::MetaGraph &graph, const size_t &tensor_idx); } // namespace lite } // namespace mindspore #endif // MINDSPORE_LITE_COMMON_GRAPH_UTIL_H_ - diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 5cea874c20..2f3b1a7951 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -32,10 +32,29 @@ namespace mindspore { namespace lite { +static std::vector packed_op = { + schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, + schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D, + schema::PrimitiveType_MatMul}; + +// this method will not check whether tensor_idx is a weight tensor index, caller should ensure this. +static bool WeightTensorNeedCopy(const lite::Model *model, const uint32_t tensor_idx) { + MS_ASSERT(nullptr != model); + auto meta_graph = model->GetMetaGraph(); + MS_ASSERT(nullptr != meta_graph); + auto post_node_idxes = GetLinkedPostNodeIdx(*meta_graph, tensor_idx); + return std::none_of(post_node_idxes.begin(), post_node_idxes.end(), [&](const size_t &post_node_idx) { + auto cNode = meta_graph->nodes()->GetAs(post_node_idx); + MS_ASSERT(cNode != nullptr); + return IsContain(packed_op, cNode->primitive()->value_type()); + }); +} + int LiteSession::ConvertTensors(const lite::Model *model) { MS_ASSERT(nullptr != model); auto meta_graph = model->GetMetaGraph(); MS_ASSERT(nullptr != meta_graph); + copyed_tensor_idxes_.clear(); uint32_t tensorCount = meta_graph->allTensors()->size(); for (uint32_t i = 0; i < tensorCount; i++) { auto *srcTensor = meta_graph->allTensors()->GetAs(i); @@ -54,16 +73,30 @@ int LiteSession::ConvertTensors(const lite::Model *model) { } } int dataType = srcTensor->dataType(); - auto *dstTensor = new tensor::Tensor(TypeId(dataType), shape, srcTensor->format(), srcTensor->nodeType()); + auto *dstTensor = + new (std::nothrow) tensor::Tensor(TypeId(dataType), shape, srcTensor->format(), srcTensor->nodeType()); + if (dstTensor == nullptr) { + MS_LOG(ERROR) << "new " << i << "th tensor failed"; + return RET_NULL_PTR; + } if (srcTensor->nodeType() == schema::NodeType_ValueNode && srcTensor->data() != nullptr && srcTensor->data()->size() > 0) { if (shape.empty()) { shape.push_back(1); + dstTensor->set_shape(shape); } - MS_ASSERT(dstTensor != nullptr); MS_ASSERT(dstTensor->Size() == srcTensor->data()->size()); - // no copy data, do copy when call LiteKernel::Init - dstTensor->SetData(const_cast(srcTensor->data()->data())); + if (WeightTensorNeedCopy(model, i)) { + auto ret = dstTensor->MallocData(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Malloc data for " << i << "th tensor failed"; + return RET_ERROR; + } + memcpy(dstTensor->Data(), srcTensor->data()->data(), dstTensor->Size()); + copyed_tensor_idxes_.emplace_back(i); + } else { + dstTensor->SetData(const_cast(srcTensor->data()->data())); + } } auto quant_params = srcTensor->quantParams(); if (quant_params != nullptr) { @@ -74,7 +107,6 @@ int LiteSession::ConvertTensors(const lite::Model *model) { dstTensor->AddQuantParam(quant_arg); } } - this->tensors_.emplace_back(dstTensor); } @@ -240,6 +272,7 @@ int LiteSession::CompileGraph(Model *model) { } executor->Prepare(this->kernels_); + model->FreeMetaGraph(); return RET_OK; } @@ -277,7 +310,10 @@ int LiteSession::Init(Context *context) { } #endif executor = new Executor(); - MS_ASSERT(nullptr != executor); + if (nullptr == executor) { + MS_LOG(ERROR) << "new Executor failed"; + return RET_ERROR; + } return RET_OK; } @@ -288,9 +324,12 @@ void LiteSession::BindThread(bool if_bind) { } LiteSession::~LiteSession() { - for (auto *tensor : tensors_) { - // weight data can not be to free, we will free weight data when freeing meta_graph - if (tensor->TensorType() == schema::NodeType_ValueNode && !IsContain(this->inputs_, tensor)) { + for (size_t i = 0; i < tensors_.size(); i++) { + auto *tensor = tensors_.at(i); + MS_ASSERT(tensor != nullptr); + // data of weight tensor of node in packed_op can not be to free, we will free weight data when freeing meta_graph + if (tensor->TensorType() == schema::NodeType_ValueNode && !IsContain(this->inputs_, tensor) && + !IsContain(copyed_tensor_idxes_, i)) { tensor->SetData(nullptr); } delete tensor; diff --git a/mindspore/lite/src/lite_session.h b/mindspore/lite/src/lite_session.h index 325b25c1e3..53a205cbe5 100644 --- a/mindspore/lite/src/lite_session.h +++ b/mindspore/lite/src/lite_session.h @@ -87,6 +87,7 @@ class LiteSession : public session::LiteSession { Context *context_ = nullptr; std::vector kernels_; std::vector tensors_; + std::vector copyed_tensor_idxes_; // graph input tensors std::vector inputs_; // graph output tensors diff --git a/mindspore/lite/src/model.cc b/mindspore/lite/src/model.cc index 797f65b2bd..cdc7c5f1c8 100644 --- a/mindspore/lite/src/model.cc +++ b/mindspore/lite/src/model.cc @@ -135,7 +135,7 @@ mindspore::lite::PrimitiveC *Model::GetOp(const std::string &name) const { void Model::FreeMetaGraph() { MS_ASSERT(nullptr != model_impl_); - return model_impl_->FreeMetaGraph(); + model_impl_->FreeMetaGraph(); } const schema::MetaGraph *Model::GetMetaGraph() const { diff --git a/mindspore/lite/src/ops/abs.cc b/mindspore/lite/src/ops/abs.cc new file mode 100644 index 0000000000..1416513b06 --- /dev/null +++ b/mindspore/lite/src/ops/abs.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/abs.h" + +namespace mindspore { +namespace lite { +#ifndef PRIMITIVE_WRITEABLE +int Abs::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateAbs(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Abs, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/abs.h b/mindspore/lite/src/ops/abs.h index 4e9fd3c12c..7b4fdc1e45 100644 --- a/mindspore/lite/src/ops/abs.h +++ b/mindspore/lite/src/ops/abs.h @@ -32,27 +32,9 @@ class Abs : public ArithmeticSelf { Abs() = default; explicit Abs(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} #else - explicit Abs(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateAbs(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Abs, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Abs() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/activation.cc b/mindspore/lite/src/ops/activation.cc index 10e9b287cd..b4af4c053c 100644 --- a/mindspore/lite/src/ops/activation.cc +++ b/mindspore/lite/src/ops/activation.cc @@ -55,7 +55,19 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector return RET_OK; } #else - +int Activation::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Activation(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Activation return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateActivation(*fbb, attr->type(), attr->alpha()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Activation, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int Activation::GetType() const { return this->primitive_->value_as_Activation()->type(); } float Activation::GetAlpha() const { return this->primitive_->value_as_Activation()->alpha(); } #endif diff --git a/mindspore/lite/src/ops/activation.h b/mindspore/lite/src/ops/activation.h index 30fdf511d5..3934572b9c 100644 --- a/mindspore/lite/src/ops/activation.h +++ b/mindspore/lite/src/ops/activation.h @@ -30,34 +30,13 @@ class Activation : public PrimitiveC { MS_DECLARE_PARENT(Activation, PrimitiveC); Activation() = default; explicit Activation(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - int UnPackAttr(const Primitive &prim, const std::vector &inputs); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; void SetType(int type); void SetAlpha(float alpha); #else - explicit Activation(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Activation(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateActivation(fbb, attr->type(), attr->alpha()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Activation, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Activation() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int GetType() const; float GetAlpha() const; diff --git a/mindspore/lite/src/ops/activation_grad.cc b/mindspore/lite/src/ops/activation_grad.cc index e60fcf65c6..a82479a6a6 100644 --- a/mindspore/lite/src/ops/activation_grad.cc +++ b/mindspore/lite/src/ops/activation_grad.cc @@ -26,7 +26,19 @@ void ActivationGrad::SetType(int type) { } #else - +int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_ActivationGrad(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_ActivationGrad return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateActivationGrad(*fbb, attr->type()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ActivationGrad, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); } #endif diff --git a/mindspore/lite/src/ops/activation_grad.h b/mindspore/lite/src/ops/activation_grad.h index 9c4d0dc7f9..f4461d30c2 100644 --- a/mindspore/lite/src/ops/activation_grad.h +++ b/mindspore/lite/src/ops/activation_grad.h @@ -33,30 +33,9 @@ class ActivationGrad : public PrimitiveC { explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetType(int type); #else - explicit ActivationGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_ActivationGrad(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateActivationGrad(fbb, attr->type()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ActivationGrad, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + ActivationGrad() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int GetType() const; }; diff --git a/mindspore/lite/src/ops/add.cc b/mindspore/lite/src/ops/add.cc index 4ec08096e3..d01d1e16e3 100644 --- a/mindspore/lite/src/ops/add.cc +++ b/mindspore/lite/src/ops/add.cc @@ -50,7 +50,19 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector &inputs } #else - +int Add::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Add(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Add return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateAdd(*fbb, attr->activationType()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Add, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int Add::GetActivationType() const { return this->primitive_->value_as_Add()->activationType(); } #endif diff --git a/mindspore/lite/src/ops/add.h b/mindspore/lite/src/ops/add.h index 7653e2fb11..78583b909b 100644 --- a/mindspore/lite/src/ops/add.h +++ b/mindspore/lite/src/ops/add.h @@ -31,33 +31,12 @@ class Add : public Arithmetic { MS_DECLARE_PARENT(Add, Arithmetic); Add() = default; explicit Add(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} - int UnPackAttr(const Primitive &prim, const std::vector &inputs); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; void SetActivationType(int activation_type); #else - explicit Add(schema::Primitive *primitive) : Arithmetic(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Add(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateAdd(fbb, attr->activationType()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Add, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Add() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int GetActivationType() const; }; diff --git a/mindspore/lite/src/ops/addn.cc b/mindspore/lite/src/ops/addn.cc index 4587f61f7c..9c82f8c271 100644 --- a/mindspore/lite/src/ops/addn.cc +++ b/mindspore/lite/src/ops/addn.cc @@ -24,7 +24,19 @@ int AddN::GetN() const { return this->primitive_->value.AsAddN()->N; } void AddN::SetN(int n) { this->primitive_->value.AsAddN()->N = n; } #else - +int AddN::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_AddN(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_AddN return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateAddN(*fbb, attr->N()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_AddN, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int AddN::GetN() const { return this->primitive_->value_as_AddN()->N(); } #endif diff --git a/mindspore/lite/src/ops/addn.h b/mindspore/lite/src/ops/addn.h index a0861b7a27..8b5c61d060 100644 --- a/mindspore/lite/src/ops/addn.h +++ b/mindspore/lite/src/ops/addn.h @@ -33,30 +33,9 @@ class AddN : public PrimitiveC { explicit AddN(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetN(int n); #else - explicit AddN(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_AddN(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateAddN(fbb, attr->N()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_AddN, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + AddN() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetN() const; diff --git a/mindspore/lite/src/ops/argmax.cc b/mindspore/lite/src/ops/argmax.cc index 297293d4cc..3005409cc3 100644 --- a/mindspore/lite/src/ops/argmax.cc +++ b/mindspore/lite/src/ops/argmax.cc @@ -32,7 +32,20 @@ void ArgMax::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMax()->k void ArgMax::SetAxisType(int axis_type) { this->primitive_->value.AsArgMax()->axisType = axis_type; } #else - +int ArgMax::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_ArgMax(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_ArgMax return nullptr"; + return RET_ERROR; + } + auto val_offset = + schema::CreateArgMax(*fbb, attr->axis(), attr->outMaxValue(), attr->topK(), attr->keepDims(), attr->axisType()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ArgMax, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int ArgMax::GetAxis() const { return this->primitive_->value_as_ArgMax()->axis(); } bool ArgMax::GetOutMaxValue() const { return this->primitive_->value_as_ArgMax()->outMaxValue(); } int ArgMax::GetTopK() const { return this->primitive_->value_as_ArgMax()->topK(); } diff --git a/mindspore/lite/src/ops/argmax.h b/mindspore/lite/src/ops/argmax.h index d114195d11..4b58916abc 100644 --- a/mindspore/lite/src/ops/argmax.h +++ b/mindspore/lite/src/ops/argmax.h @@ -37,31 +37,9 @@ class ArgMax : public PrimitiveC { void SetKeepDims(bool keep_dims); void SetAxisType(int axis_type); #else - explicit ArgMax(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_ArgMax(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateArgMax(fbb, attr->axis(), attr->outMaxValue(), - attr->topK(), attr->keepDims(), attr->axisType()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ArgMax, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + ArgMax() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; diff --git a/mindspore/lite/src/ops/argmin.cc b/mindspore/lite/src/ops/argmin.cc index 21a2f8e3fa..c3e300130d 100644 --- a/mindspore/lite/src/ops/argmin.cc +++ b/mindspore/lite/src/ops/argmin.cc @@ -32,7 +32,20 @@ void ArgMin::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMin()->k void ArgMin::SetAxisType(int axis_type) { this->primitive_->value.AsArgMin()->axisType = axis_type; } #else - +int ArgMin::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_ArgMin(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_ArgMin return nullptr"; + return RET_ERROR; + } + auto val_offset = + schema::CreateArgMin(*fbb, attr->axis(), attr->outMaxValue(), attr->topK(), attr->keepDims(), attr->axisType()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ArgMin, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int ArgMin::GetAxis() const { return this->primitive_->value_as_ArgMin()->axis(); } bool ArgMin::GetOutMaxValue() const { return this->primitive_->value_as_ArgMin()->outMaxValue(); } int ArgMin::GetTopK() const { return this->primitive_->value_as_ArgMin()->topK(); } diff --git a/mindspore/lite/src/ops/argmin.h b/mindspore/lite/src/ops/argmin.h index ae1b644e11..a62fff3917 100644 --- a/mindspore/lite/src/ops/argmin.h +++ b/mindspore/lite/src/ops/argmin.h @@ -37,31 +37,9 @@ class ArgMin : public PrimitiveC { void SetKeepDims(bool keep_dims); void SetAxisType(int axis_type); #else - explicit ArgMin(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_ArgMin(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateArgMin(fbb, attr->axis(), attr->outMaxValue(), - attr->topK(), attr->keepDims(), attr->axisType()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ArgMin, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + ArgMin() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; diff --git a/mindspore/lite/src/ops/arithmetic.h b/mindspore/lite/src/ops/arithmetic.h index 98189e3578..bcc516c1fe 100644 --- a/mindspore/lite/src/ops/arithmetic.h +++ b/mindspore/lite/src/ops/arithmetic.h @@ -32,7 +32,11 @@ class Arithmetic : public PrimitiveC { Arithmetic() = default; explicit Arithmetic(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} #else - explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {} + // explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {} + Arithmetic() = default; + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override { + return RET_ERROR; + } #endif int InferShape(std::vector inputs_, std::vector outputs_) override; bool Broadcasting() { return this->broadcasting_; } diff --git a/mindspore/lite/src/ops/arithmetic_self.h b/mindspore/lite/src/ops/arithmetic_self.h index 3a9acaa44f..57e8a108ef 100644 --- a/mindspore/lite/src/ops/arithmetic_self.h +++ b/mindspore/lite/src/ops/arithmetic_self.h @@ -29,7 +29,11 @@ class ArithmeticSelf : public PrimitiveC { ArithmeticSelf() = default; explicit ArithmeticSelf(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} #else - explicit ArithmeticSelf(schema::Primitive *primitive) : PrimitiveC(primitive) {} + // explicit ArithmeticSelf(schema::Primitive *primitive) : PrimitiveC(primitive) {} + ArithmeticSelf() = default; + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override { + return RET_ERROR; + } #endif int InferShape(std::vector inputs_, std::vector outputs_) override; }; diff --git a/mindspore/lite/src/ops/batch_norm.cc b/mindspore/lite/src/ops/batch_norm.cc index 9c7fb2d8b8..736e6a9441 100644 --- a/mindspore/lite/src/ops/batch_norm.cc +++ b/mindspore/lite/src/ops/batch_norm.cc @@ -49,7 +49,14 @@ int BatchNorm::UnPackAttr(const Primitive &prim, const std::vector & } #else - +int BatchNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateBatchNorm(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchNorm, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} float BatchNorm::GetEpsilon() const { return this->primitive_->value_as_BatchNorm()->epsilon(); } #endif diff --git a/mindspore/lite/src/ops/batch_norm.h b/mindspore/lite/src/ops/batch_norm.h index f030e2c32a..03dac3a71e 100644 --- a/mindspore/lite/src/ops/batch_norm.h +++ b/mindspore/lite/src/ops/batch_norm.h @@ -31,30 +31,12 @@ class BatchNorm : public PrimitiveC { MS_DECLARE_PARENT(BatchNorm, PrimitiveC); BatchNorm() = default; explicit BatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - int UnPackAttr(const Primitive &prim, const std::vector &inputs); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; void SetEpsilon(float epsilon); #else - explicit BatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateBatchNorm(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BatchNorm, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + BatchNorm() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif float GetEpsilon() const; }; diff --git a/mindspore/lite/src/ops/batch_to_space.cc b/mindspore/lite/src/ops/batch_to_space.cc index f8139aae4d..683508b628 100644 --- a/mindspore/lite/src/ops/batch_to_space.cc +++ b/mindspore/lite/src/ops/batch_to_space.cc @@ -32,7 +32,31 @@ void BatchToSpace::SetBlockShape(const std::vector &block_shape) { void BatchToSpace::SetCrops(const std::vector &crops) { this->primitive_->value.AsBatchToSpace()->crops = crops; } #else - +int BatchToSpace::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_BatchToSpace(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_BatchToSpace return nullptr"; + return RET_ERROR; + } + std::vector blockShape; + if (attr->blockShape() != nullptr) { + for (int i = 0; i < static_cast(attr->blockShape()->size()); i++) { + blockShape.push_back(attr->blockShape()->data()[i]); + } + } + std::vector crops; + if (attr->crops() != nullptr) { + for (int i = 0; i < static_cast(attr->crops()->size()); i++) { + crops.push_back(attr->crops()->data()[i]); + } + } + auto val_offset = schema::CreateBatchToSpaceDirect(*fbb, &blockShape, &crops); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchToSpace, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} std::vector BatchToSpace::GetBlockShape() const { auto fb_vector = this->primitive_->value_as_BatchToSpace()->blockShape(); return std::vector(fb_vector->begin(), fb_vector->end()); diff --git a/mindspore/lite/src/ops/batch_to_space.h b/mindspore/lite/src/ops/batch_to_space.h index 6d0c333cfe..9c9632fc37 100644 --- a/mindspore/lite/src/ops/batch_to_space.h +++ b/mindspore/lite/src/ops/batch_to_space.h @@ -35,39 +35,9 @@ class BatchToSpace : public PrimitiveC { void SetBlockShape(const std::vector &block_shape); void SetCrops(const std::vector &crops); #else - explicit BatchToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_BatchToSpace(); - MS_ASSERT(attr != nullptr); - - auto blockShape = std::make_unique>(); - for (int i = 0; i < static_cast(attr->blockShape()->size()); i++) { - blockShape->push_back(attr->blockShape()->data()[i]); - } - auto crops = std::make_unique>(); - for (int i = 0; i < static_cast(attr->crops()->size()); i++) { - crops->push_back(attr->crops()->data()[i]); - } - - auto val_offset = schema::CreateBatchToSpaceDirect(fbb, blockShape.release(), crops.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BatchToSpace, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + BatchToSpace() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetBlockShape() const; diff --git a/mindspore/lite/src/ops/bias_add.cc b/mindspore/lite/src/ops/bias_add.cc index 5a042e60eb..bb7059e1ab 100644 --- a/mindspore/lite/src/ops/bias_add.cc +++ b/mindspore/lite/src/ops/bias_add.cc @@ -54,7 +54,25 @@ int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector &in } #else - +int BiasAdd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_BiasAdd(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_BiasAdd return nullptr"; + return RET_ERROR; + } + std::vector axis; + if (attr->axis() != nullptr) { + for (int i = 0; i < static_cast(attr->axis()->size()); i++) { + axis.push_back(attr->axis()->data()[i]); + } + } + auto val_offset = schema::CreateBiasAddDirect(*fbb, &axis); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BiasAdd, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} std::vector BiasAdd::GetAxis() const { auto fb_vector = this->primitive_->value_as_BiasAdd()->axis(); return std::vector(fb_vector->begin(), fb_vector->end()); diff --git a/mindspore/lite/src/ops/bias_add.h b/mindspore/lite/src/ops/bias_add.h index e8a1354b45..1298cf28f1 100644 --- a/mindspore/lite/src/ops/bias_add.h +++ b/mindspore/lite/src/ops/bias_add.h @@ -32,38 +32,12 @@ class BiasAdd : public PrimitiveC { MS_DECLARE_PARENT(BiasAdd, PrimitiveC); BiasAdd() = default; explicit BiasAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - int UnPackAttr(const Primitive &prim, const std::vector &inputs); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; void SetAxis(const std::vector &axis); #else - explicit BiasAdd(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_BiasAdd(); - MS_ASSERT(attr != nullptr); - - auto axis = std::make_unique>(); - for (int i = 0; i < static_cast(attr->axis()->size()); i++) { - axis->push_back(attr->axis()->data()[i]); - } - - auto val_offset = schema::CreateBiasAddDirect(fbb, axis.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BiasAdd, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + BiasAdd() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif std::vector GetAxis() const; }; diff --git a/mindspore/lite/src/ops/bias_grad.cc b/mindspore/lite/src/ops/bias_grad.cc index 4bce760d96..23c01adc3e 100644 --- a/mindspore/lite/src/ops/bias_grad.cc +++ b/mindspore/lite/src/ops/bias_grad.cc @@ -24,7 +24,25 @@ std::vector BiasGrad::GetAxis() const { return this->primitive_->value.AsBi void BiasGrad::SetAxis(const std::vector &axis) { this->primitive_->value.AsBiasGrad()->axis = axis; } #else - +int BiasGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_BiasGrad(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_BiasGrad return nullptr"; + return RET_ERROR; + } + std::vector axis; + if (attr->axis() != nullptr) { + for (int i = 0; i < static_cast(attr->axis()->size()); i++) { + axis.push_back(attr->axis()->data()[i]); + } + } + auto val_offset = schema::CreateBiasGradDirect(*fbb, &axis); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BiasGrad, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} std::vector BiasGrad::GetAxis() const { auto fb_vector = this->primitive_->value_as_BiasGrad()->axis(); return std::vector(fb_vector->begin(), fb_vector->end()); diff --git a/mindspore/lite/src/ops/bias_grad.h b/mindspore/lite/src/ops/bias_grad.h index 7aaa70db64..c3729764c1 100644 --- a/mindspore/lite/src/ops/bias_grad.h +++ b/mindspore/lite/src/ops/bias_grad.h @@ -35,35 +35,9 @@ class BiasGrad : public PrimitiveC { void SetAxis(const std::vector &axis); #else - explicit BiasGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_BiasGrad(); - MS_ASSERT(attr != nullptr); - - auto axis = std::make_unique>(); - for (int i = 0; i < static_cast(attr->axis()->size()); i++) { - axis->push_back(attr->axis()->data()[i]); - } - - auto val_offset = schema::CreateBiasGradDirect(fbb, axis.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BiasGrad, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + BiasGrad() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif std::vector GetAxis() const; }; diff --git a/mindspore/lite/src/ops/bn_grad_input.cc b/mindspore/lite/src/ops/bn_grad_input.cc index 08bf36d465..9aee03f81d 100644 --- a/mindspore/lite/src/ops/bn_grad_input.cc +++ b/mindspore/lite/src/ops/bn_grad_input.cc @@ -26,7 +26,19 @@ void BNGradInput::SetEps(float eps) { this->primitive_->value.AsBNGradInput()->e void BNGradInput::SetChannels(int channels) { this->primitive_->value.AsBNGradInput()->channels = channels; } #else - +int BNGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_BNGradInput(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_BNGradInput return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateBNGradInput(*fbb, attr->eps(), attr->channels()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BNGradInput, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} float BNGradInput::GetEps() const { return this->primitive_->value_as_BNGradInput()->eps(); } int BNGradInput::GetChannels() const { return this->primitive_->value_as_BNGradInput()->channels(); } diff --git a/mindspore/lite/src/ops/bn_grad_input.h b/mindspore/lite/src/ops/bn_grad_input.h index f79fb2b631..aa22933f8a 100644 --- a/mindspore/lite/src/ops/bn_grad_input.h +++ b/mindspore/lite/src/ops/bn_grad_input.h @@ -34,30 +34,9 @@ class BNGradInput : public PrimitiveC { void SetEps(float eps); void SetChannels(int channels); #else - explicit BNGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_BNGradInput(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateBNGradInput(fbb, attr->eps(), attr->channels()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BNGradInput, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + BNGradInput() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif float GetEps() const; int GetChannels() const; diff --git a/mindspore/lite/src/ops/broadcast_to.cc b/mindspore/lite/src/ops/broadcast_to.cc index 16a2d827df..1c4e5875cd 100644 --- a/mindspore/lite/src/ops/broadcast_to.cc +++ b/mindspore/lite/src/ops/broadcast_to.cc @@ -26,7 +26,25 @@ void BroadcastTo::SetDstShape(const std::vector &dst_shape) { } #else - +int BroadcastTo::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_BroadcastTo(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_BroadcastTo return nullptr"; + return RET_ERROR; + } + std::vector dst_shape; + if (attr->dst_shape() != nullptr) { + for (int i = 0; i < static_cast(attr->dst_shape()->size()); i++) { + dst_shape.push_back(attr->dst_shape()->data()[i]); + } + } + auto val_offset = schema::CreateBroadcastToDirect(*fbb, &dst_shape); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BroadcastTo, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} std::vector BroadcastTo::GetDstShape() const { auto fb_vector = this->primitive_->value_as_BroadcastTo()->dst_shape(); return std::vector(fb_vector->begin(), fb_vector->end()); diff --git a/mindspore/lite/src/ops/broadcast_to.h b/mindspore/lite/src/ops/broadcast_to.h index 7aa64c96b4..d0181da165 100644 --- a/mindspore/lite/src/ops/broadcast_to.h +++ b/mindspore/lite/src/ops/broadcast_to.h @@ -35,35 +35,9 @@ class BroadcastTo : public PrimitiveC { void SetDstShape(const std::vector &dst_shape); #else - explicit BroadcastTo(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_BroadcastTo(); - MS_ASSERT(attr != nullptr); - - auto dst_shape = std::make_unique>(); - for (int i = 0; i < static_cast(attr->dst_shape()->size()); i++) { - dst_shape->push_back(attr->dst_shape()->data()[i]); - } - - auto val_offset = schema::CreateBroadcastToDirect(fbb, dst_shape.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BroadcastTo, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + BroadcastTo() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetDstShape() const; diff --git a/mindspore/lite/src/ops/cast.cc b/mindspore/lite/src/ops/cast.cc index 83bafb9141..10cf0f63c2 100644 --- a/mindspore/lite/src/ops/cast.cc +++ b/mindspore/lite/src/ops/cast.cc @@ -26,7 +26,19 @@ void Cast::SetSrcT(int src_t) { this->primitive_->value.AsCast()->srcT = src_t; void Cast::SetDstT(int dst_t) { this->primitive_->value.AsCast()->dstT = dst_t; } #else - +int Cast::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Cast(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Cast return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateCast(*fbb, attr->srcT(), attr->dstT()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Cast, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int Cast::GetSrcT() const { return this->primitive_->value_as_Cast()->srcT(); } int Cast::GetDstT() const { return this->primitive_->value_as_Cast()->dstT(); } diff --git a/mindspore/lite/src/ops/cast.h b/mindspore/lite/src/ops/cast.h index 9c9738b7b3..973e405a18 100644 --- a/mindspore/lite/src/ops/cast.h +++ b/mindspore/lite/src/ops/cast.h @@ -34,30 +34,9 @@ class Cast : public PrimitiveC { void SetSrcT(int src_t); void SetDstT(int dst_t); #else - explicit Cast(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Cast(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateCast(fbb, attr->srcT(), attr->dstT()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Cast, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Cast() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetSrcT() const; diff --git a/mindspore/lite/src/ops/ceil.h b/mindspore/lite/src/ops/ceil.h index f1c241339a..6af9ef2910 100644 --- a/mindspore/lite/src/ops/ceil.h +++ b/mindspore/lite/src/ops/ceil.h @@ -32,26 +32,15 @@ class Ceil : public ArithmeticSelf { Ceil() = default; explicit Ceil(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} #else - explicit Ceil(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateCeil(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Ceil, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Ceil() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateCeil(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Ceil, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; } #endif }; diff --git a/mindspore/lite/src/ops/clip.cc b/mindspore/lite/src/ops/clip.cc index c5f96b6f99..08e654337a 100644 --- a/mindspore/lite/src/ops/clip.cc +++ b/mindspore/lite/src/ops/clip.cc @@ -26,7 +26,19 @@ void Clip::SetMax(float max) { this->primitive_->value.AsClip()->max = max; } void Clip::SetMin(float min) { this->primitive_->value.AsClip()->min = min; } #else - +int Clip::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Clip(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Clip return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateClip(*fbb, attr->max(), attr->min()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Clip, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} float Clip::GetMax() const { return this->primitive_->value_as_Clip()->max(); } float Clip::GetMin() const { return this->primitive_->value_as_Clip()->min(); } diff --git a/mindspore/lite/src/ops/clip.h b/mindspore/lite/src/ops/clip.h index 453863baae..7cd343dbad 100644 --- a/mindspore/lite/src/ops/clip.h +++ b/mindspore/lite/src/ops/clip.h @@ -34,30 +34,9 @@ class Clip : public PrimitiveC { void SetMax(float max); void SetMin(float min); #else - explicit Clip(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Clip(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateClip(fbb, attr->max(), attr->min()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Clip, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Clip() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif float GetMax() const; float GetMin() const; diff --git a/mindspore/lite/src/ops/concat.cc b/mindspore/lite/src/ops/concat.cc index f89a721419..db73bac434 100644 --- a/mindspore/lite/src/ops/concat.cc +++ b/mindspore/lite/src/ops/concat.cc @@ -60,7 +60,19 @@ int Concat::UnPackAttr(const Primitive &prim, const std::vector &inp } #else - +int Concat::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Concat(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Concat return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateConcat(*fbb, attr->axis(), attr->n()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Concat, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); } int Concat::GetN() const { return this->primitive_->value_as_Concat()->n(); } diff --git a/mindspore/lite/src/ops/concat.h b/mindspore/lite/src/ops/concat.h index 11540624be..c12d98fb0f 100644 --- a/mindspore/lite/src/ops/concat.h +++ b/mindspore/lite/src/ops/concat.h @@ -31,34 +31,13 @@ class Concat : public PrimitiveC { MS_DECLARE_PARENT(Concat, PrimitiveC); Concat() = default; explicit Concat(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - int UnPackAttr(const Primitive &prim, const std::vector &inputs); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; void SetAxis(int axis); void SetN(int n); #else - explicit Concat(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Concat(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateConcat(fbb, attr->axis(), attr->n()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Concat, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Concat() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; diff --git a/mindspore/lite/src/ops/constant_of_shape.cc b/mindspore/lite/src/ops/constant_of_shape.cc index 023ce63ba5..587bc17278 100644 --- a/mindspore/lite/src/ops/constant_of_shape.cc +++ b/mindspore/lite/src/ops/constant_of_shape.cc @@ -30,7 +30,19 @@ float ConstantOfShape::GetValue() const { return this->primitive_->value.AsConst void ConstantOfShape::SetValue(float value) { this->primitive_->value.AsConstantOfShape()->value = value; } #else - +int ConstantOfShape::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_ConstantOfShape(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_ConstantOfShape return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateConstantOfShape(*fbb, attr->value()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ConstantOfShape, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} float ConstantOfShape::GetValue() const { return this->primitive_->value_as_ConstantOfShape()->value(); } #endif diff --git a/mindspore/lite/src/ops/constant_of_shape.h b/mindspore/lite/src/ops/constant_of_shape.h index c03fe077cd..ab96e088b3 100644 --- a/mindspore/lite/src/ops/constant_of_shape.h +++ b/mindspore/lite/src/ops/constant_of_shape.h @@ -33,30 +33,9 @@ class ConstantOfShape : public PrimitiveC { explicit ConstantOfShape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} void SetValue(float value); #else - explicit ConstantOfShape(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_ConstantOfShape(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateConstantOfShape(fbb, attr->value()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ConstantOfShape, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + ConstantOfShape() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; float GetValue() const; diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index bd000233f3..38151eef0f 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -338,7 +338,23 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector &inp } #else +int Conv2D::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Conv2D(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Conv2D return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateConv2D( + *fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(), + attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(), + attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Conv2D, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int Conv2D::GetFormat() const { return this->primitive_->value_as_Conv2D()->format(); } int Conv2D::GetGroup() const { return this->primitive_->value_as_Conv2D()->group(); } int Conv2D::GetChannelIn() const { return this->primitive_->value_as_Conv2D()->channelIn(); } diff --git a/mindspore/lite/src/ops/conv2d.h b/mindspore/lite/src/ops/conv2d.h index 992f6af107..21367dcdc5 100644 --- a/mindspore/lite/src/ops/conv2d.h +++ b/mindspore/lite/src/ops/conv2d.h @@ -34,7 +34,7 @@ class Conv2D : public PrimitiveC { Conv2D() = default; explicit Conv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - int UnPackAttr(const Primitive &prim, const std::vector &inputs); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; void SetFormat(int format); void SetGroup(int group); void SetChannelIn(int channel_in); @@ -63,34 +63,9 @@ class Conv2D : public PrimitiveC { #else public: - explicit Conv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Conv2D(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateConv2D(fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), - attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(), - attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(), - attr->padRight(), attr->dilateW(), attr->dilateH(), - attr->hasBias(), attr->activationType()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Conv2D, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Conv2D() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif public: diff --git a/mindspore/lite/src/ops/conv2d_grad_filter.cc b/mindspore/lite/src/ops/conv2d_grad_filter.cc index 5ac161ffaa..f3ef4d36e1 100644 --- a/mindspore/lite/src/ops/conv2d_grad_filter.cc +++ b/mindspore/lite/src/ops/conv2d_grad_filter.cc @@ -68,7 +68,22 @@ void Conv2DGradFilter::SetActivationType(int activation_type) { } #else - +int Conv2DGradFilter::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Conv2DGradFilter(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Conv2DGradFilter return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateConv2DGradFilter( + *fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(), + attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(), + attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Conv2DGradFilter, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int Conv2DGradFilter::GetFormat() const { return this->primitive_->value_as_Conv2DGradFilter()->format(); } int Conv2DGradFilter::GetGroup() const { return this->primitive_->value_as_Conv2DGradFilter()->group(); } int Conv2DGradFilter::GetChannelIn() const { return this->primitive_->value_as_Conv2DGradFilter()->channelIn(); } diff --git a/mindspore/lite/src/ops/conv2d_grad_filter.h b/mindspore/lite/src/ops/conv2d_grad_filter.h index 9de5064a6e..54fd9a3bf0 100644 --- a/mindspore/lite/src/ops/conv2d_grad_filter.h +++ b/mindspore/lite/src/ops/conv2d_grad_filter.h @@ -49,35 +49,9 @@ class Conv2DGradFilter : public PrimitiveC { void SetHasBias(bool has_bias); void SetActivationType(int activation_type); #else - explicit Conv2DGradFilter(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Conv2DGradFilter(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateConv2DGradFilter(fbb, attr->format(), attr->group(), - attr->channelIn(), attr->channelOut(), - attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(), - attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(), - attr->padRight(), attr->dilateW(), attr->dilateH(), - attr->hasBias(), attr->activationType()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Conv2DGradFilter, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Conv2DGradFilter() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int GetFormat() const; int GetGroup() const; diff --git a/mindspore/lite/src/ops/conv2d_grad_input.cc b/mindspore/lite/src/ops/conv2d_grad_input.cc index 896d49e94e..a8a26d2bc2 100644 --- a/mindspore/lite/src/ops/conv2d_grad_input.cc +++ b/mindspore/lite/src/ops/conv2d_grad_input.cc @@ -66,7 +66,22 @@ void Conv2DGradInput::SetActivationType(int activation_type) { } #else - +int Conv2DGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Conv2DGradInput(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Conv2DGradInput return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateConv2DGradInput( + *fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(), + attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(), + attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Conv2DGradInput, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int Conv2DGradInput::GetFormat() const { return this->primitive_->value_as_Conv2DGradInput()->format(); } int Conv2DGradInput::GetGroup() const { return this->primitive_->value_as_Conv2DGradInput()->group(); } int Conv2DGradInput::GetChannelIn() const { return this->primitive_->value_as_Conv2DGradInput()->channelIn(); } diff --git a/mindspore/lite/src/ops/conv2d_grad_input.h b/mindspore/lite/src/ops/conv2d_grad_input.h index bf37f42dbe..7d8cd2582a 100644 --- a/mindspore/lite/src/ops/conv2d_grad_input.h +++ b/mindspore/lite/src/ops/conv2d_grad_input.h @@ -49,35 +49,9 @@ class Conv2DGradInput : public PrimitiveC { void SetHasBias(bool has_bias); void SetActivationType(int activation_type); #else - explicit Conv2DGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Conv2DGradInput(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateConv2DGradInput(fbb, attr->format(), attr->group(), - attr->channelIn(), attr->channelOut(), - attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(), - attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(), - attr->padRight(), attr->dilateW(), attr->dilateH(), - attr->hasBias(), attr->activationType()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Conv2DGradInput, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Conv2DGradInput() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int GetFormat() const; int GetGroup() const; diff --git a/mindspore/lite/src/ops/cos.cc b/mindspore/lite/src/ops/cos.cc new file mode 100644 index 0000000000..373b121d97 --- /dev/null +++ b/mindspore/lite/src/ops/cos.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/cos.h" + +namespace mindspore { +namespace lite { +#ifndef PRIMITIVE_WRITEABLE +int Cos::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateCos(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Cos, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/cos.h b/mindspore/lite/src/ops/cos.h index 3f3890872d..b88a41f13b 100644 --- a/mindspore/lite/src/ops/cos.h +++ b/mindspore/lite/src/ops/cos.h @@ -21,7 +21,7 @@ #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" +#include "src/ops/arithmetic_self.h" namespace mindspore { namespace lite { @@ -31,27 +31,9 @@ class Cos : public ArithmeticSelf { Cos() = default; explicit Cos(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} #else - explicit Cos(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateCos(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Cos, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Cos() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/crop.cc b/mindspore/lite/src/ops/crop.cc index 1b5d040392..514a5975d2 100644 --- a/mindspore/lite/src/ops/crop.cc +++ b/mindspore/lite/src/ops/crop.cc @@ -26,7 +26,25 @@ void Crop::SetAxis(int64_t axis) { this->primitive_->value.AsCrop()->axis = axis void Crop::SetOffsets(const std::vector &offsets) { this->primitive_->value.AsCrop()->offsets = offsets; } #else - +int Crop::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Crop(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Crop return nullptr"; + return RET_ERROR; + } + std::vector offsets; + if (attr->offsets() != nullptr) { + for (int i = 0; i < static_cast(attr->offsets()->size()); i++) { + offsets.push_back(attr->offsets()->data()[i]); + } + } + auto val_offset = schema::CreateCropDirect(*fbb, attr->axis(), &offsets); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Crop, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int64_t Crop::GetAxis() const { return this->primitive_->value_as_Crop()->axis(); } std::vector Crop::GetOffsets() const { auto fb_vector = this->primitive_->value_as_Crop()->offsets(); diff --git a/mindspore/lite/src/ops/crop.h b/mindspore/lite/src/ops/crop.h index 503bc67a19..0650f7925f 100644 --- a/mindspore/lite/src/ops/crop.h +++ b/mindspore/lite/src/ops/crop.h @@ -35,35 +35,9 @@ class Crop : public PrimitiveC { void SetAxis(int64_t axis); void SetOffsets(const std::vector &offsets); #else - explicit Crop(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Crop(); - MS_ASSERT(attr != nullptr); - - auto offsets = std::make_unique>(); - for (int i = 0; i < static_cast(attr->offsets()->size()); i++) { - offsets->push_back(attr->offsets()->data()[i]); - } - - auto val_offset = schema::CreateCropDirect(fbb, attr->axis(), offsets.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Crop, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Crop() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int64_t GetAxis() const; diff --git a/mindspore/lite/src/ops/deconv2d.cc b/mindspore/lite/src/ops/deconv2d.cc index 0e3fc98928..892207c8cc 100644 --- a/mindspore/lite/src/ops/deconv2d.cc +++ b/mindspore/lite/src/ops/deconv2d.cc @@ -58,7 +58,22 @@ void DeConv2D::SetActivationType(int activation_type) { } #else - +int DeConv2D::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_DeConv2D(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_DeConv2D return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateDeConv2D( + *fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(), + attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(), + attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DeConv2D, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int DeConv2D::GetFormat() const { return this->primitive_->value_as_DeConv2D()->format(); } int DeConv2D::GetGroup() const { return this->primitive_->value_as_DeConv2D()->group(); } int DeConv2D::GetChannelIn() const { return this->primitive_->value_as_DeConv2D()->channelIn(); } diff --git a/mindspore/lite/src/ops/deconv2d.h b/mindspore/lite/src/ops/deconv2d.h index 1d2ebd8b4f..020b5a95d4 100644 --- a/mindspore/lite/src/ops/deconv2d.h +++ b/mindspore/lite/src/ops/deconv2d.h @@ -49,34 +49,9 @@ class DeConv2D : public PrimitiveC { void SetHasBias(bool has_bias); void SetActivationType(int activation_type); #else - explicit DeConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_DeConv2D(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateDeConv2D(fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), - attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(), - attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(), - attr->padRight(), attr->dilateW(), attr->dilateH(), - attr->hasBias(), attr->activationType()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_DeConv2D, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + DeConv2D() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; diff --git a/mindspore/lite/src/ops/dedepthwise_conv2d.cc b/mindspore/lite/src/ops/dedepthwise_conv2d.cc index 042b578892..b2ac622907 100644 --- a/mindspore/lite/src/ops/dedepthwise_conv2d.cc +++ b/mindspore/lite/src/ops/dedepthwise_conv2d.cc @@ -70,7 +70,24 @@ void DeDepthwiseConv2D::SetActivationType(int activation_type) { } #else +int DeDepthwiseConv2D::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_DeDepthwiseConv2D(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_DeDepthwiseConv2D return nullptr"; + return RET_ERROR; + } + + auto val_offset = schema::CreateDeDepthwiseConv2D( + *fbb, attr->format(), attr->channelIn(), attr->channelMultiplier(), attr->kernelW(), attr->kernelH(), + attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(), + attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DeDepthwiseConv2D, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int DeDepthwiseConv2D::GetFormat() const { return this->primitive_->value_as_DeDepthwiseConv2D()->format(); } int DeDepthwiseConv2D::GetChannelIn() const { return this->primitive_->value_as_DeDepthwiseConv2D()->channelIn(); } int DeDepthwiseConv2D::GetChannelMultiplier() const { diff --git a/mindspore/lite/src/ops/dedepthwise_conv2d.h b/mindspore/lite/src/ops/dedepthwise_conv2d.h index c01427c42a..142ce5b1f4 100644 --- a/mindspore/lite/src/ops/dedepthwise_conv2d.h +++ b/mindspore/lite/src/ops/dedepthwise_conv2d.h @@ -48,34 +48,9 @@ class DeDepthwiseConv2D : public PrimitiveC { void SetHasBias(bool has_bias); void SetActivationType(int activation_type); #else - explicit DeDepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_DeDepthwiseConv2D(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateDeDepthwiseConv2D(fbb, attr->format(), attr->channelIn(), attr->channelMultiplier(), - attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(), - attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(), - attr->padRight(), attr->dilateW(), attr->dilateH(), - attr->hasBias(), attr->activationType()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_DeDepthwiseConv2D, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + DeDepthwiseConv2D() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; diff --git a/mindspore/lite/src/ops/depth_to_space.cc b/mindspore/lite/src/ops/depth_to_space.cc index ef45f6fce1..e090fe0dc5 100644 --- a/mindspore/lite/src/ops/depth_to_space.cc +++ b/mindspore/lite/src/ops/depth_to_space.cc @@ -26,7 +26,19 @@ void DepthToSpace::SetBlockSize(int block_size) { this->primitive_->value.AsDept void DepthToSpace::SetFormat(int format) { this->primitive_->value.AsDepthToSpace()->format = (schema::Format)format; } #else - +int DepthToSpace::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_DepthToSpace(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_DepthToSpace return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateDepthToSpace(*fbb, attr->blockSize(), attr->format()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DepthToSpace, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int DepthToSpace::GetBlockSize() const { return this->primitive_->value_as_DepthToSpace()->blockSize(); } int DepthToSpace::GetFormat() const { return this->primitive_->value_as_DepthToSpace()->format(); } diff --git a/mindspore/lite/src/ops/depth_to_space.h b/mindspore/lite/src/ops/depth_to_space.h index 37a0ea4d5a..5320b27c4a 100644 --- a/mindspore/lite/src/ops/depth_to_space.h +++ b/mindspore/lite/src/ops/depth_to_space.h @@ -34,30 +34,9 @@ class DepthToSpace : public PrimitiveC { void SetBlockSize(int block_size); void SetFormat(int format); #else - explicit DepthToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_DepthToSpace(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateDepthToSpace(fbb, attr->blockSize(), attr->format()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_DepthToSpace, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + DepthToSpace() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetBlockSize() const; diff --git a/mindspore/lite/src/ops/depthwise_conv2d.cc b/mindspore/lite/src/ops/depthwise_conv2d.cc index e7b2d55d43..a66e33f82a 100644 --- a/mindspore/lite/src/ops/depthwise_conv2d.cc +++ b/mindspore/lite/src/ops/depthwise_conv2d.cc @@ -232,7 +232,22 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vectorvalue_as_DepthwiseConv2D(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_DepthwiseConv2D return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateDepthwiseConv2D( + *fbb, attr->format(), attr->channelIn(), attr->channelMultiplier(), attr->kernelW(), attr->kernelH(), + attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(), + attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DepthwiseConv2D, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int DepthwiseConv2D::GetFormat() const { return this->primitive_->value_as_DepthwiseConv2D()->format(); } int DepthwiseConv2D::GetChannelIn() const { return this->primitive_->value_as_DepthwiseConv2D()->channelIn(); } int DepthwiseConv2D::GetChannelMultiplier() const { diff --git a/mindspore/lite/src/ops/depthwise_conv2d.h b/mindspore/lite/src/ops/depthwise_conv2d.h index 1e5c52ed85..aada41f542 100644 --- a/mindspore/lite/src/ops/depthwise_conv2d.h +++ b/mindspore/lite/src/ops/depthwise_conv2d.h @@ -33,7 +33,7 @@ class DepthwiseConv2D : public PrimitiveC { DepthwiseConv2D() = default; explicit DepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - int UnPackAttr(const Primitive &prim, const std::vector &inputs); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; void SetFormat(int format); void SetChannelIn(int channel_in); void SetChannelMultiplier(int channel_multiplier); @@ -58,35 +58,9 @@ class DepthwiseConv2D : public PrimitiveC { #else public: - explicit DepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_DepthwiseConv2D(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateDepthwiseConv2D(fbb, attr->format(), - attr->channelIn(), attr->channelMultiplier(), - attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(), - attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(), - attr->padRight(), attr->dilateW(), attr->dilateH(), - attr->hasBias(), attr->activationType()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_DepthwiseConv2D, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + DepthwiseConv2D() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif public: diff --git a/mindspore/lite/src/ops/dequant.h b/mindspore/lite/src/ops/dequant.h index 74e031f19f..73fd1391f3 100644 --- a/mindspore/lite/src/ops/dequant.h +++ b/mindspore/lite/src/ops/dequant.h @@ -28,9 +28,9 @@ class Dequant : public PrimitiveC { MS_DECLARE_PARENT(Dequant, PrimitiveC); Dequant() = default; explicit Dequant(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - int UnPackAttr(const Primitive &prim, const std::vector &inputs); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else - explicit Dequant(schema::Primitive *primitive) : PrimitiveC(primitive) {} + Dequant() = default; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/detection_post_process.cc b/mindspore/lite/src/ops/detection_post_process.cc index c0b5642a1e..5e4754f077 100644 --- a/mindspore/lite/src/ops/detection_post_process.cc +++ b/mindspore/lite/src/ops/detection_post_process.cc @@ -88,7 +88,22 @@ void DetectionPostProcess::SetUseRegularNms(bool use_regular_nms) { } #else - +int DetectionPostProcess::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_DetectionPostProcess(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_DetectionPostProcess return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateDetectionPostProcess( + *fbb, attr->format(), attr->inputSize(), attr->hScale(), attr->wScale(), attr->xScale(), attr->yScale(), + attr->NmsIouThreshold(), attr->NmsScoreThreshold(), attr->MaxDetections(), attr->DetectionsPreClass(), + attr->MaxClassesPreDetection(), attr->NumClasses(), attr->UseRegularNms()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DetectionPostProcess, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int DetectionPostProcess::GetFormat() const { return this->primitive_->value_as_DetectionPostProcess()->format(); } int DetectionPostProcess::GetInputSize() const { return this->primitive_->value_as_DetectionPostProcess()->inputSize(); diff --git a/mindspore/lite/src/ops/detection_post_process.h b/mindspore/lite/src/ops/detection_post_process.h index ae8bd111ca..4fb9dea282 100644 --- a/mindspore/lite/src/ops/detection_post_process.h +++ b/mindspore/lite/src/ops/detection_post_process.h @@ -45,36 +45,9 @@ class DetectionPostProcess : public PrimitiveC { void SetNumClasses(int64_t num_classes); void SetUseRegularNms(bool use_regular_nms); #else - explicit DetectionPostProcess(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_DetectionPostProcess(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateDetectionPostProcess(fbb, attr->format(), attr->inputSize(), - attr->hScale(), attr->wScale(), - attr->xScale(), attr->yScale(), - attr->NmsIouThreshold(), attr->NmsScoreThreshold(), - attr->MaxDetections(), attr->DetectionsPreClass(), - attr->MaxClassesPreDetection(), attr->NumClasses(), - attr->UseRegularNms()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_DetectionPostProcess, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + DetectionPostProcess() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int GetFormat() const; int GetInputSize() const; diff --git a/mindspore/lite/src/ops/div.cc b/mindspore/lite/src/ops/div.cc index 3428e05742..12300fe9ec 100644 --- a/mindspore/lite/src/ops/div.cc +++ b/mindspore/lite/src/ops/div.cc @@ -26,7 +26,19 @@ void Div::SetActivationType(int activation_type) { } #else - +int Div::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Div(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Div return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateDiv(*fbb, attr->activationType()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Div, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int Div::GetActivationType() const { return this->primitive_->value_as_Div()->activationType(); } #endif diff --git a/mindspore/lite/src/ops/div.h b/mindspore/lite/src/ops/div.h index 026f41b309..6ca390b091 100644 --- a/mindspore/lite/src/ops/div.h +++ b/mindspore/lite/src/ops/div.h @@ -34,30 +34,9 @@ class Div : public Arithmetic { void SetActivationType(int activation_type); #else - explicit Div(schema::Primitive *primitive) : Arithmetic(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Div(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateDiv(fbb, attr->activationType()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Div, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Div() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int GetActivationType() const; }; diff --git a/mindspore/lite/src/ops/dropout.cc b/mindspore/lite/src/ops/dropout.cc index 3bd69467a6..7cce7ad1cc 100644 --- a/mindspore/lite/src/ops/dropout.cc +++ b/mindspore/lite/src/ops/dropout.cc @@ -24,7 +24,19 @@ float Dropout::GetRatio() const { return this->primitive_->value.AsDropout()->ra void Dropout::SetRatio(float ratio) { this->primitive_->value.AsDropout()->ratio = ratio; } #else - +int Dropout::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Dropout(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Dropout return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateDropout(*fbb, attr->ratio()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Dropout, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} float Dropout::GetRatio() const { return this->primitive_->value_as_Dropout()->ratio(); } #endif diff --git a/mindspore/lite/src/ops/dropout.h b/mindspore/lite/src/ops/dropout.h index f8eccec19d..f2a756eddd 100644 --- a/mindspore/lite/src/ops/dropout.h +++ b/mindspore/lite/src/ops/dropout.h @@ -34,30 +34,9 @@ class Dropout : public PrimitiveC { void SetRatio(float ratio); #else - explicit Dropout(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Dropout(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateDropout(fbb, attr->ratio()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Dropout, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Dropout() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif float GetRatio() const; }; diff --git a/mindspore/lite/src/ops/eltwise.cc b/mindspore/lite/src/ops/eltwise.cc index 9e25d73671..3760cd9f85 100644 --- a/mindspore/lite/src/ops/eltwise.cc +++ b/mindspore/lite/src/ops/eltwise.cc @@ -24,7 +24,19 @@ int Eltwise::GetMode() const { return this->primitive_->value.AsEltwise()->mode; void Eltwise::SetMode(int mode) { this->primitive_->value.AsEltwise()->mode = (schema::EltwiseMode)mode; } #else - +int Eltwise::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Eltwise(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Eltwise return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateEltwise(*fbb, attr->mode()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Eltwise, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int Eltwise::GetMode() const { return this->primitive_->value_as_Eltwise()->mode(); } #endif diff --git a/mindspore/lite/src/ops/eltwise.h b/mindspore/lite/src/ops/eltwise.h index 7227f804f9..720724b94f 100644 --- a/mindspore/lite/src/ops/eltwise.h +++ b/mindspore/lite/src/ops/eltwise.h @@ -34,30 +34,9 @@ class Eltwise : public PrimitiveC { void SetMode(int mode); #else - explicit Eltwise(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Eltwise(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateEltwise(fbb, attr->mode()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Eltwise, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Eltwise() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int GetMode() const; }; diff --git a/mindspore/lite/src/ops/elu.cc b/mindspore/lite/src/ops/elu.cc index d433d191bf..9a1e16991f 100644 --- a/mindspore/lite/src/ops/elu.cc +++ b/mindspore/lite/src/ops/elu.cc @@ -24,7 +24,19 @@ float Elu::GetAlpha() const { return this->primitive_->value.AsElu()->alpha; } void Elu::SetAlpha(float alpha) { this->primitive_->value.AsElu()->alpha = alpha; } #else - +int Elu::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Elu(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Elu return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateElu(*fbb, attr->alpha()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Elu, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} float Elu::GetAlpha() const { return this->primitive_->value_as_Elu()->alpha(); } #endif diff --git a/mindspore/lite/src/ops/elu.h b/mindspore/lite/src/ops/elu.h index 5490a2e1dc..e0f3a5f576 100644 --- a/mindspore/lite/src/ops/elu.h +++ b/mindspore/lite/src/ops/elu.h @@ -34,30 +34,9 @@ class Elu : public PrimitiveC { void SetAlpha(float alpha); #else - explicit Elu(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Elu(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateElu(fbb, attr->alpha()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Elu, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Elu() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif float GetAlpha() const; }; diff --git a/mindspore/lite/src/ops/embedding_lookup.cc b/mindspore/lite/src/ops/embedding_lookup.cc index 8808dc8cb2..270f22f310 100644 --- a/mindspore/lite/src/ops/embedding_lookup.cc +++ b/mindspore/lite/src/ops/embedding_lookup.cc @@ -24,7 +24,21 @@ float EmbeddingLookup::GetMaxNorm() const { return this->primitive_->value.AsEmb void EmbeddingLookup::SetMaxNorm(float max_norm) { this->primitive_->value.AsEmbeddingLookup()->maxNorm = max_norm; } #else +int EmbeddingLookup::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_EmbeddingLookup(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_EmbeddingLookup return nullptr"; + return RET_ERROR; + } + + auto val_offset = schema::CreateEmbeddingLookup(*fbb, attr->maxNorm()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_EmbeddingLookup, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} float EmbeddingLookup::GetMaxNorm() const { return this->primitive_->value_as_EmbeddingLookup()->maxNorm(); } #endif diff --git a/mindspore/lite/src/ops/embedding_lookup.h b/mindspore/lite/src/ops/embedding_lookup.h index 091140e29e..c51441b9b0 100644 --- a/mindspore/lite/src/ops/embedding_lookup.h +++ b/mindspore/lite/src/ops/embedding_lookup.h @@ -34,30 +34,9 @@ class EmbeddingLookup : public PrimitiveC { void SetMaxNorm(float max_norm); #else - explicit EmbeddingLookup(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_EmbeddingLookup(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateEmbeddingLookup(fbb, attr->maxNorm()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_EmbeddingLookup, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + EmbeddingLookup() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; float GetMaxNorm() const; diff --git a/mindspore/lite/src/ops/embedding_lookup_sparse.cc b/mindspore/lite/src/ops/embedding_lookup_sparse.cc index 9501ebe752..b981defde2 100644 --- a/mindspore/lite/src/ops/embedding_lookup_sparse.cc +++ b/mindspore/lite/src/ops/embedding_lookup_sparse.cc @@ -38,7 +38,32 @@ void EmbeddingLookupSparse::SetMaxNortm(float max_nortm) { } #else - +int EmbeddingLookupSparse::UnPackToFlatBuilder(const schema::Primitive *primitive, + flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_EmbeddingLookupSparse(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_EmbeddingLookupSparse return nullptr"; + return RET_ERROR; + } + std::vector spIds; + if (attr->spIds() != nullptr) { + for (int i = 0; i < static_cast(attr->spIds()->size()); i++) { + spIds.push_back(attr->spIds()->data()[i]); + } + } + std::vector spWeights; + if (attr->spWeights() != nullptr) { + for (int i = 0; i < static_cast(attr->spWeights()->size()); i++) { + spWeights.push_back(attr->spWeights()->data()[i]); + } + } + auto val_offset = schema::CreateEmbeddingLookupSparseDirect(*fbb, &spIds, &spWeights); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_EmbeddingLookupSparse, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} std::vector EmbeddingLookupSparse::GetSpIds() const { auto fb_vector = this->primitive_->value_as_EmbeddingLookupSparse()->spIds(); return std::vector(fb_vector->begin(), fb_vector->end()); diff --git a/mindspore/lite/src/ops/embedding_lookup_sparse.h b/mindspore/lite/src/ops/embedding_lookup_sparse.h index d58462a5c1..8ca4cd1177 100644 --- a/mindspore/lite/src/ops/embedding_lookup_sparse.h +++ b/mindspore/lite/src/ops/embedding_lookup_sparse.h @@ -36,39 +36,9 @@ class EmbeddingLookupSparse : public PrimitiveC { void SetSpWeights(const std::vector &sp_weights); void SetMaxNortm(float max_nortm); #else - explicit EmbeddingLookupSparse(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_EmbeddingLookupSparse(); - MS_ASSERT(attr != nullptr); - - auto spIds = std::make_unique>(); - for (int i = 0; i < static_cast(attr->spIds()->size()); i++) { - spIds->push_back(attr->spIds()->data()[i]); - } - auto spWeights = std::make_unique>(); - for (int i = 0; i < static_cast(attr->spWeights()->size()); i++) { - spWeights->push_back(attr->spWeights()->data()[i]); - } - - auto val_offset = schema:: CreateEmbeddingLookupSparseDirect(fbb, spIds.release(), spWeights.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_EmbeddingLookupSparse, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + EmbeddingLookupSparse() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif std::vector GetSpIds() const; std::vector GetSpWeights() const; diff --git a/mindspore/lite/src/ops/equal.cc b/mindspore/lite/src/ops/equal.cc new file mode 100644 index 0000000000..9732211061 --- /dev/null +++ b/mindspore/lite/src/ops/equal.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/equal.h" + +namespace mindspore { +namespace lite { +#ifndef PRIMITIVE_WRITEABLE +int Equal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateEqual(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Equal, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/equal.h b/mindspore/lite/src/ops/equal.h index d3e297d892..69194b5528 100644 --- a/mindspore/lite/src/ops/equal.h +++ b/mindspore/lite/src/ops/equal.h @@ -32,27 +32,9 @@ class Equal : public Arithmetic { Equal() = default; explicit Equal(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} #else - explicit Equal(schema::Primitive *primitive) : Arithmetic(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateEqual(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Equal, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Equal() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/exp.cc b/mindspore/lite/src/ops/exp.cc new file mode 100644 index 0000000000..1c5acbba01 --- /dev/null +++ b/mindspore/lite/src/ops/exp.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/exp.h" + +namespace mindspore { +namespace lite { +#ifndef PRIMITIVE_WRITEABLE +int Exp::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateExp(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Exp, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/exp.h b/mindspore/lite/src/ops/exp.h index 9d47375086..c79648725e 100644 --- a/mindspore/lite/src/ops/exp.h +++ b/mindspore/lite/src/ops/exp.h @@ -32,27 +32,9 @@ class Exp : public ArithmeticSelf { Exp() = default; explicit Exp(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} #else - explicit Exp(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateExp(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Exp, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Exp() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/expand_dims.cc b/mindspore/lite/src/ops/expand_dims.cc index a5fd088cfb..7961a6385b 100644 --- a/mindspore/lite/src/ops/expand_dims.cc +++ b/mindspore/lite/src/ops/expand_dims.cc @@ -24,7 +24,20 @@ int ExpandDims::GetDim() const { return this->primitive_->value.AsExpandDims()-> void ExpandDims::SetDim(int dim) { this->primitive_->value.AsExpandDims()->dim = dim; } #else +int ExpandDims::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_ExpandDims(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_ExpandDims return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateExpandDims(*fbb, attr->dim()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ExpandDims, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int ExpandDims::GetDim() const { return this->primitive_->value_as_ExpandDims()->dim(); } #endif diff --git a/mindspore/lite/src/ops/expand_dims.h b/mindspore/lite/src/ops/expand_dims.h index f2bd17360b..52007b874d 100644 --- a/mindspore/lite/src/ops/expand_dims.h +++ b/mindspore/lite/src/ops/expand_dims.h @@ -34,30 +34,9 @@ class ExpandDims : public PrimitiveC { void SetDim(int dim); #else - explicit ExpandDims(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_ExpandDims(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateExpandDims(fbb, attr->dim()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ExpandDims, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + ExpandDims() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetDim() const; diff --git a/mindspore/lite/src/ops/fake_quant_with_min_max_vars.cc b/mindspore/lite/src/ops/fake_quant_with_min_max_vars.cc index d1a2fd5ff4..5bc8b3eb26 100644 --- a/mindspore/lite/src/ops/fake_quant_with_min_max_vars.cc +++ b/mindspore/lite/src/ops/fake_quant_with_min_max_vars.cc @@ -32,7 +32,21 @@ void FakeQuantWithMinMaxVars::SetNumBits(int num_bits) { } #else +int FakeQuantWithMinMaxVars::UnPackToFlatBuilder(const schema::Primitive *primitive, + flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_FakeQuantWithMinMaxVars(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_FakeQuantWithMinMaxVars return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateFakeQuantWithMinMaxVars(*fbb, attr->narrowRange(), attr->numBits()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FakeQuantWithMinMaxVars, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} bool FakeQuantWithMinMaxVars::GetNarrowRange() const { return this->primitive_->value_as_FakeQuantWithMinMaxVars()->narrowRange(); } diff --git a/mindspore/lite/src/ops/fake_quant_with_min_max_vars.h b/mindspore/lite/src/ops/fake_quant_with_min_max_vars.h index 16017db716..a6f85f7b92 100644 --- a/mindspore/lite/src/ops/fake_quant_with_min_max_vars.h +++ b/mindspore/lite/src/ops/fake_quant_with_min_max_vars.h @@ -34,31 +34,9 @@ class FakeQuantWithMinMaxVars : public PrimitiveC { void SetNarrowRange(bool narrow_range); void SetNumBits(int num_bits); #else - explicit FakeQuantWithMinMaxVars(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_FakeQuantWithMinMaxVars(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateFakeQuantWithMinMaxVars(fbb, attr->narrowRange(), attr->numBits()); - auto prim_offset = schema::CreatePrimitive(fbb, - schema::PrimitiveType_FakeQuantWithMinMaxVars, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + FakeQuantWithMinMaxVars() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif bool GetNarrowRange() const; int GetNumBits() const; diff --git a/mindspore/lite/src/ops/fill.cc b/mindspore/lite/src/ops/fill.cc index 3e6be23153..9ae1fb4305 100644 --- a/mindspore/lite/src/ops/fill.cc +++ b/mindspore/lite/src/ops/fill.cc @@ -24,7 +24,25 @@ std::vector Fill::GetDims() const { return this->primitive_->value.AsFill() void Fill::SetDims(const std::vector &dims) { this->primitive_->value.AsFill()->dims = dims; } #else - +int Fill::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Fill(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Fill return nullptr"; + return RET_ERROR; + } + std::vector dims; + if (attr->dims() != nullptr) { + for (int i = 0; i < static_cast(attr->dims()->size()); i++) { + dims.push_back(attr->dims()->data()[i]); + } + } + auto val_offset = schema::CreateFillDirect(*fbb, &dims); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Fill, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} std::vector Fill::GetDims() const { auto fb_vector = this->primitive_->value_as_Fill()->dims(); return std::vector(fb_vector->begin(), fb_vector->end()); diff --git a/mindspore/lite/src/ops/fill.h b/mindspore/lite/src/ops/fill.h index 5b1d52a15c..f95d22542c 100644 --- a/mindspore/lite/src/ops/fill.h +++ b/mindspore/lite/src/ops/fill.h @@ -35,35 +35,9 @@ class Fill : public PrimitiveC { void SetDims(const std::vector &dims); #else - explicit Fill(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Fill(); - MS_ASSERT(attr != nullptr); - - auto dims = std::make_unique>(); - for (int i = 0; i < static_cast(attr->dims()->size()); i++) { - dims->push_back(attr->dims()->data()[i]); - } - - auto val_offset = schema::CreateFillDirect(fbb, dims.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Fill, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Fill() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetDims() const; diff --git a/mindspore/lite/src/ops/flatten.cc b/mindspore/lite/src/ops/flatten.cc index 8e852721e5..949bf3d92d 100644 --- a/mindspore/lite/src/ops/flatten.cc +++ b/mindspore/lite/src/ops/flatten.cc @@ -77,6 +77,15 @@ int Flatten::UnPackAttr(const Primitive &prim, const std::vector &in } return RET_OK; } +#else +int Flatten::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateFlatten(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Flatten, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/flatten.h b/mindspore/lite/src/ops/flatten.h index b10bfca5d4..ce60608cff 100644 --- a/mindspore/lite/src/ops/flatten.h +++ b/mindspore/lite/src/ops/flatten.h @@ -31,32 +31,13 @@ class Flatten : public PrimitiveC { MS_DECLARE_PARENT(Flatten, PrimitiveC); Flatten() = default; explicit Flatten(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else - explicit Flatten(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateFlatten(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Flatten, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Flatten() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; - - int UnPackAttr(const Primitive &prim, const std::vector &inputs); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/floor.cc b/mindspore/lite/src/ops/floor.cc new file mode 100644 index 0000000000..d284b102d9 --- /dev/null +++ b/mindspore/lite/src/ops/floor.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/floor.h" + +namespace mindspore { +namespace lite { +#ifndef PRIMITIVE_WRITEABLE + +int Floor::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateFloor(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Floor, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/floor.h b/mindspore/lite/src/ops/floor.h index 8baf469a6f..f970218cc4 100644 --- a/mindspore/lite/src/ops/floor.h +++ b/mindspore/lite/src/ops/floor.h @@ -21,7 +21,7 @@ #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" +#include "src/ops/arithmetic_self.h" namespace mindspore { namespace lite { @@ -32,27 +32,9 @@ class Floor : public ArithmeticSelf { Floor() = default; explicit Floor(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} #else - explicit Floor(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateFloor(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Floor, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Floor() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/floor_div.cc b/mindspore/lite/src/ops/floor_div.cc new file mode 100644 index 0000000000..0aa4610d3e --- /dev/null +++ b/mindspore/lite/src/ops/floor_div.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/floor_div.h" + +namespace mindspore { +namespace lite { +#ifndef PRIMITIVE_WRITEABLE + +int FloorDiv::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateFloor(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Floor, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/floor_div.h b/mindspore/lite/src/ops/floor_div.h index 4fec27263a..5525218708 100644 --- a/mindspore/lite/src/ops/floor_div.h +++ b/mindspore/lite/src/ops/floor_div.h @@ -32,27 +32,9 @@ class FloorDiv : public Arithmetic { FloorDiv() = default; explicit FloorDiv(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} #else - explicit FloorDiv(schema::Primitive *primitive) : Arithmetic(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateFloorDiv(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_FloorDiv, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + FloorDiv() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/floor_mod.cc b/mindspore/lite/src/ops/floor_mod.cc new file mode 100644 index 0000000000..f903620655 --- /dev/null +++ b/mindspore/lite/src/ops/floor_mod.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/floor_mod.h" + +namespace mindspore { +namespace lite { +#ifndef PRIMITIVE_WRITEABLE + +int FloorMod::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateFloorMod(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FloorMod, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/floor_mod.h b/mindspore/lite/src/ops/floor_mod.h index 32c9f772db..adbca5e52c 100644 --- a/mindspore/lite/src/ops/floor_mod.h +++ b/mindspore/lite/src/ops/floor_mod.h @@ -32,27 +32,9 @@ class FloorMod : public Arithmetic { FloorMod() = default; explicit FloorMod(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} #else - explicit FloorMod(schema::Primitive *primitive) : Arithmetic(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateFloorMod(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_FloorMod, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + FloorMod() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/full_connection.cc b/mindspore/lite/src/ops/full_connection.cc index 5550655610..4be4e45b4b 100644 --- a/mindspore/lite/src/ops/full_connection.cc +++ b/mindspore/lite/src/ops/full_connection.cc @@ -31,7 +31,21 @@ void FullConnection::SetActivationType(int activationType) { this->primitive_->value.AsFullConnection()->activationType = (schema::ActivationType)activationType; } #else +int FullConnection::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_FullConnection(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_FullConnection return nullptr"; + return RET_ERROR; + } + auto val_offset = + schema::CreateFullConnection(*fbb, attr->hasBias(), attr->axis(), attr->useAxis(), attr->activationType()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FullConnection, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} bool FullConnection::GetHasBias() const { return this->primitive_->value_as_FullConnection()->hasBias(); } int FullConnection::GetAxis() const { return this->primitive_->value_as_FullConnection()->axis(); } bool FullConnection::GetUseAxis() const { return this->primitive_->value_as_FullConnection()->useAxis(); } diff --git a/mindspore/lite/src/ops/full_connection.h b/mindspore/lite/src/ops/full_connection.h index c60ff080b7..c4d5e980ff 100644 --- a/mindspore/lite/src/ops/full_connection.h +++ b/mindspore/lite/src/ops/full_connection.h @@ -36,31 +36,9 @@ class FullConnection : public PrimitiveC { void SetUseAxis(bool use_axis); void SetActivationType(int activationType); #else - explicit FullConnection(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_FullConnection(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateFullConnection(fbb, attr->hasBias(), attr->axis(), - attr->useAxis(), attr->activationType()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_FullConnection, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + FullConnection() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; bool GetHasBias() const; diff --git a/mindspore/lite/src/ops/fused_batchnorm.cc b/mindspore/lite/src/ops/fused_batchnorm.cc index c78bb2f91d..a5d3a31929 100644 --- a/mindspore/lite/src/ops/fused_batchnorm.cc +++ b/mindspore/lite/src/ops/fused_batchnorm.cc @@ -28,7 +28,20 @@ void FusedBatchNorm::SetMomentum(float momentum) { this->primitive_->value.AsFus void FusedBatchNorm::SetSpatial(int spatial) { this->primitive_->value.AsFusedBatchNorm()->spatial = spatial; } #else +int FusedBatchNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_FusedBatchNorm(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_FusedBatchNorm return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateFusedBatchNorm(*fbb, attr->epsilon(), attr->momentum(), attr->spatial()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FusedBatchNorm, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} float FusedBatchNorm::GetEpsilon() const { return this->primitive_->value_as_FusedBatchNorm()->epsilon(); } float FusedBatchNorm::GetMomentum() const { return this->primitive_->value_as_FusedBatchNorm()->momentum(); } int FusedBatchNorm::GetSpatial() const { return this->primitive_->value_as_FusedBatchNorm()->spatial(); } diff --git a/mindspore/lite/src/ops/fused_batchnorm.h b/mindspore/lite/src/ops/fused_batchnorm.h index a812ec343b..b95314def8 100644 --- a/mindspore/lite/src/ops/fused_batchnorm.h +++ b/mindspore/lite/src/ops/fused_batchnorm.h @@ -35,30 +35,9 @@ class FusedBatchNorm : public PrimitiveC { void SetMomentum(float momentum); void SetSpatial(int spatial); #else - explicit FusedBatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_FusedBatchNorm(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateFusedBatchNorm(fbb, attr->epsilon(), attr->momentum(), attr->spatial()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_FusedBatchNorm, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + FusedBatchNorm() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif float GetEpsilon() const; float GetMomentum() const; diff --git a/mindspore/lite/src/ops/gather.cc b/mindspore/lite/src/ops/gather.cc index a899bbb0a6..4d48933d19 100644 --- a/mindspore/lite/src/ops/gather.cc +++ b/mindspore/lite/src/ops/gather.cc @@ -29,7 +29,20 @@ void Gather::SetAxis(int axis) { this->primitive_->value.AsGather()->axis = axis void Gather::SetBatchDims(int batch_dims) { this->primitive_->value.AsGather()->batchDims = batch_dims; } #else +int Gather::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Gather(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Gather return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateGather(*fbb, attr->axis(), attr->batchDims()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Gather, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int Gather::GetAxis() const { return this->primitive_->value_as_Gather()->axis(); } int Gather::GetBatchDims() const { return this->primitive_->value_as_Gather()->batchDims(); } diff --git a/mindspore/lite/src/ops/gather.h b/mindspore/lite/src/ops/gather.h index 81401c6f7c..088a736efa 100644 --- a/mindspore/lite/src/ops/gather.h +++ b/mindspore/lite/src/ops/gather.h @@ -34,30 +34,9 @@ class Gather : public PrimitiveC { void SetAxis(int axis); void SetBatchDims(int batch_dims); #else - explicit Gather(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Gather(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateGather(fbb, attr->axis(), attr->batchDims()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Gather, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Gather() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; diff --git a/mindspore/lite/src/ops/gather_nd.cc b/mindspore/lite/src/ops/gather_nd.cc index 5fe4f09bbe..e88e913339 100644 --- a/mindspore/lite/src/ops/gather_nd.cc +++ b/mindspore/lite/src/ops/gather_nd.cc @@ -24,7 +24,20 @@ int GatherNd::GetBatchDims() const { return this->primitive_->value.AsGatherNd() void GatherNd::SetBatchDims(int batch_dims) { this->primitive_->value.AsGatherNd()->batchDims = batch_dims; } #else +int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_GatherNd(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_GatherNd return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateGatherNd(*fbb, attr->batchDims()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_GatherNd, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} int GatherNd::GetBatchDims() const { return this->primitive_->value_as_GatherNd()->batchDims(); } #endif diff --git a/mindspore/lite/src/ops/gather_nd.h b/mindspore/lite/src/ops/gather_nd.h index 6e67f8c826..f578b55ae4 100644 --- a/mindspore/lite/src/ops/gather_nd.h +++ b/mindspore/lite/src/ops/gather_nd.h @@ -34,30 +34,9 @@ class GatherNd : public PrimitiveC { void SetBatchDims(int batch_dims); #else - explicit GatherNd(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_GatherNd(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateGatherNd(fbb, attr->batchDims()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_GatherNd, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + GatherNd() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetBatchDims() const; diff --git a/mindspore/lite/src/ops/greater.cc b/mindspore/lite/src/ops/greater.cc new file mode 100644 index 0000000000..bd92f1a1b1 --- /dev/null +++ b/mindspore/lite/src/ops/greater.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/greater.h" + +namespace mindspore { +namespace lite { +#ifndef PRIMITIVE_WRITEABLE + +int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateGreater(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Greater, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/greater.h b/mindspore/lite/src/ops/greater.h index 96c044b670..611025152d 100644 --- a/mindspore/lite/src/ops/greater.h +++ b/mindspore/lite/src/ops/greater.h @@ -31,27 +31,9 @@ class Greater : public Arithmetic { Greater() = default; explicit Greater(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} #else - explicit Greater(schema::Primitive *primitive) : Arithmetic(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateGreater(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Greater, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Greater() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/greater_equal.cc b/mindspore/lite/src/ops/greater_equal.cc new file mode 100644 index 0000000000..bd2e5b1c45 --- /dev/null +++ b/mindspore/lite/src/ops/greater_equal.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/greater_equal.h" + +namespace mindspore { +namespace lite { +#ifndef PRIMITIVE_WRITEABLE +int GreaterEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateGreaterEqual(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_GreaterEqual, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/greater_equal.h b/mindspore/lite/src/ops/greater_equal.h index 8363ae3990..c6a6001764 100644 --- a/mindspore/lite/src/ops/greater_equal.h +++ b/mindspore/lite/src/ops/greater_equal.h @@ -32,27 +32,9 @@ class GreaterEqual : public Arithmetic { GreaterEqual() = default; explicit GreaterEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} #else - explicit GreaterEqual(schema::Primitive *primitive) : Arithmetic(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateGreaterEqual(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_GreaterEqual, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + GreaterEqual() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/l2_norm.cc b/mindspore/lite/src/ops/l2_norm.cc index 507c9fa9eb..4c8431ae68 100644 --- a/mindspore/lite/src/ops/l2_norm.cc +++ b/mindspore/lite/src/ops/l2_norm.cc @@ -26,7 +26,26 @@ void L2Norm::SetAxis(const std::vector &axis) { this->primitive_->value.AsL void L2Norm::SetEpsilon(float epsilon) { this->primitive_->value.AsL2Norm()->epsilon = epsilon; } #else +int L2Norm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_L2Norm(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_L2Norm return nullptr"; + return RET_ERROR; + } + std::vector axis; + if (attr->axis() != nullptr) { + for (int i = 0; i < static_cast(attr->axis()->size()); i++) { + axis.push_back(attr->axis()->data()[i]); + } + } + auto val_offset = schema::CreateL2NormDirect(*fbb, &axis, attr->epsilon()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_L2Norm, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} std::vector L2Norm::GetAxis() const { auto fb_vector = this->primitive_->value_as_L2Norm()->axis(); return std::vector(fb_vector->begin(), fb_vector->end()); diff --git a/mindspore/lite/src/ops/l2_norm.h b/mindspore/lite/src/ops/l2_norm.h index 17da92ba65..e44579d574 100644 --- a/mindspore/lite/src/ops/l2_norm.h +++ b/mindspore/lite/src/ops/l2_norm.h @@ -35,35 +35,9 @@ class L2Norm : public PrimitiveC { void SetAxis(const std::vector &axis); void SetEpsilon(float epsilon); #else - explicit L2Norm(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_L2Norm(); - MS_ASSERT(attr != nullptr); - - auto axis = std::make_unique>(); - for (int i = 0; i < static_cast(attr->axis()->size()); i++) { - axis->push_back(attr->axis()->data()[i]); - } - - auto val_offset = schema::CreateL2NormDirect(fbb, axis.release(), attr->epsilon()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_L2Norm, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + L2Norm() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif std::vector GetAxis() const; float GetEpsilon() const; diff --git a/mindspore/lite/src/ops/leaky_relu.cc b/mindspore/lite/src/ops/leaky_relu.cc index 25ac7b50b5..7164372f6b 100644 --- a/mindspore/lite/src/ops/leaky_relu.cc +++ b/mindspore/lite/src/ops/leaky_relu.cc @@ -29,6 +29,19 @@ void LeakyReLU::SetNegativeSlope(float negative_slope) { float LeakyReLU::GetNegativeSlope() const { return this->primitive_->value_as_LeakyReLU()->negativeSlope(); } +int LeakyReLU::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_LeakyReLU(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_LeakyReLU return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateLeakyReLU(*fbb, attr->negativeSlope()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_LeakyReLU, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/leaky_relu.h b/mindspore/lite/src/ops/leaky_relu.h index 6723f66e6c..f72b75516a 100644 --- a/mindspore/lite/src/ops/leaky_relu.h +++ b/mindspore/lite/src/ops/leaky_relu.h @@ -34,30 +34,9 @@ class LeakyReLU : public PrimitiveC { void SetNegativeSlope(float negative_slope); #else - explicit LeakyReLU(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_LeakyReLU(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateLeakyReLU(fbb, attr->negativeSlope()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_LeakyReLU, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + LeakyReLU() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif float GetNegativeSlope() const; }; diff --git a/mindspore/lite/src/ops/less.cc b/mindspore/lite/src/ops/less.cc new file mode 100644 index 0000000000..57a98d87a9 --- /dev/null +++ b/mindspore/lite/src/ops/less.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/less.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +#else +int Less::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + + auto val_offset = schema::CreateLess(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Less, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/less.h b/mindspore/lite/src/ops/less.h index 5e44f5194a..e230f89412 100644 --- a/mindspore/lite/src/ops/less.h +++ b/mindspore/lite/src/ops/less.h @@ -32,27 +32,9 @@ class Less : public Arithmetic { Less() = default; explicit Less(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} #else - explicit Less(schema::Primitive *primitive) : Arithmetic(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateLess(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Less, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Less() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/less_equal.cc b/mindspore/lite/src/ops/less_equal.cc new file mode 100644 index 0000000000..7274f8cc22 --- /dev/null +++ b/mindspore/lite/src/ops/less_equal.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/less_equal.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +#else +int LessEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateLessEqual(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_LessEqual, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/less_equal.h b/mindspore/lite/src/ops/less_equal.h index ee91141c95..43a906713d 100644 --- a/mindspore/lite/src/ops/less_equal.h +++ b/mindspore/lite/src/ops/less_equal.h @@ -32,27 +32,9 @@ class LessEqual : public Arithmetic { LessEqual() = default; explicit LessEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} #else - explicit LessEqual(schema::Primitive *primitive) : Arithmetic(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateLessEqual(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_LessEqual, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + LessEqual() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/local_response_normalization.cc b/mindspore/lite/src/ops/local_response_normalization.cc index 2d758b7de6..518567d785 100644 --- a/mindspore/lite/src/ops/local_response_normalization.cc +++ b/mindspore/lite/src/ops/local_response_normalization.cc @@ -60,6 +60,22 @@ float LocalResponseNormalization::GetBeta() const { return this->primitive_->value_as_LocalResponseNormalization()->beta(); } +int LocalResponseNormalization::UnPackToFlatBuilder(const schema::Primitive *primitive, + flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_LocalResponseNormalization(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_LocalResponseNormalization return nullptr"; + return RET_ERROR; + } + auto val_offset = + schema::CreateLocalResponseNormalization(*fbb, attr->depth_radius(), attr->bias(), attr->alpha(), attr->beta()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_LocalResponseNormalization, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/local_response_normalization.h b/mindspore/lite/src/ops/local_response_normalization.h index 9200d71aa7..108e0c8c83 100644 --- a/mindspore/lite/src/ops/local_response_normalization.h +++ b/mindspore/lite/src/ops/local_response_normalization.h @@ -36,31 +36,9 @@ class LocalResponseNormalization : public PrimitiveC { void SetAlpha(float alpha); void SetBeta(float beta); #else - explicit LocalResponseNormalization(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_LocalResponseNormalization(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateLocalResponseNormalization(fbb, attr->depth_radius(), attr->bias(), - attr->alpha(), attr->beta()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_LocalResponseNormalization, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + LocalResponseNormalization() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int GetDepthRadius() const; float GetBias() const; diff --git a/mindspore/lite/src/ops/log.cc b/mindspore/lite/src/ops/log.cc new file mode 100644 index 0000000000..f35ec426a4 --- /dev/null +++ b/mindspore/lite/src/ops/log.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/log.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +#else +int Log::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateLog(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Log, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/log.h b/mindspore/lite/src/ops/log.h index 9a3b2f1599..1bbac2eba5 100644 --- a/mindspore/lite/src/ops/log.h +++ b/mindspore/lite/src/ops/log.h @@ -32,27 +32,9 @@ class Log : public ArithmeticSelf { Log() = default; explicit Log(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} #else - explicit Log(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateLog(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Log, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Log() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/logical_and.cc b/mindspore/lite/src/ops/logical_and.cc new file mode 100644 index 0000000000..8cc73dfe50 --- /dev/null +++ b/mindspore/lite/src/ops/logical_and.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/logical_and.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +#else +int LogicalAnd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateLogicalAnd(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_LogicalAnd, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/logical_and.h b/mindspore/lite/src/ops/logical_and.h index 10e1ef4817..e323b7e9e8 100644 --- a/mindspore/lite/src/ops/logical_and.h +++ b/mindspore/lite/src/ops/logical_and.h @@ -21,7 +21,7 @@ #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" +#include "src/ops/arithmetic.h" namespace mindspore { namespace lite { @@ -32,27 +32,9 @@ class LogicalAnd : public Arithmetic { LogicalAnd() = default; explicit LogicalAnd(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} #else - explicit LogicalAnd(schema::Primitive *primitive) : Arithmetic(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateLogicalAnd(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_LogicalAnd, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + LogicalAnd() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/logical_not.cc b/mindspore/lite/src/ops/logical_not.cc new file mode 100644 index 0000000000..c67869f932 --- /dev/null +++ b/mindspore/lite/src/ops/logical_not.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/logical_not.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +#else +int LogicalNot::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateLogicalNot(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_LogicalNot, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/logical_not.h b/mindspore/lite/src/ops/logical_not.h index 4196ba99c8..69555551b8 100644 --- a/mindspore/lite/src/ops/logical_not.h +++ b/mindspore/lite/src/ops/logical_not.h @@ -21,7 +21,7 @@ #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" +#include "src/ops/arithmetic_self.h" namespace mindspore { namespace lite { @@ -32,27 +32,9 @@ class LogicalNot : public ArithmeticSelf { LogicalNot() = default; explicit LogicalNot(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} #else - explicit LogicalNot(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateLogicalNot(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_LogicalNot, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + LogicalNot() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/logical_or.cc b/mindspore/lite/src/ops/logical_or.cc new file mode 100644 index 0000000000..2d8f73f040 --- /dev/null +++ b/mindspore/lite/src/ops/logical_or.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/logical_or.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +#else +int LogicalOr::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateLogicalOr(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_LogicalOr, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/logical_or.h b/mindspore/lite/src/ops/logical_or.h index a9e0045e30..5afc583e48 100644 --- a/mindspore/lite/src/ops/logical_or.h +++ b/mindspore/lite/src/ops/logical_or.h @@ -21,7 +21,7 @@ #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" +#include "src/ops/arithmetic.h" namespace mindspore { namespace lite { @@ -32,27 +32,9 @@ class LogicalOr : public Arithmetic { LogicalOr() = default; explicit LogicalOr(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} #else - explicit LogicalOr(schema::Primitive *primitive) : Arithmetic(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateLogicalOr(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_LogicalOr, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + LogicalOr() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/lrn.cc b/mindspore/lite/src/ops/lrn.cc index 58ab9115e5..55d7745c9f 100644 --- a/mindspore/lite/src/ops/lrn.cc +++ b/mindspore/lite/src/ops/lrn.cc @@ -36,6 +36,19 @@ float Lrn::GetBeta() const { return this->primitive_->value_as_Lrn()->beta(); } float Lrn::GetBias() const { return this->primitive_->value_as_Lrn()->bias(); } int Lrn::GetSize() const { return this->primitive_->value_as_Lrn()->size(); } +int Lrn::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Lrn(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Lrn return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateLrn(*fbb, attr->alpha(), attr->beta(), attr->bias(), attr->size()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Lrn, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/lrn.h b/mindspore/lite/src/ops/lrn.h index 047971045d..d1ccb69cb4 100644 --- a/mindspore/lite/src/ops/lrn.h +++ b/mindspore/lite/src/ops/lrn.h @@ -36,30 +36,9 @@ class Lrn : public PrimitiveC { void SetBias(float bias); void SetSize(int size); #else - explicit Lrn(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Lrn(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateLrn(fbb, attr->alpha(), attr->beta(), attr->bias(), attr->size()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Lrn, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Lrn() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif float GetAlpha() const; float GetBeta() const; diff --git a/mindspore/lite/src/ops/lstm.cc b/mindspore/lite/src/ops/lstm.cc index bf95c1b8c9..8548bc846e 100644 --- a/mindspore/lite/src/ops/lstm.cc +++ b/mindspore/lite/src/ops/lstm.cc @@ -26,7 +26,19 @@ void Lstm::SetBidirection(bool bidirection) { this->primitive_->value.AsLstm()-> #else bool Lstm::GetBidirection() const { return this->primitive_->value_as_Lstm()->bidirection(); } - +int Lstm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Lstm(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Lstm return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateLstm(*fbb, attr->bidirection()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Lstm, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif const int kLstmInputNum = 6; diff --git a/mindspore/lite/src/ops/lstm.h b/mindspore/lite/src/ops/lstm.h index a5e99c89b4..f30dbd0aed 100644 --- a/mindspore/lite/src/ops/lstm.h +++ b/mindspore/lite/src/ops/lstm.h @@ -34,30 +34,9 @@ class Lstm : public PrimitiveC { void SetBidirection(bool bidirection); #else - explicit Lstm(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Lstm(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateLstm(fbb, attr->bidirection()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Lstm, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Lstm() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; bool GetBidirection() const; diff --git a/mindspore/lite/src/ops/make_tuple.cc b/mindspore/lite/src/ops/make_tuple.cc index bfa56377c4..78ca0b1084 100644 --- a/mindspore/lite/src/ops/make_tuple.cc +++ b/mindspore/lite/src/ops/make_tuple.cc @@ -48,6 +48,15 @@ int MakeTuple::UnPackAttr(const Primitive &prim, const std::vector & } return RET_OK; } +#else +int MakeTuple::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateMakeTuple(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_MakeTuple, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/make_tuple.h b/mindspore/lite/src/ops/make_tuple.h index 71fd062ad7..04c621b587 100644 --- a/mindspore/lite/src/ops/make_tuple.h +++ b/mindspore/lite/src/ops/make_tuple.h @@ -27,29 +27,11 @@ class MakeTuple : public PrimitiveC { MS_DECLARE_PARENT(MakeTuple, PrimitiveC); MakeTuple() = default; explicit MakeTuple(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - int UnPackAttr(const Primitive &prim, const std::vector &inputs); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else - explicit MakeTuple(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateMakeTuple(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_MakeTuple, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + MakeTuple() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/matmul.cc b/mindspore/lite/src/ops/matmul.cc index 7f9227e63d..815eaa24a5 100644 --- a/mindspore/lite/src/ops/matmul.cc +++ b/mindspore/lite/src/ops/matmul.cc @@ -149,6 +149,20 @@ int MatMul::UnPackAttr(const Primitive &prim, const std::vector &inp bool MatMul::GetTransposeA() const { return this->primitive_->value_as_MatMul()->transposeA(); } bool MatMul::GetTransposeB() const { return this->primitive_->value_as_MatMul()->transposeB(); } +int MatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_MatMul(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_MatMul return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateMatMul(*fbb, attr->transposeA(), attr->transposeB()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_MatMul, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + #endif int MatMul::InferShape(std::vector inputs_, std::vector outputs_) { diff --git a/mindspore/lite/src/ops/matmul.h b/mindspore/lite/src/ops/matmul.h index e2a863e15e..edbd61b252 100644 --- a/mindspore/lite/src/ops/matmul.h +++ b/mindspore/lite/src/ops/matmul.h @@ -32,7 +32,7 @@ class MatMul : public PrimitiveC { public: MatMul() = default; explicit MatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - int UnPackAttr(const Primitive &prim, const std::vector &inputs); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; void SetTransposeA(bool transpose_a); void SetTransposeB(bool transpose_b); @@ -43,30 +43,9 @@ class MatMul : public PrimitiveC { #else public: - explicit MatMul(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_MatMul(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateMatMul(fbb, attr->transposeA(), attr->transposeB()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_MatMul, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + MatMul() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif public: diff --git a/mindspore/lite/src/ops/matrix_diag.cc b/mindspore/lite/src/ops/matrix_diag.cc index 4a0835ed5a..b92094c85c 100644 --- a/mindspore/lite/src/ops/matrix_diag.cc +++ b/mindspore/lite/src/ops/matrix_diag.cc @@ -38,6 +38,20 @@ int MatrixDiag::GetNumRows() const { return this->primitive_->value_as_MatrixDia int MatrixDiag::GetNumCols() const { return this->primitive_->value_as_MatrixDiag()->numCols(); } float MatrixDiag::GetPaddingValue() const { return this->primitive_->value_as_MatrixDiag()->paddingValue(); } +int MatrixDiag::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_MatrixDiag(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_MatrixDiag return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateMatrixDiag(*fbb, attr->k(), attr->numRows(), attr->numCols(), attr->paddingValue()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_MatrixDiag, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/matrix_diag.h b/mindspore/lite/src/ops/matrix_diag.h index efb56a4e3a..fe36b2558d 100644 --- a/mindspore/lite/src/ops/matrix_diag.h +++ b/mindspore/lite/src/ops/matrix_diag.h @@ -36,31 +36,9 @@ class MatrixDiag : public PrimitiveC { void SetNumCols(int num_cols); void SetPaddingValue(float padding_value); #else - explicit MatrixDiag(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_MatrixDiag(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateMatrixDiag(fbb, attr->k(), attr->numRows(), - attr->numCols(), attr->paddingValue()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_MatrixDiag, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + MatrixDiag() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int GetK() const; int GetNumRows() const; diff --git a/mindspore/lite/src/ops/maximum.cc b/mindspore/lite/src/ops/maximum.cc new file mode 100644 index 0000000000..39223ee4b5 --- /dev/null +++ b/mindspore/lite/src/ops/maximum.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/maximum.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +#else +int Maximum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateMaximum(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Maximum, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/maximum.h b/mindspore/lite/src/ops/maximum.h index e0c1ab17cb..88fa03aa82 100644 --- a/mindspore/lite/src/ops/maximum.h +++ b/mindspore/lite/src/ops/maximum.h @@ -21,38 +21,20 @@ #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" +#include "src/ops/arithmetic.h" namespace mindspore { namespace lite { class Maximum : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE - MS_DECLARE_PARENT(Activation, Arithmetic); + MS_DECLARE_PARENT(Arithmetic, Arithmetic); Maximum() = default; explicit Maximum(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} #else - explicit Maximum(schema::Primitive *primitive) : Arithmetic(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateMaximum(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Maximum, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Maximum() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/mean.cc b/mindspore/lite/src/ops/mean.cc index cd1ae019fe..2e7324796b 100644 --- a/mindspore/lite/src/ops/mean.cc +++ b/mindspore/lite/src/ops/mean.cc @@ -33,6 +33,26 @@ std::vector Mean::GetAxis() const { } bool Mean::GetKeepDims() const { return this->primitive_->value_as_Mean()->keepDims(); } +int Mean::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Mean(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Mean return nullptr"; + return RET_ERROR; + } + std::vector axis; + if (attr->axis() != nullptr) { + for (int i = 0; i < static_cast(attr->axis()->size()); i++) { + axis.push_back(attr->axis()->data()[i]); + } + } + auto val_offset = schema::CreateMeanDirect(*fbb, &axis, attr->keepDims()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Mean, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + #endif namespace { diff --git a/mindspore/lite/src/ops/mean.h b/mindspore/lite/src/ops/mean.h index 9c9aec4ede..fd9a42e8a3 100644 --- a/mindspore/lite/src/ops/mean.h +++ b/mindspore/lite/src/ops/mean.h @@ -35,35 +35,9 @@ class Mean : public PrimitiveC { void SetAxis(const std::vector &axis); void SetKeepDims(bool keep_dims); #else - explicit Mean(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Mean(); - MS_ASSERT(attr != nullptr); - - auto axis = std::make_unique>(); - for (int i = 0; i < static_cast(attr->axis()->size()); i++) { - axis->push_back(attr->axis()->data()[i]); - } - - auto val_offset = schema::CreateMeanDirect(fbb, axis.release(), attr->keepDims()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Mean, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Mean() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetAxis() const; diff --git a/mindspore/lite/src/ops/minimum.cc b/mindspore/lite/src/ops/minimum.cc new file mode 100644 index 0000000000..c2c8c8fd56 --- /dev/null +++ b/mindspore/lite/src/ops/minimum.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/minimum.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +#else +int Minimum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateMinimum(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Minimum, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/minimum.h b/mindspore/lite/src/ops/minimum.h index d0901701b5..6c4a097952 100644 --- a/mindspore/lite/src/ops/minimum.h +++ b/mindspore/lite/src/ops/minimum.h @@ -21,38 +21,20 @@ #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" +#include "src/ops/arithmetic.h" namespace mindspore { namespace lite { class Minimum : public Arithmetic { public: #ifdef PRIMITIVE_WRITEABLE - MS_DECLARE_PARENT(Activation, Arithmetic); + MS_DECLARE_PARENT(Arithmetic, Arithmetic); Minimum() = default; explicit Minimum(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} #else - explicit Minimum(schema::Primitive *primitive) : Arithmetic(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateMinimum(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Minimum, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Minimum() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/mul.cc b/mindspore/lite/src/ops/mul.cc index e2b4a865d4..f6d2fab12a 100644 --- a/mindspore/lite/src/ops/mul.cc +++ b/mindspore/lite/src/ops/mul.cc @@ -58,6 +58,20 @@ int Mul::UnPackAttr(const Primitive &prim, const std::vector &inputs int Mul::GetActivationType() const { return this->primitive_->value_as_Mul()->activationType(); } +int Mul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Mul(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Mul return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateMul(*fbb, attr->activationType()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Mul, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/mul.h b/mindspore/lite/src/ops/mul.h index aaad8775ae..ba582db115 100644 --- a/mindspore/lite/src/ops/mul.h +++ b/mindspore/lite/src/ops/mul.h @@ -32,36 +32,13 @@ class Mul : public Arithmetic { Mul() = default; explicit Mul(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} void SetActivationType(int activation_type); - + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else - explicit Mul(schema::Primitive *primitive) : Arithmetic(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Mul(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateMul(fbb, attr->activationType()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Mul, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Mul() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int GetActivationType() const; - - int UnPackAttr(const Primitive &prim, const std::vector &inputs); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/nchw2nhwc.cc b/mindspore/lite/src/ops/nchw2nhwc.cc index 170f04020f..0ac5c25639 100644 --- a/mindspore/lite/src/ops/nchw2nhwc.cc +++ b/mindspore/lite/src/ops/nchw2nhwc.cc @@ -19,6 +19,18 @@ namespace mindspore { namespace lite { +#ifdef PRIMITIVE_WRITEABLE +#else +int Nchw2Nhwc::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateNchw2Nhwc(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Nchw2Nhwc, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif + int Nchw2Nhwc::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/nchw2nhwc.h b/mindspore/lite/src/ops/nchw2nhwc.h index 9eb342c6a0..8f7ddd0ef1 100644 --- a/mindspore/lite/src/ops/nchw2nhwc.h +++ b/mindspore/lite/src/ops/nchw2nhwc.h @@ -32,27 +32,9 @@ class Nchw2Nhwc : public PrimitiveC { Nchw2Nhwc() = default; explicit Nchw2Nhwc(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} #else - explicit Nchw2Nhwc(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateNchw2Nhwc(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Nchw2Nhwc, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Nchw2Nhwc() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; }; diff --git a/mindspore/lite/src/ops/nhwc2nchw.cc b/mindspore/lite/src/ops/nhwc2nchw.cc index 9bff02ab2e..a5f73bcfe0 100644 --- a/mindspore/lite/src/ops/nhwc2nchw.cc +++ b/mindspore/lite/src/ops/nhwc2nchw.cc @@ -19,6 +19,19 @@ namespace mindspore { namespace lite { + +#ifdef PRIMITIVE_WRITEABLE +#else +int Nhwc2Nchw::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateNhwc2Nchw(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Nhwc2Nchw, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif + int Nhwc2Nchw::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/nhwc2nchw.h b/mindspore/lite/src/ops/nhwc2nchw.h index 50bce7c668..479769cc19 100644 --- a/mindspore/lite/src/ops/nhwc2nchw.h +++ b/mindspore/lite/src/ops/nhwc2nchw.h @@ -32,27 +32,9 @@ class Nhwc2Nchw : public PrimitiveC { Nhwc2Nchw() = default; explicit Nhwc2Nchw(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} #else - explicit Nhwc2Nchw(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateNhwc2Nchw(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Nhwc2Nchw, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Nhwc2Nchw() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; }; diff --git a/mindspore/lite/src/ops/not_equal.cc b/mindspore/lite/src/ops/not_equal.cc new file mode 100644 index 0000000000..b768d74293 --- /dev/null +++ b/mindspore/lite/src/ops/not_equal.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/not_equal.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +#else +int NotEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateSin(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Sin, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/not_equal.h b/mindspore/lite/src/ops/not_equal.h index 67b6afb047..2c73caaa5f 100644 --- a/mindspore/lite/src/ops/not_equal.h +++ b/mindspore/lite/src/ops/not_equal.h @@ -21,7 +21,7 @@ #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" +#include "src/ops/arithmetic.h" namespace mindspore { namespace lite { @@ -32,27 +32,9 @@ class NotEqual : public Arithmetic { NotEqual() = default; explicit NotEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} #else - explicit NotEqual(schema::Primitive *primitive) : Arithmetic(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateNotEqual(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_NotEqual, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + NotEqual() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/one_hot.cc b/mindspore/lite/src/ops/one_hot.cc index 5a40165bea..41b2040088 100644 --- a/mindspore/lite/src/ops/one_hot.cc +++ b/mindspore/lite/src/ops/one_hot.cc @@ -27,6 +27,19 @@ void OneHot::SetAxis(int axis) { this->primitive_->value.AsOneHot()->axis = axis int OneHot::GetAxis() const { return this->primitive_->value_as_OneHot()->axis(); } +int OneHot::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_OneHot(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_OneHot return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateOneHot(*fbb, attr->axis()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_OneHot, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif namespace { diff --git a/mindspore/lite/src/ops/one_hot.h b/mindspore/lite/src/ops/one_hot.h index ab32af0d57..deaa9ab1f1 100644 --- a/mindspore/lite/src/ops/one_hot.h +++ b/mindspore/lite/src/ops/one_hot.h @@ -34,30 +34,9 @@ class OneHot : public PrimitiveC { void SetAxis(int axis); #else - explicit OneHot(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_OneHot(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateOneHot(fbb, attr->axis()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_OneHot, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + OneHot() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; diff --git a/mindspore/lite/src/ops/p_relu.cc b/mindspore/lite/src/ops/p_relu.cc index b2f1bdadf2..2174e80baa 100644 --- a/mindspore/lite/src/ops/p_relu.cc +++ b/mindspore/lite/src/ops/p_relu.cc @@ -21,14 +21,31 @@ namespace lite { #ifdef PRIMITIVE_WRITEABLE bool PReLU::GetChannelShared() const { return this->primitive_->value.AsPReLU()->channelShared; } -void PReLU::SetChannelShared(bool channel_shared) { - this->primitive_->value.AsPReLU()->channelShared = channel_shared; -} +void PReLU::SetChannelShared(bool channel_shared) { this->primitive_->value.AsPReLU()->channelShared = channel_shared; } #else - bool PReLU::GetChannelShared() const { return this->primitive_->value_as_PReLU()->channelShared(); } +int PReLU::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_PReLU(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_PReLU return nullptr"; + return RET_ERROR; + } + std::vector slope; + if (attr->slope() != nullptr) { + for (int i = 0; i < static_cast(attr->slope()->size()); i++) { + slope.push_back(attr->slope()->data()[i]); + } + } + auto val_offset = schema::CreatePReLUDirect(*fbb, attr->channelShared(), &slope); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_PReLU, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/p_relu.h b/mindspore/lite/src/ops/p_relu.h index 1f2c04a41c..f18f168154 100644 --- a/mindspore/lite/src/ops/p_relu.h +++ b/mindspore/lite/src/ops/p_relu.h @@ -35,35 +35,9 @@ class PReLU : public Activation { void SetChannelShared(bool channel_shared); #else - explicit PReLU(schema::Primitive *primitive) : Activation(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_PReLU(); - MS_ASSERT(attr != nullptr); - - auto slope = std::make_unique>(); - for (int i = 0; i < static_cast(attr->slope()->size()); i++) { - slope->push_back(attr->slope()->data()[i]); - } - - auto val_offset = schema::CreatePReLUDirect(fbb, attr->channelShared(), slope.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_PReLU, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + PReLU() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif bool GetChannelShared() const; }; diff --git a/mindspore/lite/src/ops/pad.cc b/mindspore/lite/src/ops/pad.cc index 6cc5cb321c..852686972c 100644 --- a/mindspore/lite/src/ops/pad.cc +++ b/mindspore/lite/src/ops/pad.cc @@ -38,6 +38,25 @@ std::vector Pad::GetPaddings() const { int Pad::GetPaddingMode() const { return this->primitive_->value_as_Pad()->paddingMode(); } float Pad::GetConstantValue() const { return this->primitive_->value_as_Pad()->constantValue(); } +int Pad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Pad(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Pad return nullptr"; + return RET_ERROR; + } + std::vector paddings; + if (attr->paddings() != nullptr) { + for (int i = 0; i < static_cast(attr->paddings()->size()); i++) { + paddings.push_back(attr->paddings()->data()[i]); + } + } + auto val_offset = schema::CreatePadDirect(*fbb, &paddings, attr->paddingMode(), attr->constantValue()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Pad, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif namespace { const size_t kInputRank = 4; diff --git a/mindspore/lite/src/ops/pad.h b/mindspore/lite/src/ops/pad.h index d7f5797d29..695a1c1742 100644 --- a/mindspore/lite/src/ops/pad.h +++ b/mindspore/lite/src/ops/pad.h @@ -36,35 +36,9 @@ class Pad : public PrimitiveC { void SetPaddingMode(int padding_mode); void SetConstantValue(float constant_value); #else - explicit Pad(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Pad(); - MS_ASSERT(attr != nullptr); - - auto paddings = std::make_unique>(); - for (int i = 0; i < static_cast(attr->paddings()->size()); i++) { - paddings->push_back(attr->paddings()->data()[i]); - } - - auto val_offset = schema::CreatePadDirect(fbb, paddings.release(), attr->paddingMode(), attr->constantValue()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Pad, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Pad() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetPaddings() const; diff --git a/mindspore/lite/src/ops/permute.cc b/mindspore/lite/src/ops/permute.cc index 74c49b53ba..d51c99ebe8 100644 --- a/mindspore/lite/src/ops/permute.cc +++ b/mindspore/lite/src/ops/permute.cc @@ -31,6 +31,25 @@ std::vector Permute::GetOrder() const { } void Permute::SetOrder(const std::vector &order) {} +int Permute::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Permute(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Permute return nullptr"; + return RET_ERROR; + } + std::vector order; + if (attr->order() != nullptr) { + for (int i = 0; i < static_cast(attr->order()->size()); i++) { + order.push_back(attr->order()->data()[i]); + } + } + auto val_offset = schema::CreatePermuteDirect(*fbb, &order); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Permute, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/permute.h b/mindspore/lite/src/ops/permute.h index b7433bd08c..f2c082ab30 100644 --- a/mindspore/lite/src/ops/permute.h +++ b/mindspore/lite/src/ops/permute.h @@ -32,35 +32,9 @@ class Permute : public PrimitiveC { MS_DECLARE_PARENT(Permute, PrimitiveC); explicit Permute(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} #else - explicit Permute(schema::Primitive *primitive) : PrimitiveC(primitive) {} + Permute() = default; - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Permute(); - MS_ASSERT(attr != nullptr); - - auto order = std::make_unique>(); - for (int i = 0; i < static_cast(attr->order()->size()); i++) { - order->push_back(attr->order()->data()[i]); - } - - auto val_offset = schema::CreatePermuteDirect(fbb, order.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Permute, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); - - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif std::vector GetOrder() const; void SetOrder(const std::vector &order); diff --git a/mindspore/lite/src/ops/pooling.cc b/mindspore/lite/src/ops/pooling.cc index 80a3d728af..2b409622ca 100644 --- a/mindspore/lite/src/ops/pooling.cc +++ b/mindspore/lite/src/ops/pooling.cc @@ -136,6 +136,23 @@ int Pooling::GetPadRight() const { return this->primitive_->value_as_Pooling()-> int Pooling::GetRoundMode() const { return this->primitive_->value_as_Pooling()->roundMode(); } int Pooling::GetActivationType() const { return this->primitive_->value_as_Pooling()->activationType(); } +int Pooling::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Pooling(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Pooling return nullptr"; + return RET_ERROR; + } + auto val_offset = + schema::CreatePooling(*fbb, attr->format(), attr->poolingMode(), attr->global(), attr->windowW(), attr->windowH(), + attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), + attr->padLeft(), attr->padRight(), attr->roundMode()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Pooling, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + #endif int Pooling::PadUp() const { return this->pad_u_; } diff --git a/mindspore/lite/src/ops/pooling.h b/mindspore/lite/src/ops/pooling.h index d4f13fd553..6892d5b95e 100644 --- a/mindspore/lite/src/ops/pooling.h +++ b/mindspore/lite/src/ops/pooling.h @@ -45,34 +45,11 @@ class Pooling : public PrimitiveC { void SetPadRight(int pad_right); void SetRoundMode(int round_mode); void SetActivationType(int activation_type); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else - explicit Pooling(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Pooling(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreatePooling(fbb, attr->format(), attr->poolingMode(), attr->global(), - attr->windowW(), attr->windowH(), attr->strideW(), attr->strideH(), - attr->padMode(), attr->padUp(), attr->padDown(), - attr->padLeft(), attr->padRight(), attr->roundMode()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Pooling, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Pooling() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; @@ -95,8 +72,6 @@ class Pooling : public PrimitiveC { int PadLeft() const; int PadRight() const; - int UnPackAttr(const Primitive &prim, const std::vector &inputs); - protected: int pad_u_ = 0; int pad_d_ = 0; diff --git a/mindspore/lite/src/ops/pooling_grad.cc b/mindspore/lite/src/ops/pooling_grad.cc index dcbd587058..654a3cb047 100644 --- a/mindspore/lite/src/ops/pooling_grad.cc +++ b/mindspore/lite/src/ops/pooling_grad.cc @@ -69,6 +69,22 @@ int PoolingGrad::GetPadLeft() const { return this->primitive_->value_as_PoolingG int PoolingGrad::GetPadRight() const { return this->primitive_->value_as_PoolingGrad()->padRight(); } int PoolingGrad::GetRoundMode() const { return this->primitive_->value_as_PoolingGrad()->roundMode(); } +int PoolingGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_PoolingGrad(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_PoolingGrad return nullptr"; + return RET_ERROR; + } + auto val_offset = + schema::CreatePoolingGrad(*fbb, attr->format(), attr->poolingMode(), attr->global(), attr->windowW(), + attr->windowH(), attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), + attr->padDown(), attr->padLeft(), attr->padRight(), attr->roundMode()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_PoolingGrad, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/pooling_grad.h b/mindspore/lite/src/ops/pooling_grad.h index ac0e58ff9f..c42c5d72d6 100644 --- a/mindspore/lite/src/ops/pooling_grad.h +++ b/mindspore/lite/src/ops/pooling_grad.h @@ -45,33 +45,9 @@ class PoolingGrad : public PrimitiveC { void SetPadRight(int pad_right); void SetRoundMode(int round_mode); #else - explicit PoolingGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_PoolingGrad(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreatePoolingGrad(fbb, attr->format(), attr->poolingMode(), attr->global(), - attr->windowW(), attr->windowH(), attr->strideW(), attr->strideH(), - attr->padMode(), attr->padUp(), attr->padDown(), - attr->padLeft(), attr->padRight(), attr->roundMode()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_PoolingGrad, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + PoolingGrad() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int GetFormat() const; int GetPoolingMode() const; diff --git a/mindspore/lite/src/ops/power.cc b/mindspore/lite/src/ops/power.cc index 047d95d545..d812ed1d1f 100644 --- a/mindspore/lite/src/ops/power.cc +++ b/mindspore/lite/src/ops/power.cc @@ -32,7 +32,19 @@ void Power::SetShift(float shift) { this->primitive_->value.AsPower()->shift = s float Power::GetPower() const { return this->primitive_->value_as_Power()->power(); } float Power::GetScale() const { return this->primitive_->value_as_Power()->scale(); } float Power::GetShift() const { return this->primitive_->value_as_Power()->shift(); } - +int Power::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Power(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Power return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreatePower(*fbb, attr->power(), attr->scale(), attr->shift()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Power, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif int Power::InferShape(std::vector inputs, std::vector outputs) { diff --git a/mindspore/lite/src/ops/power.h b/mindspore/lite/src/ops/power.h index d607c8b889..b38dce1bdd 100644 --- a/mindspore/lite/src/ops/power.h +++ b/mindspore/lite/src/ops/power.h @@ -35,30 +35,9 @@ class Power : public PrimitiveC { void SetScale(float scale); void SetShift(float shift); #else - explicit Power(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Power(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreatePower(fbb, attr->power(), attr->scale(), attr->shift()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Power, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Power() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; float GetPower() const; diff --git a/mindspore/lite/src/ops/power_grad.cc b/mindspore/lite/src/ops/power_grad.cc index c1df70ea9f..ba10623dec 100644 --- a/mindspore/lite/src/ops/power_grad.cc +++ b/mindspore/lite/src/ops/power_grad.cc @@ -33,6 +33,21 @@ float PowerGrad::GetPower() const { return this->primitive_->value_as_PowerGrad( float PowerGrad::GetScale() const { return this->primitive_->value_as_PowerGrad()->scale(); } float PowerGrad::GetShift() const { return this->primitive_->value_as_PowerGrad()->shift(); } +int PowerGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + + auto attr = primitive->value_as_PowerGrad(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_PowerGrad return nullptr"; + return RET_ERROR; + } + + auto val_offset = schema::CreatePowerGrad(*fbb, attr->power(), attr->scale(), attr->shift()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_PowerGrad, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/power_grad.h b/mindspore/lite/src/ops/power_grad.h index 8fe000586e..9cb95d696f 100644 --- a/mindspore/lite/src/ops/power_grad.h +++ b/mindspore/lite/src/ops/power_grad.h @@ -35,30 +35,9 @@ class PowerGrad : public PrimitiveC { void SetScale(float scale); void SetShift(float shift); #else - explicit PowerGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_PowerGrad(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreatePowerGrad(fbb, attr->power(), attr->scale(), attr->shift()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_PowerGrad, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + PowerGrad() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif float GetPower() const; float GetScale() const; diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 13348e1f28..7bc61f531a 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -448,210 +448,210 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT return nullptr; } #else -PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *primitive) { +PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(const schema::Primitive *primitive) { MS_ASSERT(primitive); auto op_type = primitive->value_type(); switch (op_type) { case schema::PrimitiveType_SoftMax: - return new SoftMax(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Activation: - return new Activation(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Conv2D: - return new Conv2D(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_DeConv2D: - return new DeConv2D(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Reduce: - return new Reduce(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Pooling: - return new Pooling(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_ROIPooling: - return new ROIPooling(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_DepthwiseConv2D: - return new DepthwiseConv2D(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_FusedBatchNorm: - return new FusedBatchNorm(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_BatchNorm: - return new BatchNorm(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_FullConnection: - return new FullConnection(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Power: - return new Power(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Pad: - return new Pad(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Range: - return new Range(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Mul: - return new Mul(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Add: - return new Add(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Sub: - return new Sub(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Div: - return new Div(const_cast(primitive)); + return NewPrimitiveC
(primitive); case schema::PrimitiveType_BiasAdd: - return new BiasAdd(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_ExpandDims: - return new ExpandDims(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_ArgMax: - return new ArgMax(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_ArgMin: - return new ArgMin(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Cast: - return new Cast(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Reshape: - return new Reshape(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Scale: - return new Scale(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Eltwise: - return new Eltwise(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Ceil: - return new Ceil(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Concat: - return new Concat(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Fill: - return new Fill(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Nhwc2Nchw: - return new Nhwc2Nchw(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Nchw2Nhwc: - return new Nchw2Nhwc(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Transpose: - return new Transpose(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Slice: - return new Slice(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Squeeze: - return new Squeeze(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Flatten: - return new Flatten(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Mean: - return new Mean(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Stack: - return new Stack(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Crop: - return new Crop(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_SquaredDifference: - return new SquaredDifference(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_AddN: - return new AddN(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Abs: - return new Abs(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Sin: - return new Sin(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Cos: - return new Cos(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Log: - return new Log(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Sqrt: - return new Sqrt(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Rsqrt: - return new Rsqrt(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Square: - return new Square(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Exp: - return new Exp(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Gather: - return new Gather(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_GatherNd: - return new GatherNd(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_LocalResponseNormalization: - return new LocalResponseNormalization(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Maximum: - return new Maximum(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Minimum: - return new Minimum(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_StridedSlice: - return new StridedSlice(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_LeakyReLU: - return new (std::nothrow) LeakyReLU(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_PReLU: - return new (std::nothrow) PReLU(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Round: - return new Round(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Reverse: - return new Reverse(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_ReverseSequence: - return new ReverseSequence(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_LogicalAnd: - return new LogicalAnd(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_LogicalOr: - return new LogicalOr(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_LogicalNot: - return new LogicalNot(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_FloorDiv: - return new FloorDiv(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_FloorMod: - return new FloorMod(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Equal: - return new Equal(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_NotEqual: - return new NotEqual(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Less: - return new Less(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_LessEqual: - return new LessEqual(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Greater: - return new Greater(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_GreaterEqual: - return new GreaterEqual(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Floor: - return new Floor(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Split: - return new Split(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_OneHot: - return new OneHot(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_PriorBox: - return new PriorBox(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_SpaceToDepth: - return new SpaceToDepth(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Tile: - return new Tile(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Resize: - return new Resize(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Unstack: - return new Unstack(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Unique: - return new Unique(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_TopK: - return new TopK(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_MatMul: - return new MatMul(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_QuantDTypeCast: - return new QuantDTypeCast(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_EmbeddingLookup: - return new EmbeddingLookup(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Elu: - return new Elu(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_DeDepthwiseConv2D: - return new DeDepthwiseConv2D(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Shape: - return new Shape(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Unsqueeze: - return new Unsqueeze(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_BatchToSpace: - return new BatchToSpace(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_SpaceToBatch: - return new SpaceToBatch(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_SpaceToBatchND: - return new SpaceToBatchND(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_BroadcastTo: - return new BroadcastTo(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_DepthToSpace: - return new DepthToSpace(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Lstm: - return new Lstm(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_ZerosLike: - return new ZerosLike(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_MakeTuple: - return new MakeTuple(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_Where: - return new Where(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_ScatterND: - return new ScatterND(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_ConstantOfShape: - return new ConstantOfShape(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_L2Norm: - return new L2Norm(const_cast(primitive)); + return NewPrimitiveC(primitive); case schema::PrimitiveType_SparseToDense: - return new SparseToDense(const_cast(primitive)); + return NewPrimitiveC(primitive); default: MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromSchemaPrimitive : " << schema::EnumNamePrimitiveType(op_type); diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index f7a97f92ff..d13f5c31a2 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -46,14 +46,14 @@ constexpr int kAnfPopulaterTwo = 2; constexpr int kAnfPopulaterThree = 3; class PrimitiveC : public mindspore::Primitive { public: - // Argument primitive is delived into PrimitiveC and will be deleted in ~PrimitiveC(). Caller should not delete - // primitive + // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). + // Caller should not delete primitive. explicit PrimitiveC(schema::PrimitiveT *primitive) : Primitive(""), primitive_(primitive) {} explicit PrimitiveC(const Primitive &prim) : Primitive(prim) {} - // Argument primitive is delived into PrimitiveC and will be deleted in ~PrimitiveC(). Caller should not delete - // primitive + // Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC(). + // Caller should not delete primitive. explicit PrimitiveC(const std::string &name, schema::PrimitiveT *primitive) : Primitive(name), primitive_(primitive) {} @@ -113,7 +113,7 @@ class PrimitiveC : public mindspore::Primitive { static std::shared_ptr UnPackFromPrimitive(const Primitive &prim, const std::vector &inputs); protected: - virtual int UnPackAttr(const Primitive &prim) { return RET_ERROR; } + virtual int UnPackAttr(const Primitive &prim, const std::vector &inputs) { return RET_ERROR; } protected: schema::PrimitiveT *primitive_ = nullptr; @@ -133,15 +133,9 @@ class PrimitiveC { public: PrimitiveC() = default; - // Argument primitive is delived into PrimitiveC and will be deleted in ~PrimitiveC(). Caller should not delete - // primitive - explicit PrimitiveC(schema::Primitive *primitive) : primitive_(primitive) {} + virtual ~PrimitiveC() { free(this->primitive_buf_); } - virtual ~PrimitiveC() { - // delete this->primitive_; - } - - static PrimitiveC *UnPackFromSchemaPrimitive(mindspore::schema::Primitive *primitive); + static PrimitiveC *UnPackFromSchemaPrimitive(const schema::Primitive *primitive); bool GetInferFlag() const; @@ -152,7 +146,53 @@ class PrimitiveC { int Type() const; protected: - schema::Primitive *primitive_ = nullptr; + template ::value>> + static PrimitiveC *NewPrimitiveC(const schema::Primitive *primitive) { + auto primc = new T(); + if (primc == nullptr) { + MS_LOG(ERROR) << "new PrimitiveC failed"; + return nullptr; + } + auto ret = primc->UnPackSchemaPrimitive(primitive); + if (ret != RET_OK) { + MS_LOG(ERROR) << "UnPackSchemaPrimitive failed"; + return nullptr; + } + return primc; + } + + int UnPackSchemaPrimitive(const schema::Primitive *primitive) { + flatbuffers::FlatBufferBuilder fbb(1024); + if (UnPackToFlatBuilder(primitive, &fbb) != RET_OK) { + MS_LOG(ERROR) << "UnPackToFlatBuilder failde"; + fbb.Clear(); + return RET_ERROR; + } + auto buf = fbb.GetBufferPointer(); + if (buf == nullptr) { + MS_LOG(ERROR) << "GetBufferPointer return nullptr"; + fbb.Clear(); + return RET_ERROR; + } + primitive_buf_ = reinterpret_cast(malloc(fbb.GetSize())); + if (primitive_buf_ == nullptr) { + MS_LOG(ERROR) << "malloc primitive_buf_ failed"; + fbb.Clear(); + return RET_ERROR; + } + memcpy(primitive_buf_, buf, fbb.GetSize()); + this->primitive_ = flatbuffers::GetRoot(primitive_buf_); + fbb.Clear(); + return RET_OK; + } + + virtual int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + return RET_ERROR; + } + + protected: + const schema::Primitive *primitive_ = nullptr; + char *primitive_buf_ = nullptr; bool infer_flag_ = true; }; #endif diff --git a/mindspore/lite/src/ops/prior_box.cc b/mindspore/lite/src/ops/prior_box.cc index 0d57353ae8..1bd60c151e 100644 --- a/mindspore/lite/src/ops/prior_box.cc +++ b/mindspore/lite/src/ops/prior_box.cc @@ -77,6 +77,43 @@ bool PriorBox::GetClip() const { return this->primitive_->value_as_PriorBox()->c bool PriorBox::GetFlip() const { return this->primitive_->value_as_PriorBox()->flip(); } float PriorBox::GetOffset() const { return this->primitive_->value_as_PriorBox()->offset(); } +int PriorBox::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_PriorBox(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_PriorBox return nullptr"; + return RET_ERROR; + } + std::vector min_sizes; + if (attr->min_sizes() != nullptr) { + for (int i = 0; i < static_cast(attr->min_sizes()->size()); i++) { + min_sizes.push_back(attr->min_sizes()->data()[i]); + } + } + std::vector max_sizes; + if (attr->max_sizes() != nullptr) { + for (int i = 0; i < static_cast(attr->max_sizes()->size()); i++) { + max_sizes.push_back(attr->max_sizes()->data()[i]); + } + } + std::vector aspect_ratios; + if (attr->aspect_ratios() != nullptr) { + for (int i = 0; i < static_cast(attr->aspect_ratios()->size()); i++) { + aspect_ratios.push_back(attr->aspect_ratios()->data()[i]); + } + } + std::vector variances; + if (attr->variances() != nullptr) { + for (int i = 0; i < static_cast(attr->variances()->size()); i++) { + variances.push_back(attr->variances()->data()[i]); + } + } + auto val_offset = schema::CreatePriorBoxDirect(*fbb, &min_sizes, &max_sizes, &aspect_ratios, &variances); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_PriorBox, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif namespace { diff --git a/mindspore/lite/src/ops/prior_box.h b/mindspore/lite/src/ops/prior_box.h index 8012b406d5..d6f105a31c 100644 --- a/mindspore/lite/src/ops/prior_box.h +++ b/mindspore/lite/src/ops/prior_box.h @@ -44,48 +44,9 @@ class PriorBox : public PrimitiveC { void SetFlip(bool flip); void SetOffset(float offset); #else - explicit PriorBox(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_PriorBox(); - MS_ASSERT(attr != nullptr); - - auto min_sizes = std::make_unique>(); - for (int i = 0; i < static_cast(attr->min_sizes()->size()); i++) { - min_sizes->push_back(attr->min_sizes()->data()[i]); - } - auto max_sizes = std::make_unique>(); - for (int i = 0; i < static_cast(attr->max_sizes()->size()); i++) { - max_sizes->push_back(attr->max_sizes()->data()[i]); - } - auto aspect_ratios = std::make_unique>(); - for (int i = 0; i < static_cast(attr->aspect_ratios()->size()); i++) { - aspect_ratios->push_back(attr->aspect_ratios()->data()[i]); - } - auto variances = std::make_unique>(); - for (int i = 0; i < static_cast(attr->variances()->size()); i++) { - variances->push_back(attr->variances()->data()[i]); - } - - auto val_offset = schema::CreatePriorBoxDirect(fbb, min_sizes.release(), max_sizes.release(), - aspect_ratios.release(), variances.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_PriorBox, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + PriorBox() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetMinSizes() const; diff --git a/mindspore/lite/src/ops/quant.h b/mindspore/lite/src/ops/quant.h index 80e4567d92..8717585aab 100644 --- a/mindspore/lite/src/ops/quant.h +++ b/mindspore/lite/src/ops/quant.h @@ -27,9 +27,9 @@ class Quant : public PrimitiveC { MS_DECLARE_PARENT(Quant, PrimitiveC); Quant() = default; explicit Quant(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - int UnPackAttr(const Primitive &prim, const std::vector &inputs); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else - explicit Quant(schema::Primitive *primitive) : PrimitiveC(primitive) {} + Quant() = default; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/quant_dtype_cast.cc b/mindspore/lite/src/ops/quant_dtype_cast.cc index 9ce46a4438..a3adb9a5c4 100644 --- a/mindspore/lite/src/ops/quant_dtype_cast.cc +++ b/mindspore/lite/src/ops/quant_dtype_cast.cc @@ -29,7 +29,19 @@ void QuantDTypeCast::SetDstT(int dst_t) { this->primitive_->value.AsQuantDTypeCa int QuantDTypeCast::GetSrcT() const { return this->primitive_->value_as_QuantDTypeCast()->srcT(); } int QuantDTypeCast::GetDstT() const { return this->primitive_->value_as_QuantDTypeCast()->dstT(); } - +int QuantDTypeCast::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_QuantDTypeCast(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_QuantDTypeCast return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateQuantDTypeCast(*fbb, attr->srcT(), attr->dstT()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_QuantDTypeCast, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif int QuantDTypeCast::InferShape(std::vector inputs_, std::vector outputs_) { diff --git a/mindspore/lite/src/ops/quant_dtype_cast.h b/mindspore/lite/src/ops/quant_dtype_cast.h index 7357461cf1..0523272982 100644 --- a/mindspore/lite/src/ops/quant_dtype_cast.h +++ b/mindspore/lite/src/ops/quant_dtype_cast.h @@ -34,30 +34,9 @@ class QuantDTypeCast : public PrimitiveC { void SetSrcT(int src_t); void SetDstT(int dst_t); #else - explicit QuantDTypeCast(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_QuantDTypeCast(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateQuantDTypeCast(fbb, attr->srcT(), attr->dstT()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_QuantDTypeCast, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + QuantDTypeCast() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetSrcT() const; diff --git a/mindspore/lite/src/ops/range.cc b/mindspore/lite/src/ops/range.cc index bad3941de7..75afe4efb2 100644 --- a/mindspore/lite/src/ops/range.cc +++ b/mindspore/lite/src/ops/range.cc @@ -35,7 +35,19 @@ int Range::GetDType() const { return this->primitive_->value_as_Range()->dType() int Range::GetStart() const { return this->primitive_->value_as_Range()->start(); } int Range::GetLimit() const { return this->primitive_->value_as_Range()->limit(); } int Range::GetDelta() const { return this->primitive_->value_as_Range()->delta(); } - +int Range::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Range(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Range return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateRange(*fbb, attr->dType(), attr->start(), attr->limit(), attr->delta()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Range, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif int Range::InferShape(std::vector inputs_, std::vector outputs_) { diff --git a/mindspore/lite/src/ops/range.h b/mindspore/lite/src/ops/range.h index 6b1ac596e4..4f0a432462 100644 --- a/mindspore/lite/src/ops/range.h +++ b/mindspore/lite/src/ops/range.h @@ -36,30 +36,9 @@ class Range : public PrimitiveC { void SetLimit(int limit); void SetDelta(int delta); #else - explicit Range(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Range(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateRange(fbb, attr->dType(), attr->start(), attr->limit(), attr->delta()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Range, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Range() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetDType() const; diff --git a/mindspore/lite/src/ops/rank.cc b/mindspore/lite/src/ops/rank.cc index 5a89c68178..1c95012d95 100644 --- a/mindspore/lite/src/ops/rank.cc +++ b/mindspore/lite/src/ops/rank.cc @@ -18,7 +18,17 @@ namespace mindspore { namespace lite { - +#ifdef PRIMITIVE_WRITEABLE +#else +int Rank::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateRank(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Rank, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif int Rank::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/rank.h b/mindspore/lite/src/ops/rank.h index 2e5f7c7895..5251247e1f 100644 --- a/mindspore/lite/src/ops/rank.h +++ b/mindspore/lite/src/ops/rank.h @@ -32,27 +32,9 @@ class Rank : public PrimitiveC { Rank() = default; explicit Rank(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} #else - explicit Rank(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateRank(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Rank, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Rank() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; }; diff --git a/mindspore/lite/src/ops/reduce.cc b/mindspore/lite/src/ops/reduce.cc index bc2c0753dc..7ec73d67f1 100644 --- a/mindspore/lite/src/ops/reduce.cc +++ b/mindspore/lite/src/ops/reduce.cc @@ -86,7 +86,25 @@ std::vector Reduce::GetAxes() const { } int Reduce::GetKeepDims() const { return this->primitive_->value_as_Reduce()->keepDims(); } int Reduce::GetMode() const { return this->primitive_->value_as_Reduce()->mode(); } - +int Reduce::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Reduce(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Reduce return nullptr"; + return RET_ERROR; + } + std::vector axes; + if (attr->axes() != nullptr) { + for (int i = 0; i < static_cast(attr->axes()->size()); i++) { + axes.push_back(attr->axes()->data()[i]); + } + } + auto val_offset = schema::CreateReduceDirect(*fbb, &axes, attr->keepDims(), attr->mode()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Reduce, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif namespace { diff --git a/mindspore/lite/src/ops/reduce.h b/mindspore/lite/src/ops/reduce.h index 3bea756cc5..bb8458d09a 100644 --- a/mindspore/lite/src/ops/reduce.h +++ b/mindspore/lite/src/ops/reduce.h @@ -32,40 +32,14 @@ class Reduce : public PrimitiveC { MS_DECLARE_PARENT(Reduce, PrimitiveC); Reduce() = default; explicit Reduce(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - int UnPackAttr(const Primitive &prim, const std::vector &inputs); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; void SetAxes(const std::vector &axes); void SetKeepDims(int keep_dims); void SetMode(int mode); #else - explicit Reduce(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Reduce(); - MS_ASSERT(attr != nullptr); - - auto axes = std::make_unique>(); - for (int i = 0; i < static_cast(attr->axes()->size()); i++) { - axes->push_back(attr->axes()->data()[i]); - } - - auto val_offset = schema::CreateReduceDirect(fbb, axes.release(), attr->keepDims(), attr->mode()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Reduce, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Reduce() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetAxes() const; diff --git a/mindspore/lite/src/ops/reshape.cc b/mindspore/lite/src/ops/reshape.cc index eb35e05ed7..8928a89b94 100644 --- a/mindspore/lite/src/ops/reshape.cc +++ b/mindspore/lite/src/ops/reshape.cc @@ -81,7 +81,25 @@ std::vector Reshape::GetShape() const { auto fb_vector = this->primitive_->value_as_Reshape()->shape(); return std::vector(fb_vector->begin(), fb_vector->end()); } - +int Reshape::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Reshape(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Reshape return nullptr"; + return RET_ERROR; + } + std::vector shape; + if (attr->shape() != nullptr) { + for (int i = 0; i < static_cast(attr->shape()->size()); i++) { + shape.push_back(attr->shape()->data()[i]); + } + } + auto val_offset = schema::CreateReshapeDirect(*fbb, attr->format(), &shape); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Reshape, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif int Reshape::CalNewShape(const tensor::Tensor *in_tensor, std::vector *out_shape) const { diff --git a/mindspore/lite/src/ops/reshape.h b/mindspore/lite/src/ops/reshape.h index 76eccc9421..c81187636b 100644 --- a/mindspore/lite/src/ops/reshape.h +++ b/mindspore/lite/src/ops/reshape.h @@ -32,39 +32,13 @@ class Reshape : public PrimitiveC { MS_DECLARE_PARENT(Reshape, PrimitiveC); Reshape() = default; explicit Reshape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - int UnPackAttr(const Primitive &prim, const std::vector &inputs); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; void SetFormat(int format); void SetShape(const std::vector &shape); #else - explicit Reshape(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Reshape(); - MS_ASSERT(attr != nullptr); - - auto shape = std::make_unique>(); - for (int i = 0; i < static_cast(attr->shape()->size()); i++) { - shape->push_back(attr->shape()->data()[i]); - } - - auto val_offset = schema::CreateReshapeDirect(fbb, attr->format(), shape.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Reshape, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Reshape() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; diff --git a/mindspore/lite/src/ops/resize.cc b/mindspore/lite/src/ops/resize.cc index d9973a3bd8..aa0dd10648 100644 --- a/mindspore/lite/src/ops/resize.cc +++ b/mindspore/lite/src/ops/resize.cc @@ -43,7 +43,20 @@ int64_t Resize::GetNewHeight() const { return this->primitive_->value_as_Resize( int64_t Resize::GetNewWidth() const { return this->primitive_->value_as_Resize()->newWidth(); } bool Resize::GetAlignCorners() const { return this->primitive_->value_as_Resize()->alignCorners(); } bool Resize::GetPreserveAspectRatio() const { return this->primitive_->value_as_Resize()->preserveAspectRatio(); } - +int Resize::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Resize(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Resize return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateResize(*fbb, attr->format(), attr->method(), attr->newHeight(), attr->newWidth(), + attr->alignCorners(), attr->preserveAspectRatio()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Resize, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif namespace { constexpr int kInputRank = 4; diff --git a/mindspore/lite/src/ops/resize.h b/mindspore/lite/src/ops/resize.h index 98a52fbfcd..3e1d71a484 100644 --- a/mindspore/lite/src/ops/resize.h +++ b/mindspore/lite/src/ops/resize.h @@ -38,32 +38,9 @@ class Resize : public PrimitiveC { void SetAlignCorners(bool align_corners); void SetPreserveAspectRatio(bool preserve_aspect_ratio); #else - explicit Resize(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Resize(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateResize(fbb, attr->format(), attr->method(), - attr->newHeight(), attr->newWidth(), - attr->alignCorners(), attr->preserveAspectRatio()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Resize, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Resize() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; diff --git a/mindspore/lite/src/ops/return.h b/mindspore/lite/src/ops/return.h index 44431af919..ec6af73b67 100644 --- a/mindspore/lite/src/ops/return.h +++ b/mindspore/lite/src/ops/return.h @@ -31,9 +31,9 @@ class Return : public PrimitiveC { MS_DECLARE_PARENT(Return, PrimitiveC); Return() = default; explicit Return(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - int UnPackAttr(const Primitive &prim, const std::vector &inputs); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else - explicit Return(schema::Primitive *primitive) : PrimitiveC(primitive) {} + Return() = default; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; }; diff --git a/mindspore/lite/src/ops/reverse.cc b/mindspore/lite/src/ops/reverse.cc index 0f4ce81b2b..11bb4388d5 100644 --- a/mindspore/lite/src/ops/reverse.cc +++ b/mindspore/lite/src/ops/reverse.cc @@ -29,7 +29,25 @@ std::vector Reverse::GetAxis() const { auto fb_vector = this->primitive_->value_as_Reverse()->axis(); return std::vector(fb_vector->begin(), fb_vector->end()); } - +int Reverse::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Reverse(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Reverse return nullptr"; + return RET_ERROR; + } + std::vector axis; + if (attr->axis() != nullptr) { + for (int i = 0; i < static_cast(attr->axis()->size()); i++) { + axis.push_back(attr->axis()->data()[i]); + } + } + auto val_offset = schema::CreateReverseDirect(*fbb, &axis); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Reverse, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/reverse.h b/mindspore/lite/src/ops/reverse.h index 7b895b394e..0a95a18413 100644 --- a/mindspore/lite/src/ops/reverse.h +++ b/mindspore/lite/src/ops/reverse.h @@ -35,35 +35,9 @@ class Reverse : public PrimitiveC { void SetAxis(const std::vector &axis); #else - explicit Reverse(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Reverse(); - MS_ASSERT(attr != nullptr); - - auto axis = std::make_unique>(); - for (int i = 0; i < static_cast(attr->axis()->size()); i++) { - axis->push_back(attr->axis()->data()[i]); - } - - auto val_offset = schema::CreateReverseDirect(fbb, axis.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Reverse, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Reverse() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif std::vector GetAxis() const; }; diff --git a/mindspore/lite/src/ops/reverse_sequence.cc b/mindspore/lite/src/ops/reverse_sequence.cc index b620e1d06e..c89477832c 100644 --- a/mindspore/lite/src/ops/reverse_sequence.cc +++ b/mindspore/lite/src/ops/reverse_sequence.cc @@ -41,7 +41,26 @@ std::vector ReverseSequence::GetSeqLengths() const { auto fb_vector = this->primitive_->value_as_ReverseSequence()->seqLengths(); return std::vector(fb_vector->begin(), fb_vector->end()); } +int ReverseSequence::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_ReverseSequence(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_ReverseSequence return nullptr"; + return RET_ERROR; + } + std::vector seqLengths; + if (attr->seqLengths() != nullptr) { + for (int i = 0; i < static_cast(attr->seqLengths()->size()); i++) { + seqLengths.push_back(attr->seqLengths()->data()[i]); + } + } + auto val_offset = schema::CreateReverseSequenceDirect(*fbb, attr->seqAxis(), attr->batchAxis(), &seqLengths); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ReverseSequence, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif int ReverseSequence::InferShape(std::vector inputs, std::vector outputs) { diff --git a/mindspore/lite/src/ops/reverse_sequence.h b/mindspore/lite/src/ops/reverse_sequence.h index b197e546a2..6b0c59d384 100644 --- a/mindspore/lite/src/ops/reverse_sequence.h +++ b/mindspore/lite/src/ops/reverse_sequence.h @@ -36,36 +36,9 @@ class ReverseSequence : public PrimitiveC { void SetBatchAxis(int batch_axis); void SetSeqLengths(const std::vector &seq_lengths); #else - explicit ReverseSequence(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_ReverseSequence(); - MS_ASSERT(attr != nullptr); - - auto seqLengths = std::make_unique>(); - for (int i = 0; i < static_cast(attr->seqLengths()->size()); i++) { - seqLengths->push_back(attr->seqLengths()->data()[i]); - } - - auto val_offset = schema::CreateReverseSequenceDirect(fbb, attr->seqAxis(), - attr->batchAxis(), seqLengths.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ReverseSequence, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + ReverseSequence() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetSeqAxis() const; diff --git a/mindspore/lite/src/ops/roi_pooling.cc b/mindspore/lite/src/ops/roi_pooling.cc index 1c6f1b919d..4d1270e35a 100644 --- a/mindspore/lite/src/ops/roi_pooling.cc +++ b/mindspore/lite/src/ops/roi_pooling.cc @@ -32,7 +32,21 @@ void ROIPooling::SetScale(float scale) { this->primitive_->value.AsROIPooling()- int ROIPooling::GetPooledH() const { return this->primitive_->value_as_ROIPooling()->pooledH(); } int ROIPooling::GetPooledW() const { return this->primitive_->value_as_ROIPooling()->pooledW(); } float ROIPooling::GetScale() const { return this->primitive_->value_as_ROIPooling()->scale(); } +int ROIPooling::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_ROIPooling(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_ROIPooling return nullptr"; + return RET_ERROR; + } + + auto val_offset = schema::CreateROIPooling(*fbb, attr->pooledH(), attr->pooledW(), attr->scale()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ROIPooling, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif int ROIPooling::InferShape(std::vector inputs_, std::vector outputs_) { diff --git a/mindspore/lite/src/ops/roi_pooling.h b/mindspore/lite/src/ops/roi_pooling.h index 876850b504..4dabb95f90 100644 --- a/mindspore/lite/src/ops/roi_pooling.h +++ b/mindspore/lite/src/ops/roi_pooling.h @@ -35,30 +35,9 @@ class ROIPooling : public PrimitiveC { void SetPooledW(int pooled_w); void SetScale(float scale); #else - explicit ROIPooling(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_ROIPooling(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateROIPooling(fbb, attr->pooledH(), attr->pooledW(), attr->scale()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ROIPooling, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + ROIPooling() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetPooledH() const; diff --git a/mindspore/lite/src/ops/round.cc b/mindspore/lite/src/ops/round.cc new file mode 100644 index 0000000000..ae3167597c --- /dev/null +++ b/mindspore/lite/src/ops/round.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/round.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +#else +int Round::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateRound(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Round, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/round.h b/mindspore/lite/src/ops/round.h index b253c0728e..606324243c 100644 --- a/mindspore/lite/src/ops/round.h +++ b/mindspore/lite/src/ops/round.h @@ -21,7 +21,7 @@ #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" +#include "src/ops/arithmetic_self.h" namespace mindspore { namespace lite { @@ -32,27 +32,9 @@ class Round : public ArithmeticSelf { Round() = default; explicit Round(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} #else - explicit Round(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateRound(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Round, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Round() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/rsqrt.cc b/mindspore/lite/src/ops/rsqrt.cc new file mode 100644 index 0000000000..742aed2953 --- /dev/null +++ b/mindspore/lite/src/ops/rsqrt.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/rsqrt.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +#else +int Rsqrt::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + + auto val_offset = schema::CreateRsqrt(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Rsqrt, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/rsqrt.h b/mindspore/lite/src/ops/rsqrt.h index 77ae116dcc..17c48c4413 100644 --- a/mindspore/lite/src/ops/rsqrt.h +++ b/mindspore/lite/src/ops/rsqrt.h @@ -21,7 +21,7 @@ #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" +#include "src/ops/arithmetic_self.h" namespace mindspore { namespace lite { @@ -32,27 +32,9 @@ class Rsqrt : public ArithmeticSelf { Rsqrt() = default; explicit Rsqrt(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} #else - explicit Rsqrt(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateRsqrt(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Rsqrt, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Rsqrt() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/scale.cc b/mindspore/lite/src/ops/scale.cc index 52265173e4..53ed368783 100644 --- a/mindspore/lite/src/ops/scale.cc +++ b/mindspore/lite/src/ops/scale.cc @@ -26,7 +26,19 @@ void Scale::SetAxis(int axis) { this->primitive_->value.AsScale()->axis = axis; #else int Scale::GetAxis() const { return this->primitive_->value_as_Scale()->axis(); } - +int Scale::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Scale(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Scale return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateScale(*fbb, attr->axis()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Scale, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/scale.h b/mindspore/lite/src/ops/scale.h index 1a22998583..c9f3d653b6 100644 --- a/mindspore/lite/src/ops/scale.h +++ b/mindspore/lite/src/ops/scale.h @@ -34,30 +34,9 @@ class Scale : public PrimitiveC { void SetAxis(int axis); #else - explicit Scale(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Scale(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateScale(fbb, attr->axis()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Scale, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Scale() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int GetAxis() const; }; diff --git a/mindspore/lite/src/ops/scatter_nd.cc b/mindspore/lite/src/ops/scatter_nd.cc index a033192708..6cd425ac75 100644 --- a/mindspore/lite/src/ops/scatter_nd.cc +++ b/mindspore/lite/src/ops/scatter_nd.cc @@ -61,5 +61,17 @@ int ScatterND::InferShape(std::vector inputs_, std::vectorset_shape(out_shape); return RET_OK; } +#ifdef PRIMITIVE_WRITEABLE +#else +int ScatterND::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + + auto val_offset = schema::CreateScatterND(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ScatterND, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/scatter_nd.h b/mindspore/lite/src/ops/scatter_nd.h index a5afb3471e..ad7bc2c887 100644 --- a/mindspore/lite/src/ops/scatter_nd.h +++ b/mindspore/lite/src/ops/scatter_nd.h @@ -32,27 +32,9 @@ class ScatterND : public PrimitiveC { ScatterND() = default; explicit ScatterND(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} #else - explicit ScatterND(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateScatterND(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ScatterND, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + ScatterND() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; }; diff --git a/mindspore/lite/src/ops/shape.cc b/mindspore/lite/src/ops/shape.cc index a6b8098b2c..349fb8cb0d 100644 --- a/mindspore/lite/src/ops/shape.cc +++ b/mindspore/lite/src/ops/shape.cc @@ -51,5 +51,17 @@ int Shape::InferShape(std::vector inputs_, std::vectorFinish(prim_offset); + return RET_OK; +} +#endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/shape.h b/mindspore/lite/src/ops/shape.h index eb6e8218ed..7dc856eca5 100644 --- a/mindspore/lite/src/ops/shape.h +++ b/mindspore/lite/src/ops/shape.h @@ -32,27 +32,9 @@ class Shape : public PrimitiveC { Shape() = default; explicit Shape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} #else - explicit Shape(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateShape(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Shape, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Shape() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; }; diff --git a/mindspore/lite/src/ops/sin.cc b/mindspore/lite/src/ops/sin.cc new file mode 100644 index 0000000000..eb363cf531 --- /dev/null +++ b/mindspore/lite/src/ops/sin.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/sin.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +#else +int Sin::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + + auto val_offset = schema::CreateSin(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Sin, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/sin.h b/mindspore/lite/src/ops/sin.h index db1076a786..82383c166d 100644 --- a/mindspore/lite/src/ops/sin.h +++ b/mindspore/lite/src/ops/sin.h @@ -21,7 +21,7 @@ #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" +#include "src/ops/arithmetic_self.h" namespace mindspore { namespace lite { @@ -32,27 +32,9 @@ class Sin : public ArithmeticSelf { Sin() = default; explicit Sin(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} #else - explicit Sin(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateSin(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Sin, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Sin() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/slice.cc b/mindspore/lite/src/ops/slice.cc index 3f90bcd147..bfdd2ac039 100644 --- a/mindspore/lite/src/ops/slice.cc +++ b/mindspore/lite/src/ops/slice.cc @@ -50,6 +50,34 @@ std::vector Slice::GetAxes() const { auto fb_vector = this->primitive_->value_as_Slice()->axes(); return std::vector(fb_vector->begin(), fb_vector->end()); } +int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + + auto attr = primitive->value_as_Slice(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Slice return nullptr"; + return RET_ERROR; + } + + std::vector begin; + if (attr->begin() != nullptr) { + for (int i = 0; i < static_cast(attr->begin()->size()); i++) { + begin.push_back(attr->begin()->data()[i]); + } + } + std::vector size; + if (attr->size() != nullptr) { + for (int i = 0; i < static_cast(attr->size()->size()); i++) { + size.push_back(attr->size()->data()[i]); + } + } + + auto val_offset = schema::CreateSliceDirect(*fbb, attr->format(), &begin, &size); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Slice, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif std::vector Slice::GetPostProcessBegin() const { return this->begin; } @@ -92,8 +120,7 @@ int Slice::InferShape(std::vector inputs, std::vector
  • (input_shape[i] - begin[i])) { - MS_LOG(ERROR) << "Invalid size input " << size[i] - << " which should be <= " << input_shape[i] - begin[i]; + MS_LOG(ERROR) << "Invalid size input " << size[i] << " which should be <= " << input_shape[i] - begin[i]; return RET_PARAM_INVALID; } diff --git a/mindspore/lite/src/ops/slice.h b/mindspore/lite/src/ops/slice.h index 8c7995ee80..b5fa281e4b 100644 --- a/mindspore/lite/src/ops/slice.h +++ b/mindspore/lite/src/ops/slice.h @@ -36,39 +36,9 @@ class Slice : public PrimitiveC { void SetBegin(const std::vector &begin); void SetSize(const std::vector &size); #else - explicit Slice(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Slice(); - MS_ASSERT(attr != nullptr); - - auto begin = std::make_unique>(); - for (int i = 0; i < static_cast(attr->begin()->size()); i++) { - begin->push_back(attr->begin()->data()[i]); - } - auto size = std::make_unique>(); - for (int i = 0; i < static_cast(attr->size()->size()); i++) { - size->push_back(attr->size()->data()[i]); - } - - auto val_offset = schema::CreateSliceDirect(fbb, attr->format(), begin.release(), size.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Slice, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Slice() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetFormat() const; @@ -81,8 +51,8 @@ class Slice : public PrimitiveC { std::vector GetPostProcessSize() const; protected: - std::vector begin = {0}; - std::vector size = {-1}; + std::vector begin = {0}; + std::vector size = {-1}; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/softmax.cc b/mindspore/lite/src/ops/softmax.cc index 48f0bfc8dc..2640d3edbc 100644 --- a/mindspore/lite/src/ops/softmax.cc +++ b/mindspore/lite/src/ops/softmax.cc @@ -26,7 +26,19 @@ void SoftMax::SetAxis(int axis) { this->primitive_->value.AsSoftMax()->axis = ax #else int SoftMax::GetAxis() const { return this->primitive_->value_as_SoftMax()->axis(); } - +int SoftMax::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_SoftMax(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_SoftMax return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateSoftMax(*fbb, attr->axis()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_SoftMax, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif int SoftMax::InferShape(std::vector inputs_, std::vector outputs_) { diff --git a/mindspore/lite/src/ops/softmax.h b/mindspore/lite/src/ops/softmax.h index bd39b4de0c..aa7dc5db88 100644 --- a/mindspore/lite/src/ops/softmax.h +++ b/mindspore/lite/src/ops/softmax.h @@ -34,30 +34,9 @@ class SoftMax : public PrimitiveC { void SetAxis(int axis); #else - explicit SoftMax(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_SoftMax(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateSoftMax(fbb, attr->axis()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_SoftMax, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + SoftMax() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; diff --git a/mindspore/lite/src/ops/softmax_cross_entropy.cc b/mindspore/lite/src/ops/softmax_cross_entropy.cc index c1a8ea0469..8be8ca1d88 100644 --- a/mindspore/lite/src/ops/softmax_cross_entropy.cc +++ b/mindspore/lite/src/ops/softmax_cross_entropy.cc @@ -31,7 +31,25 @@ std::vector SoftmaxCrossEntropy::GetAxis() const { auto fb_vector = this->primitive_->value_as_SoftmaxCrossEntropy()->axis(); return std::vector(fb_vector->begin(), fb_vector->end()); } - +int SoftmaxCrossEntropy::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_SoftmaxCrossEntropy(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_SoftmaxCrossEntropy return nullptr"; + return RET_ERROR; + } + std::vector axis; + if (attr->axis() != nullptr) { + for (int i = 0; i < static_cast(attr->axis()->size()); i++) { + axis.push_back(attr->axis()->data()[i]); + } + } + auto val_offset = schema::CreateSoftmaxCrossEntropyDirect(*fbb, &axis); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_SoftmaxCrossEntropy, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/softmax_cross_entropy.h b/mindspore/lite/src/ops/softmax_cross_entropy.h index a054235c61..44449d0bfd 100644 --- a/mindspore/lite/src/ops/softmax_cross_entropy.h +++ b/mindspore/lite/src/ops/softmax_cross_entropy.h @@ -35,35 +35,9 @@ class SoftmaxCrossEntropy : public PrimitiveC { void SetAxis(const std::vector &axis); #else - explicit SoftmaxCrossEntropy(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_SoftmaxCrossEntropy(); - MS_ASSERT(attr != nullptr); - - auto axis = std::make_unique>(); - for (int i = 0; i < static_cast(attr->axis()->size()); i++) { - axis->push_back(attr->axis()->data()[i]); - } - - auto val_offset = schema::CreateSoftmaxCrossEntropyDirect(fbb, axis.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_SoftmaxCrossEntropy, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + SoftmaxCrossEntropy() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif std::vector GetAxis() const; }; diff --git a/mindspore/lite/src/ops/space_to_batch.cc b/mindspore/lite/src/ops/space_to_batch.cc index 1942455380..ac2902a307 100644 --- a/mindspore/lite/src/ops/space_to_batch.cc +++ b/mindspore/lite/src/ops/space_to_batch.cc @@ -40,7 +40,31 @@ std::vector SpaceToBatch::GetPaddings() const { auto fb_vector = this->primitive_->value_as_SpaceToBatch()->paddings(); return std::vector(fb_vector->begin(), fb_vector->end()); } - +int SpaceToBatch::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_SpaceToBatch(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_SpaceToBatch return nullptr"; + return RET_ERROR; + } + std::vector blockShape; + if (attr->blockShape() != nullptr) { + for (int i = 0; i < static_cast(attr->blockShape()->size()); i++) { + blockShape.push_back(attr->blockShape()->data()[i]); + } + } + std::vector paddings; + if (attr->paddings() != nullptr) { + for (int i = 0; i < static_cast(attr->paddings()->size()); i++) { + paddings.push_back(attr->paddings()->data()[i]); + } + } + auto val_offset = schema::CreateSpaceToBatchDirect(*fbb, &blockShape, &paddings); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_SpaceToBatch, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif namespace { constexpr int kSpaceToBatchNDOutputNum = 1; diff --git a/mindspore/lite/src/ops/space_to_batch.h b/mindspore/lite/src/ops/space_to_batch.h index f1e0b3811b..3c3888bcb7 100644 --- a/mindspore/lite/src/ops/space_to_batch.h +++ b/mindspore/lite/src/ops/space_to_batch.h @@ -35,39 +35,9 @@ class SpaceToBatch : public PrimitiveC { void SetBlockShape(const std::vector &block_shape); void SetPaddings(const std::vector &paddings); #else - explicit SpaceToBatch(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_SpaceToBatch(); - MS_ASSERT(attr != nullptr); - - auto blockShape = std::make_unique>(); - for (int i = 0; i < static_cast(attr->blockShape()->size()); i++) { - blockShape->push_back(attr->blockShape()->data()[i]); - } - auto paddings = std::make_unique>(); - for (int i = 0; i < static_cast(attr->paddings()->size()); i++) { - paddings->push_back(attr->paddings()->data()[i]); - } - - auto val_offset = schema::CreateSpaceToBatchDirect(fbb, blockShape.release(), paddings.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_SpaceToBatch, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + SpaceToBatch() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs, std::vector outputs) override; diff --git a/mindspore/lite/src/ops/space_to_batch_nd.cc b/mindspore/lite/src/ops/space_to_batch_nd.cc index b716a6cfe2..23d5bfd18e 100644 --- a/mindspore/lite/src/ops/space_to_batch_nd.cc +++ b/mindspore/lite/src/ops/space_to_batch_nd.cc @@ -50,6 +50,32 @@ std::vector SpaceToBatchND::GetPaddings() const { return std::vector(fb_vector->begin(), fb_vector->end()); } +int SpaceToBatchND::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_SpaceToBatch(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_SpaceToBatch return nullptr"; + return RET_ERROR; + } + std::vector blockShape; + if (attr->blockShape() != nullptr) { + for (int i = 0; i < static_cast(attr->blockShape()->size()); i++) { + blockShape.push_back(attr->blockShape()->data()[i]); + } + } + std::vector paddings; + if (attr->paddings() != nullptr) { + for (int i = 0; i < static_cast(attr->paddings()->size()); i++) { + paddings.push_back(attr->paddings()->data()[i]); + } + } + auto val_offset = schema::CreateSpaceToBatchDirect(*fbb, &blockShape, &paddings); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_SpaceToBatch, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + #endif // PRIMITIVE_WRITEABLE int SpaceToBatchND::InferShape(std::vector inputs, diff --git a/mindspore/lite/src/ops/space_to_batch_nd.h b/mindspore/lite/src/ops/space_to_batch_nd.h index 24ae7c4a2a..f308efd509 100644 --- a/mindspore/lite/src/ops/space_to_batch_nd.h +++ b/mindspore/lite/src/ops/space_to_batch_nd.h @@ -33,39 +33,9 @@ class SpaceToBatchND : public PrimitiveC { void SetBlockShape(const std::vector &block_shape); void SetPaddings(const std::vector &paddings); #else - explicit SpaceToBatchND(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_SpaceToBatchND(); - MS_ASSERT(attr != nullptr); - - auto blockShape = std::make_unique>(); - for (int i = 0; i < static_cast(attr->blockShape()->size()); i++) { - blockShape->push_back(attr->blockShape()->data()[i]); - } - auto paddings = std::make_unique>(); - for (int i = 0; i < static_cast(attr->paddings()->size()); i++) { - paddings->push_back(attr->paddings()->data()[i]); - } - - auto val_offset = schema::CreateSpaceToBatchNDDirect(fbb, blockShape.release(), paddings.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_SpaceToBatchND, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + SpaceToBatchND() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif std::vector GetBlockShape() const; std::vector GetPaddings() const; diff --git a/mindspore/lite/src/ops/space_to_depth.cc b/mindspore/lite/src/ops/space_to_depth.cc index b287cbb2e1..f98956d089 100644 --- a/mindspore/lite/src/ops/space_to_depth.cc +++ b/mindspore/lite/src/ops/space_to_depth.cc @@ -30,7 +30,19 @@ void SpaceToDepth::SetFormat(int format) { this->primitive_->value.AsSpaceToDept int SpaceToDepth::GetBlockSize() const { return this->primitive_->value_as_SpaceToDepth()->blockSize(); } int SpaceToDepth::GetFormat() const { return this->primitive_->value_as_SpaceToDepth()->format(); } - +int SpaceToDepth::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_SpaceToDepth(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_SpaceToDepth return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateSpaceToDepth(*fbb, attr->blockSize(), attr->format()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_SpaceToDepth, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif namespace { constexpr int kSpaceToDepthOutputNum = 1; diff --git a/mindspore/lite/src/ops/space_to_depth.h b/mindspore/lite/src/ops/space_to_depth.h index fde2639a9b..8edeb3ea0f 100644 --- a/mindspore/lite/src/ops/space_to_depth.h +++ b/mindspore/lite/src/ops/space_to_depth.h @@ -34,30 +34,9 @@ class SpaceToDepth : public PrimitiveC { void SetBlockSize(int block_size); void SetFormat(int format); #else - explicit SpaceToDepth(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_SpaceToDepth(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateSpaceToDepth(fbb, attr->blockSize(), attr->format()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_SpaceToDepth, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + SpaceToDepth() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetBlockSize() const; diff --git a/mindspore/lite/src/ops/sparse_to_dense.cc b/mindspore/lite/src/ops/sparse_to_dense.cc index f6ed040cbb..c59ce46473 100644 --- a/mindspore/lite/src/ops/sparse_to_dense.cc +++ b/mindspore/lite/src/ops/sparse_to_dense.cc @@ -58,7 +58,37 @@ std::vector SparseToDense::GetDefaultValue() const { return std::vector(fb_vector->begin(), fb_vector->end()); } bool SparseToDense::GetValidateIndices() const { return this->primitive_->value_as_SparseToDense()->validateIndices(); } - +int SparseToDense::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_SparseToDense(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_SparseToDense return nullptr"; + return RET_ERROR; + } + std::vector outputShape; + if (attr->outputShape() != nullptr) { + for (int i = 0; i < static_cast(attr->outputShape()->size()); i++) { + outputShape.push_back(attr->outputShape()->data()[i]); + } + } + std::vector sparseValue; + if (attr->sparseValue() != nullptr) { + for (int i = 0; i < static_cast(attr->sparseValue()->size()); i++) { + sparseValue.push_back(attr->sparseValue()->data()[i]); + } + } + std::vector defaultValue; + if (attr->defaultValue() != nullptr) { + for (int i = 0; i < static_cast(attr->defaultValue()->size()); i++) { + defaultValue.push_back(attr->defaultValue()->data()[i]); + } + } + auto val_offset = schema::CreateSparseToDenseDirect(*fbb, &outputShape, &sparseValue, &defaultValue); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_SparseToDense, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/sparse_to_dense.h b/mindspore/lite/src/ops/sparse_to_dense.h index 14acfce282..d98a843975 100644 --- a/mindspore/lite/src/ops/sparse_to_dense.h +++ b/mindspore/lite/src/ops/sparse_to_dense.h @@ -37,44 +37,9 @@ class SparseToDense : public PrimitiveC { void SetDefaultValue(const std::vector &default_value); void SetValidateIndices(bool validate_indices); #else - explicit SparseToDense(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_SparseToDense(); - MS_ASSERT(attr != nullptr); - - auto outputShape = std::make_unique>(); - for (int i = 0; i < static_cast(attr->outputShape()->size()); i++) { - outputShape->push_back(attr->outputShape()->data()[i]); - } - auto sparseValue = std::make_unique>(); - for (int i = 0; i < static_cast(attr->sparseValue()->size()); i++) { - sparseValue->push_back(attr->sparseValue()->data()[i]); - } - auto defaultValue = std::make_unique>(); - for (int i = 0; i < static_cast(attr->defaultValue()->size()); i++) { - defaultValue->push_back(attr->defaultValue()->data()[i]); - } - - auto val_offset = schema::CreateSparseToDenseDirect(fbb, outputShape.release(), - sparseValue.release(), defaultValue.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_SparseToDense, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + SparseToDense() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif std::vector GetOutputShape() const; std::vector GetSparseValue() const; diff --git a/mindspore/lite/src/ops/split.cc b/mindspore/lite/src/ops/split.cc index 4c5634c51b..a7bde44996 100644 --- a/mindspore/lite/src/ops/split.cc +++ b/mindspore/lite/src/ops/split.cc @@ -38,6 +38,25 @@ std::vector Split::GetSizeSplits() const { } int Split::GetSplitDim() const { return this->primitive_->value_as_Split()->splitDim(); } +int Split::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Split(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Split return nullptr"; + return RET_ERROR; + } + std::vector sizeSplits; + if (attr->sizeSplits() != nullptr) { + for (int i = 0; i < static_cast(attr->sizeSplits()->size()); i++) { + sizeSplits.push_back(attr->sizeSplits()->data()[i]); + } + } + auto val_offset = schema::CreateSplitDirect(*fbb, attr->numberSplit(), &sizeSplits, attr->splitDim()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Split, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif namespace { diff --git a/mindspore/lite/src/ops/split.h b/mindspore/lite/src/ops/split.h index 4f94948b07..86c9fe3594 100644 --- a/mindspore/lite/src/ops/split.h +++ b/mindspore/lite/src/ops/split.h @@ -36,35 +36,9 @@ class Split : public PrimitiveC { void SetSizeSplits(const std::vector &size_splits); void SetSplitDim(int split_dim); #else - explicit Split(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Split(); - MS_ASSERT(attr != nullptr); - - auto sizeSplits = std::make_unique>(); - for (int i = 0; i < static_cast(attr->sizeSplits()->size()); i++) { - sizeSplits->push_back(attr->sizeSplits()->data()[i]); - } - - auto val_offset = schema::CreateSplitDirect(fbb, attr->numberSplit(), sizeSplits.release(), attr->splitDim()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Split, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Split() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetNumberSplit() const; diff --git a/mindspore/lite/src/ops/sqrt.cc b/mindspore/lite/src/ops/sqrt.cc new file mode 100644 index 0000000000..6035b5db8c --- /dev/null +++ b/mindspore/lite/src/ops/sqrt.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/sqrt.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +#else +int Sqrt::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + + auto val_offset = schema::CreateSqrt(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Sqrt, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/sqrt.h b/mindspore/lite/src/ops/sqrt.h index 0121de2a49..68b82eee60 100644 --- a/mindspore/lite/src/ops/sqrt.h +++ b/mindspore/lite/src/ops/sqrt.h @@ -21,7 +21,7 @@ #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" +#include "src/ops/arithmetic_self.h" namespace mindspore { namespace lite { @@ -32,27 +32,9 @@ class Sqrt : public ArithmeticSelf { Sqrt() = default; explicit Sqrt(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} #else - explicit Sqrt(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateSqrt(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Sqrt, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Sqrt() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/square.cc b/mindspore/lite/src/ops/square.cc new file mode 100644 index 0000000000..89f5ba8dbf --- /dev/null +++ b/mindspore/lite/src/ops/square.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/square.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +#else +int Square::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + + auto val_offset = schema::CreateSquare(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Square, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/square.h b/mindspore/lite/src/ops/square.h index 8762aff94e..5ab29590ff 100644 --- a/mindspore/lite/src/ops/square.h +++ b/mindspore/lite/src/ops/square.h @@ -20,7 +20,7 @@ #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" +#include "src/ops/arithmetic_self.h" namespace mindspore { namespace lite { @@ -31,27 +31,9 @@ class Square : public ArithmeticSelf { Square() = default; explicit Square(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {} #else - explicit Square(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateSquare(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Square, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Square() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/squared_difference.cc b/mindspore/lite/src/ops/squared_difference.cc new file mode 100644 index 0000000000..b602b29600 --- /dev/null +++ b/mindspore/lite/src/ops/squared_difference.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/squared_difference.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +#else +int SquaredDifference::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + + auto val_offset = schema::CreateSquaredDifference(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_SquaredDifference, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/squared_difference.h b/mindspore/lite/src/ops/squared_difference.h index 4893289941..7b7a6a412f 100644 --- a/mindspore/lite/src/ops/squared_difference.h +++ b/mindspore/lite/src/ops/squared_difference.h @@ -21,7 +21,7 @@ #include #include #include "ir/dtype/type_id.h" -#include "src/ops/primitive_c.h" +#include "src/ops/arithmetic.h" namespace mindspore { namespace lite { @@ -32,27 +32,9 @@ class SquaredDifference : public Arithmetic { SquaredDifference() = default; explicit SquaredDifference(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} #else - explicit SquaredDifference(schema::Primitive *primitive) : Arithmetic(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateSquaredDifference(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_SquaredDifference, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + SquaredDifference() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/squeeze.cc b/mindspore/lite/src/ops/squeeze.cc index 1f21d39ceb..d9dbd6c734 100644 --- a/mindspore/lite/src/ops/squeeze.cc +++ b/mindspore/lite/src/ops/squeeze.cc @@ -29,7 +29,25 @@ std::vector Squeeze::GetAxis() const { auto fb_vector = this->primitive_->value_as_Squeeze()->axis(); return std::vector(fb_vector->begin(), fb_vector->end()); } - +int Squeeze::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Squeeze(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Squeeze return nullptr"; + return RET_ERROR; + } + std::vector axis; + if (attr->axis() != nullptr) { + for (int i = 0; i < static_cast(attr->axis()->size()); i++) { + axis.push_back(attr->axis()->data()[i]); + } + } + auto val_offset = schema::CreateSqueezeDirect(*fbb, &axis); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Squeeze, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif namespace { diff --git a/mindspore/lite/src/ops/squeeze.h b/mindspore/lite/src/ops/squeeze.h index 7da4a13044..aec38d7d68 100644 --- a/mindspore/lite/src/ops/squeeze.h +++ b/mindspore/lite/src/ops/squeeze.h @@ -35,35 +35,9 @@ class Squeeze : public PrimitiveC { void SetAxis(const std::vector &axis); #else - explicit Squeeze(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Squeeze(); - MS_ASSERT(attr != nullptr); - - auto axis = std::make_unique>(); - for (int i = 0; i < static_cast(attr->axis()->size()); i++) { - axis->push_back(attr->axis()->data()[i]); - } - - auto val_offset = schema::CreateSqueezeDirect(fbb, axis.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Squeeze, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Squeeze() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetAxis() const; diff --git a/mindspore/lite/src/ops/stack.cc b/mindspore/lite/src/ops/stack.cc index ef316bbd24..f5f0c56df6 100644 --- a/mindspore/lite/src/ops/stack.cc +++ b/mindspore/lite/src/ops/stack.cc @@ -35,7 +35,25 @@ std::vector Stack::GetIsScale() const { auto fb_vector = this->primitive_->value_as_Stack()->isScale(); return std::vector(fb_vector->begin(), fb_vector->end()); } - +int Stack::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Stack(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Stack return nullptr"; + return RET_ERROR; + } + std::vector isScale; + if (attr->isScale() != nullptr) { + for (int i = 0; i < static_cast(attr->isScale()->size()); i++) { + isScale.push_back(attr->isScale()->data()[i]); + } + } + auto val_offset = schema::CreateStackDirect(*fbb, attr->axis(), attr->n(), &isScale); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Stack, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif namespace { diff --git a/mindspore/lite/src/ops/stack.h b/mindspore/lite/src/ops/stack.h index ffd39877ee..b1e480349d 100644 --- a/mindspore/lite/src/ops/stack.h +++ b/mindspore/lite/src/ops/stack.h @@ -36,35 +36,9 @@ class Stack : public PrimitiveC { void SetN(int n); void SetIsScale(const std::vector &is_scale); #else - explicit Stack(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Stack(); - MS_ASSERT(attr != nullptr); - - auto isScale = std::make_unique>(); - for (int i = 0; i < static_cast(attr->isScale()->size()); i++) { - isScale->push_back(attr->isScale()->data()[i]); - } - - auto val_offset = schema::CreateStackDirect(fbb, attr->axis(), attr->n(), isScale.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Stack, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Stack() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetAxis() const; diff --git a/mindspore/lite/src/ops/strided_slice.cc b/mindspore/lite/src/ops/strided_slice.cc index 57d8511e8b..892229b866 100644 --- a/mindspore/lite/src/ops/strided_slice.cc +++ b/mindspore/lite/src/ops/strided_slice.cc @@ -72,7 +72,45 @@ std::vector StridedSlice::GetIsScale() const { auto fb_vector = this->primitive_->value_as_StridedSlice()->isScale(); return std::vector(fb_vector->begin(), fb_vector->end()); } - +int StridedSlice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_StridedSlice(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_StridedSlice return nullptr"; + return RET_ERROR; + } + std::vector begin; + if (attr->begin() != nullptr) { + for (int i = 0; i < static_cast(attr->begin()->size()); i++) { + begin.push_back(attr->begin()->data()[i]); + } + } + std::vector end; + if (attr->end() != nullptr) { + for (int i = 0; i < static_cast(attr->end()->size()); i++) { + end.push_back(attr->end()->data()[i]); + } + } + std::vector stride; + if (attr->stride() != nullptr) { + for (int i = 0; i < static_cast(attr->stride()->size()); i++) { + stride.push_back(attr->stride()->data()[i]); + } + } + std::vector isScale; + if (attr->isScale() != nullptr) { + for (int i = 0; i < static_cast(attr->isScale()->size()); i++) { + isScale.push_back(attr->isScale()->data()[i]); + } + } + auto val_offset = + schema::CreateStridedSliceDirect(*fbb, attr->beginMask(), attr->endMask(), attr->ellipsisMask(), + attr->newAxisMask(), attr->shrinkAxisMask(), &begin, &end, &stride, &isScale); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_StridedSlice, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif namespace { constexpr int kStridedSliceOutputNum = 1; diff --git a/mindspore/lite/src/ops/strided_slice.h b/mindspore/lite/src/ops/strided_slice.h index 49f634fc81..13d8d9151a 100644 --- a/mindspore/lite/src/ops/strided_slice.h +++ b/mindspore/lite/src/ops/strided_slice.h @@ -42,49 +42,9 @@ class StridedSlice : public PrimitiveC { void SetStride(const std::vector &stride); void SetIsScale(const std::vector &is_scale); #else - explicit StridedSlice(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_StridedSlice(); - MS_ASSERT(attr != nullptr); - - auto begin = std::make_unique>(); - for (int i = 0; i < static_cast(attr->begin()->size()); i++) { - begin->push_back(attr->begin()->data()[i]); - } - auto end = std::make_unique>(); - for (int i = 0; i < static_cast(attr->end()->size()); i++) { - end->push_back(attr->end()->data()[i]); - } - auto stride = std::make_unique>(); - for (int i = 0; i < static_cast(attr->stride()->size()); i++) { - stride->push_back(attr->stride()->data()[i]); - } - auto isScale = std::make_unique>(); - for (int i = 0; i < static_cast(attr->isScale()->size()); i++) { - isScale->push_back(attr->isScale()->data()[i]); - } - - auto val_offset = schema::CreateStridedSliceDirect(fbb, attr->beginMask(), attr->endMask(), attr->ellipsisMask(), - attr->newAxisMask(), attr->shrinkAxisMask(), begin.release(), - end.release(), stride.release(), isScale.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_StridedSlice, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + StridedSlice() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetBeginMask() const; diff --git a/mindspore/lite/src/ops/sub.cc b/mindspore/lite/src/ops/sub.cc index 5024832657..bee2131df8 100644 --- a/mindspore/lite/src/ops/sub.cc +++ b/mindspore/lite/src/ops/sub.cc @@ -28,7 +28,19 @@ void Sub::SetActivationType(int activation_type) { #else int Sub::GetActivationType() const { return this->primitive_->value_as_Sub()->activationType(); } - +int Sub::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Sub(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Sub return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateSub(*fbb, attr->activationType()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Sub, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/sub.h b/mindspore/lite/src/ops/sub.h index 8faf0331c9..6b4058c368 100644 --- a/mindspore/lite/src/ops/sub.h +++ b/mindspore/lite/src/ops/sub.h @@ -34,30 +34,9 @@ class Sub : public Arithmetic { void SetActivationType(int activation_type); #else - explicit Sub(schema::Primitive *primitive) : Arithmetic(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Sub(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateSub(fbb, attr->activationType()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Sub, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Sub() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int GetActivationType() const; }; diff --git a/mindspore/lite/src/ops/tile.cc b/mindspore/lite/src/ops/tile.cc index 4b4b39764b..cf7058ceaf 100644 --- a/mindspore/lite/src/ops/tile.cc +++ b/mindspore/lite/src/ops/tile.cc @@ -39,7 +39,31 @@ std::vector Tile::GetDims() const { auto fb_vector = this->primitive_->value_as_Tile()->dims(); return std::vector(fb_vector->begin(), fb_vector->end()); } - +int Tile::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Tile(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Tile return nullptr"; + return RET_ERROR; + } + std::vector multiples; + if (attr->multiples() != nullptr) { + for (int i = 0; i < static_cast(attr->multiples()->size()); i++) { + multiples.push_back(attr->multiples()->data()[i]); + } + } + std::vector dims; + if (attr->dims() != nullptr) { + for (int i = 0; i < static_cast(attr->dims()->size()); i++) { + dims.push_back(attr->dims()->data()[i]); + } + } + auto val_offset = schema::CreateTileDirect(*fbb, &multiples, &dims); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Tile, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif int Tile::InferShape(std::vector inputs_, std::vector outputs_) { diff --git a/mindspore/lite/src/ops/tile.h b/mindspore/lite/src/ops/tile.h index 960586e155..f46685a605 100644 --- a/mindspore/lite/src/ops/tile.h +++ b/mindspore/lite/src/ops/tile.h @@ -36,39 +36,9 @@ class Tile : public PrimitiveC { void SetDims(const std::vector &dims); #else - explicit Tile(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Tile(); - MS_ASSERT(attr != nullptr); - - auto multiples = std::make_unique>(); - for (int i = 0; i < static_cast(attr->multiples()->size()); i++) { - multiples->push_back(attr->multiples()->data()[i]); - } - auto dims = std::make_unique>(); - for (int i = 0; i < static_cast(attr->dims()->size()); i++) { - dims->push_back(attr->dims()->data()[i]); - } - - auto val_offset = schema::CreateTileDirect(fbb, multiples.release(), dims.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Tile, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Tile() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetMultiples() const; diff --git a/mindspore/lite/src/ops/topk.cc b/mindspore/lite/src/ops/topk.cc index b1533da95c..fb3a8a47e9 100644 --- a/mindspore/lite/src/ops/topk.cc +++ b/mindspore/lite/src/ops/topk.cc @@ -29,7 +29,19 @@ void TopK::SetSorted(bool sorted) { this->primitive_->value.AsTopK()->sorted = s int TopK::GetK() const { return this->primitive_->value_as_TopK()->k(); } bool TopK::GetSorted() const { return this->primitive_->value_as_TopK()->sorted(); } - +int TopK::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_TopK(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_TopK return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateTopK(*fbb, attr->k(), attr->sorted()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_TopK, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif int TopK::InferShape(std::vector inputs_, std::vector outputs_) { diff --git a/mindspore/lite/src/ops/topk.h b/mindspore/lite/src/ops/topk.h index a7510d59dd..082f83fd7e 100644 --- a/mindspore/lite/src/ops/topk.h +++ b/mindspore/lite/src/ops/topk.h @@ -34,30 +34,9 @@ class TopK : public PrimitiveC { void SetK(int k); void SetSorted(bool sorted); #else - explicit TopK(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_TopK(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateTopK(fbb, attr->k(), attr->sorted()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_TopK, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + TopK() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetK() const; diff --git a/mindspore/lite/src/ops/transpose.cc b/mindspore/lite/src/ops/transpose.cc index 406aeceda6..11c057a6e1 100644 --- a/mindspore/lite/src/ops/transpose.cc +++ b/mindspore/lite/src/ops/transpose.cc @@ -77,6 +77,26 @@ std::vector Transpose::GetPerm() const { } bool Transpose::GetConjugate() const { return this->primitive_->value_as_Transpose()->conjugate(); } +int Transpose::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Transpose(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Transpose return nullptr"; + return RET_ERROR; + } + std::vector perm; + if (attr->perm() != nullptr) { + for (int i = 0; i < static_cast(attr->perm()->size()); i++) { + perm.push_back(attr->perm()->data()[i]); + } + } + + auto val_offset = schema::CreateTransposeDirect(*fbb, &perm, attr->conjugate()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Transpose, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif int Transpose::InferShape(std::vector inputs_, std::vector outputs_) { diff --git a/mindspore/lite/src/ops/transpose.h b/mindspore/lite/src/ops/transpose.h index 013fcfb5df..b12507993d 100644 --- a/mindspore/lite/src/ops/transpose.h +++ b/mindspore/lite/src/ops/transpose.h @@ -32,39 +32,13 @@ class Transpose : public PrimitiveC { MS_DECLARE_PARENT(Transpose, PrimitiveC); Transpose() = default; explicit Transpose(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - int UnPackAttr(const Primitive &prim, const std::vector &inputs); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; void SetPerm(const std::vector &perm); void SetConjugate(bool conjugate); #else - explicit Transpose(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Transpose(); - MS_ASSERT(attr != nullptr); - - auto perm = std::make_unique>(); - for (int i = 0; i < static_cast(attr->perm()->size()); i++) { - perm->push_back(attr->perm()->data()[i]); - } - - auto val_offset = schema::CreateTransposeDirect(fbb, perm.release(), attr->conjugate()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Transpose, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Transpose() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetPerm() const; diff --git a/mindspore/lite/src/ops/tuple_get_item.cc b/mindspore/lite/src/ops/tuple_get_item.cc index 23a67e88f4..ac35bb6100 100644 --- a/mindspore/lite/src/ops/tuple_get_item.cc +++ b/mindspore/lite/src/ops/tuple_get_item.cc @@ -48,6 +48,15 @@ int TupleGetItem::UnPackAttr(const Primitive &prim, const std::vectorFinish(prim_offset); + return RET_OK; +} #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/tuple_get_item.h b/mindspore/lite/src/ops/tuple_get_item.h index 925c5f6b3e..e816c122b6 100644 --- a/mindspore/lite/src/ops/tuple_get_item.h +++ b/mindspore/lite/src/ops/tuple_get_item.h @@ -28,29 +28,10 @@ class TupleGetItem : public PrimitiveC { MS_DECLARE_PARENT(TupleGetItem, PrimitiveC); TupleGetItem() = default; explicit TupleGetItem(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - int UnPackAttr(const Primitive &prim, const std::vector &inputs); + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else - explicit TupleGetItem(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateTupleGetItem(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_TupleGetItem, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); - - delete[] buf_bak; - fbb.Clear(); - return prim; - } + TupleGetItem() = default; + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif }; } // namespace lite diff --git a/mindspore/lite/src/ops/unique.cc b/mindspore/lite/src/ops/unique.cc index 7199a30aad..8f0a626bcc 100644 --- a/mindspore/lite/src/ops/unique.cc +++ b/mindspore/lite/src/ops/unique.cc @@ -26,7 +26,19 @@ void Unique::SetOutType(int out_type) { this->primitive_->value.AsUnique()->outT #else int Unique::GetOutType() const { return this->primitive_->value_as_Unique()->outType(); } - +int Unique::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Unique return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateUnique(*fbb, attr->outType()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Unique, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif int Unique::InferShape(std::vector inputs_, std::vector outputs_) { diff --git a/mindspore/lite/src/ops/unique.h b/mindspore/lite/src/ops/unique.h index 46c75b5005..4904a1e813 100644 --- a/mindspore/lite/src/ops/unique.h +++ b/mindspore/lite/src/ops/unique.h @@ -34,30 +34,9 @@ class Unique : public PrimitiveC { void SetOutType(int out_type); #else - explicit Unique(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Unique(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateUnique(fbb, attr->outType()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Unique, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Unique() = default; + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; - delete[] buf_bak; - fbb.Clear(); - return prim; - } #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetOutType() const; diff --git a/mindspore/lite/src/ops/unsqueeze.cc b/mindspore/lite/src/ops/unsqueeze.cc index bc1082b781..5d55cd19f2 100644 --- a/mindspore/lite/src/ops/unsqueeze.cc +++ b/mindspore/lite/src/ops/unsqueeze.cc @@ -32,7 +32,25 @@ std::vector Unsqueeze::GetAxis() const { auto fb_vector = this->primitive_->value_as_Unsqueeze()->axis(); return std::vector(fb_vector->begin(), fb_vector->end()); } - +int Unsqueeze::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Unsqueeze(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Unsqueeze return nullptr"; + return RET_ERROR; + } + std::vector axis; + if (attr->axis() != nullptr) { + for (int i = 0; i < static_cast(attr->axis()->size()); i++) { + axis.push_back(attr->axis()->data()[i]); + } + } + auto val_offset = schema::CreateUnsqueezeDirect(*fbb, &axis); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Unsqueeze, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif int Unsqueeze::InferShape(std::vector inputs_, std::vector outputs_) { diff --git a/mindspore/lite/src/ops/unsqueeze.h b/mindspore/lite/src/ops/unsqueeze.h index 5092496e42..36bc2b261c 100644 --- a/mindspore/lite/src/ops/unsqueeze.h +++ b/mindspore/lite/src/ops/unsqueeze.h @@ -35,35 +35,9 @@ class Unsqueeze : public PrimitiveC { void SetAxis(const std::vector &axis); #else - explicit Unsqueeze(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Unsqueeze(); - MS_ASSERT(attr != nullptr); - - auto axis = std::make_unique>(); - for (int i = 0; i < static_cast(attr->axis()->size()); i++) { - axis->push_back(attr->axis()->data()[i]); - } - - auto val_offset = schema::CreateUnsqueezeDirect(fbb, axis.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Unsqueeze, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Unsqueeze() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetAxis() const; diff --git a/mindspore/lite/src/ops/unstack.cc b/mindspore/lite/src/ops/unstack.cc index c5bee2f4e5..24da3faab6 100644 --- a/mindspore/lite/src/ops/unstack.cc +++ b/mindspore/lite/src/ops/unstack.cc @@ -29,7 +29,19 @@ void Unstack::SetAxis(int axis) { this->primitive_->value.AsUnstack()->axis = ax int Unstack::GetNum() const { return this->primitive_->value_as_Unstack()->num(); } int Unstack::GetAxis() const { return this->primitive_->value_as_Unstack()->axis(); } - +int Unstack::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Unstack(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Unstack return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateUnstack(*fbb, attr->num(), attr->axis()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Unstack, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif int Unstack::InferShape(std::vector inputs, std::vector outputs) { diff --git a/mindspore/lite/src/ops/unstack.h b/mindspore/lite/src/ops/unstack.h index b4590ae0e8..b0a24b672e 100644 --- a/mindspore/lite/src/ops/unstack.h +++ b/mindspore/lite/src/ops/unstack.h @@ -34,30 +34,9 @@ class Unstack : public PrimitiveC { void SetNum(int num); void SetAxis(int axis); #else - explicit Unstack(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Unstack(); - MS_ASSERT(attr != nullptr); - - auto val_offset = schema::CreateUnstack(fbb, attr->num(), attr->axis()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Unstack, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Unstack() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; int GetNum() const; diff --git a/mindspore/lite/src/ops/upsample.cc b/mindspore/lite/src/ops/upsample.cc index 5196231eda..10c9af70d8 100644 --- a/mindspore/lite/src/ops/upsample.cc +++ b/mindspore/lite/src/ops/upsample.cc @@ -33,7 +33,25 @@ std::vector Upsample::GetScales() const { auto fb_vector = this->primitive_->value_as_Upsample()->scales(); return std::vector(fb_vector->begin(), fb_vector->end()); } - +int Upsample::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Upsample(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Upsample return nullptr"; + return RET_ERROR; + } + std::vector scales; + if (attr->scales() != nullptr) { + for (int i = 0; i < static_cast(attr->scales()->size()); i++) { + scales.push_back(attr->scales()->data()[i]); + } + } + auto val_offset = schema::CreateUpsampleDirect(*fbb, attr->mode()->c_str(), &scales); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Upsample, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/upsample.h b/mindspore/lite/src/ops/upsample.h index bbf67e83fa..3d9ace6abc 100644 --- a/mindspore/lite/src/ops/upsample.h +++ b/mindspore/lite/src/ops/upsample.h @@ -36,35 +36,9 @@ class Upsample : public PrimitiveC { void SetMode(std::string mode); void SetScales(const std::vector &scales); #else - explicit Upsample(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Upsample(); - MS_ASSERT(attr != nullptr); - - auto scales = std::make_unique>(); - for (int i = 0; i < static_cast(attr->scales()->size()); i++) { - scales->push_back(attr->scales()->data()[i]); - } - - auto val_offset = schema::CreateUpsampleDirect(fbb, attr->mode()->c_str(), scales.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Upsample, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Upsample() = default; + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; - delete[] buf_bak; - fbb.Clear(); - return prim; - } #endif std::string GetMode() const; std::vector GetScales() const; diff --git a/mindspore/lite/src/ops/where.cc b/mindspore/lite/src/ops/where.cc index 9ff6df4286..39ca552670 100644 --- a/mindspore/lite/src/ops/where.cc +++ b/mindspore/lite/src/ops/where.cc @@ -31,7 +31,25 @@ std::vector Where::GetCondition() const { auto fb_vector = this->primitive_->value_as_Where()->condition(); return std::vector(fb_vector->begin(), fb_vector->end()); } - +int Where::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Where(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Where return nullptr"; + return RET_ERROR; + } + std::vector condition; + if (attr->condition() != nullptr) { + for (int i = 0; i < static_cast(attr->condition()->size()); i++) { + condition.push_back(attr->condition()->data()[i]); + } + } + auto val_offset = schema::CreateWhereDirect(*fbb, &condition); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Where, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} #endif int Where::InferShape(std::vector inputs_, std::vector outputs_) { diff --git a/mindspore/lite/src/ops/where.h b/mindspore/lite/src/ops/where.h index 7df2a83fac..9597c813e2 100644 --- a/mindspore/lite/src/ops/where.h +++ b/mindspore/lite/src/ops/where.h @@ -35,35 +35,9 @@ class Where : public PrimitiveC { void SetCondition(const std::vector &condition); #else - explicit Where(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto attr = primitive->value_as_Where(); - MS_ASSERT(attr != nullptr); - - auto condition = std::make_unique>(); - for (int i = 0; i < static_cast(attr->condition()->size()); i++) { - condition->push_back(attr->condition()->data()[i]); - } - - auto val_offset = schema::CreateWhereDirect(fbb, condition.release()); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Where, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); + Where() = default; - delete[] buf_bak; - fbb.Clear(); - return prim; - } + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; std::vector GetCondition() const; diff --git a/mindspore/lite/src/ops/zeros_like.cc b/mindspore/lite/src/ops/zeros_like.cc index 23e674617d..f4562e38fc 100644 --- a/mindspore/lite/src/ops/zeros_like.cc +++ b/mindspore/lite/src/ops/zeros_like.cc @@ -18,6 +18,20 @@ namespace mindspore { namespace lite { + +#ifdef PRIMITIVE_WRITEABLE +#else +int ZerosLike::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + + auto val_offset = schema::CreateZerosLike(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ZerosLike, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +#endif + int ZerosLike::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); @@ -37,5 +51,6 @@ int ZerosLike::InferShape(std::vector inputs_, std::vect output->set_shape(input->shape()); return RET_OK; } + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/zeros_like.h b/mindspore/lite/src/ops/zeros_like.h index 3a9d038923..08a0325d11 100644 --- a/mindspore/lite/src/ops/zeros_like.h +++ b/mindspore/lite/src/ops/zeros_like.h @@ -32,27 +32,8 @@ class ZerosLike : public PrimitiveC { ZerosLike() = default; explicit ZerosLike(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} #else - explicit ZerosLike(schema::Primitive *primitive) : PrimitiveC(primitive) {} - - schema::Primitive *Init(schema::Primitive *primitive) { - flatbuffers::FlatBufferBuilder fbb(1024); - - auto val_offset = schema::CreateZerosLike(fbb); - auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ZerosLike, val_offset.o); - fbb.Finish(prim_offset); - - auto buf = fbb.GetBufferPointer(); - MS_ASSERT(buf != nullptr); - auto buf_bak = new char[fbb.GetSize()]; - memcpy(buf_bak, buf, fbb.GetSize()); - - auto root = flatbuffers::GetRoot(buf_bak); - auto prim = const_cast(root); - - delete[] buf_bak; - fbb.Clear(); - return prim; - } + ZerosLike() = default; + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; #endif int InferShape(std::vector inputs_, std::vector outputs_) override; }; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc index 8f968c443f..d9c2d011d6 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc @@ -459,8 +459,6 @@ TEST_F(TestMatMulFp32, batch) { -17.63555145263672, -8.490625381469727, 5.317771911621094, -14.561882019042969, -7.251564025878906, -2.508212089538574, 5.86458683013916, -3.466249465942383, 8.869029998779297, 25.034008026123047}; - - float *output = reinterpret_cast(outputs_[0]->Data()); CompareOutputData(reinterpret_cast(outputs_[0]->Data()), correct, total_size, 0.0001); delete mm; for (auto t : inputs_) delete t; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc index 9c6d3e9359..5af17aaa93 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc @@ -65,7 +65,7 @@ TEST_F(TestTransposeFp32, TransposeFp32_axes4) { } auto ret = DoTranspose(in, out, input_shape, output_shape, param, 0, 3); - MS_ASSERT(ret == 0); + ASSERT_EQ(ret, 0); delete param; CompareOutputData(out, correct, 24, 0.000001); } @@ -105,7 +105,7 @@ TEST_F(TestTransposeFp32, TransposeFp32_axes3) { } auto ret = DoTranspose(in, out, input_shape, output_shape, param, 0, 3); - MS_ASSERT(ret == 0); + ASSERT_EQ(ret, 0); delete param; CompareOutputData(out, correct, 24, 0.000001); } @@ -146,7 +146,7 @@ TEST_F(TestTransposeFp32, TransposeFp32_axes2) { } auto ret = DoTranspose(in, out, input_shape, output_shape, param, 0, 6); - MS_ASSERT(ret == 0); + ASSERT_EQ(ret, 0); delete param; CompareOutputData(out, correct, 24, 0.000001); } diff --git a/mindspore/lite/tools/common/graph_util.cc b/mindspore/lite/tools/common/graph_util.cc index 333a20436a..1d2f13fda1 100644 --- a/mindspore/lite/tools/common/graph_util.cc +++ b/mindspore/lite/tools/common/graph_util.cc @@ -587,79 +587,5 @@ std::string GetModelName(const std::string &modelFile) { return modelName; } - -OpGraphT *OpGraphT::Build(const schema::MetaGraphT *subGraphDef) { - if (subGraphDef == nullptr) { - MS_LOG(ERROR) << "subGraphDef is nullptr"; - return nullptr; - } - auto graph = std::unique_ptr(new OpGraphT()); - if (graph == nullptr) { - MS_LOG(ERROR) << "malloc opgraph failed"; - return nullptr; - } - - auto &opDefs = subGraphDef->nodes; - - for (auto &opDef : opDefs) { - auto ret = graph->AddEdge(opDef.get(), &opDefs); - if (ret != RET_OK) { - MS_LOG(ERROR) << opDef->name.c_str() << " add edge failed. ret: " << ret; - return nullptr; - } - } - - return graph.release(); -} - -int OpGraphT::AddEdge(const schema::CNodeT *srcNodeDef, const std::vector> *nodeDefs) { - MS_ASSERT(srcNodeDef != nullptr); - MS_ASSERT(nodeDefs != nullptr); - NODE_ID srcId = std::string(srcNodeDef->name); - // for single op condition - AddNode(srcId); - for (auto index : srcNodeDef->outputIndex) { - for (auto &dstNodeDef : *nodeDefs) { - bool find = false; - auto inputIndex = dstNodeDef->inputIndex; - if (std::any_of(inputIndex.begin(), inputIndex.end(), [&index](size_t i) { return i == index; })) { - find = true; - } - - if (!find) { - continue; - } - NODE_ID dstId = std::string(dstNodeDef->name.c_str()); - auto ret = AddEdge(srcId, dstId); - if (ret != RET_OK) { - return ret; - } - } - } - return RET_OK; -} - -int OpGraphT::AddEdge(NODE_ID srcId, NODE_ID dstId) { - auto srcNode = AddNode(srcId); - if (srcNode == nullptr) { - MS_LOG(ERROR) << "add srcNode failed"; - return RET_ERROR; - } - srcNode->AddOutEdge(dstId); - auto dstNode = AddNode(dstId); - if (dstNode == nullptr) { - MS_LOG(ERROR) << "add dstNode failed"; - return RET_ERROR; - } - dstNode->AddInEdge(srcId); - return RET_OK; -} - -OpGraphT::~OpGraphT() { - for (auto iter : nodes) { - delete iter.second; - } - nodes.clear(); -} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/common/graph_util.h b/mindspore/lite/tools/common/graph_util.h index ce40fd1ac7..2aa53b75ff 100644 --- a/mindspore/lite/tools/common/graph_util.h +++ b/mindspore/lite/tools/common/graph_util.h @@ -88,17 +88,6 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz STATUS ValidateFileStr(const std::string &modelFile, std::string fileType); std::string GetModelName(const std::string &modelFile); - -class OpGraphT : public OpGraph { - public: - OpGraphT() {} - ~OpGraphT(); - static OpGraphT *Build(const schema::MetaGraphT *subGraphDef); - - private: - int AddEdge(NODE_ID srcId, NODE_ID dstId); - int AddEdge(const schema::CNodeT *srcNodeDef, const std::vector> *nodeDefs); -}; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/common/tensor_util.cc b/mindspore/lite/tools/common/tensor_util.cc index a864518544..f9e5838a2b 100644 --- a/mindspore/lite/tools/common/tensor_util.cc +++ b/mindspore/lite/tools/common/tensor_util.cc @@ -75,7 +75,7 @@ size_t GetShapeSize(const TensorT &tensor) { std::unique_ptr CopyTensorDefT(const std::unique_ptr &oldTensor) { auto newTensor = std::unique_ptr(new (std::nothrow) TensorT); if (newTensor == nullptr) { - // MS_LOG(ERROR)("new TensorT failed"); + MS_LOG(ERROR) << "new TensorT failed"; return nullptr; } newTensor->dims = oldTensor->dims; @@ -85,7 +85,7 @@ std::unique_ptr CopyTensorDefT(const std::unique_ptr &oldTenso newTensor->nodeType = oldTensor->nodeType; newTensor->data = oldTensor->data; if (!oldTensor->quantParams.empty()) { - newTensor->quantParams.emplace_back(std::move(GetTensorQuantParam(oldTensor))); + newTensor->quantParams.emplace_back(GetTensorQuantParam(oldTensor)); } return newTensor; } diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index bb1b9859e8..2ca1d8d4c6 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -164,7 +164,7 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An for (size_t j = 0; j < input_tensors.size(); ++j) { input_tensors[j]->SetFormat(schema::Format_NHWC); if (input_tensors[j]->shape().size() == 4) { - MS_LOG(WARNING) << "init input_tensor format to nhwc"; + MS_LOG(INFO) << "init input_tensor format to nhwc"; } } lite_primitive->InferShape(input_tensors, output_tensors);