|
|
|
@ -15,8 +15,7 @@
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include "tools/common/graph_util.h"
|
|
|
|
|
#include <stdlib.h>
|
|
|
|
|
#include <time.h>
|
|
|
|
|
#include <ctime>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <set>
|
|
|
|
|
#include "schema/inner/model_generated.h"
|
|
|
|
@ -29,7 +28,10 @@ namespace mindspore {
|
|
|
|
|
namespace lite {
|
|
|
|
|
OpDefCopyer GetSimpleOpCopyer() {
|
|
|
|
|
return [](CNodeT *inCNode) -> std::unique_ptr<CNodeT> {
|
|
|
|
|
std::unique_ptr<CNodeT> newCNode(new CNodeT);
|
|
|
|
|
std::unique_ptr<CNodeT> newCNode = std::make_unique<CNodeT>();
|
|
|
|
|
if (newCNode == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
newCNode->name = inCNode->name;
|
|
|
|
|
newCNode->quantType = inCNode->quantType;
|
|
|
|
@ -163,8 +165,6 @@ STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// whether need to remove weightInputTensores
|
|
|
|
|
// remove all node's outputTensors
|
|
|
|
|
RemoveTensor(graphT, outputTensorIdxes);
|
|
|
|
|
node->inputIndex.clear();
|
|
|
|
|
node->outputIndex.clear();
|
|
|
|
@ -183,8 +183,11 @@ STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool remove
|
|
|
|
|
MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodeT *node = graphT->nodes.at(nodeIdx).get();
|
|
|
|
|
if (node == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "node is null";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
auto inputTensorIdxes = node->inputIndex;
|
|
|
|
|
auto outputTensorIdxes = node->outputIndex;
|
|
|
|
|
auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
|
|
|
|
@ -244,6 +247,7 @@ STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTe
|
|
|
|
|
size_t nodeIdx = 0;
|
|
|
|
|
for (size_t i = 0; i < graphT->nodes.size(); i++) {
|
|
|
|
|
auto &inNode = graphT->nodes.at(i);
|
|
|
|
|
MS_ASSERT(inNode != nullptr);
|
|
|
|
|
if (inNode->name == node->name) {
|
|
|
|
|
isSubNode = true;
|
|
|
|
|
nodeIdx = i;
|
|
|
|
@ -259,6 +263,7 @@ STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTe
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) {
|
|
|
|
|
MS_ASSERT(graphT != nullptr);
|
|
|
|
|
for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) {
|
|
|
|
|
uint32_t deleteIdx = *iter;
|
|
|
|
|
if (!forceDelete) {
|
|
|
|
@ -297,6 +302,7 @@ STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTe
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS UpdateNodeIndex(CNodeT *node, uint32_t deleteIdx) {
|
|
|
|
|
MS_ASSERT(node != nullptr);
|
|
|
|
|
for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) {
|
|
|
|
|
if (*inIdxIt == deleteIdx) {
|
|
|
|
|
inIdxIt = node->inputIndex.erase(inIdxIt);
|
|
|
|
@ -330,6 +336,7 @@ STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_
|
|
|
|
|
graphT->allTensors.emplace_back(std::move(tensor));
|
|
|
|
|
uint32_t newTensorIdx = graphT->allTensors.size() - 1;
|
|
|
|
|
auto node = graphT->nodes.at(nodeIdx).get();
|
|
|
|
|
MS_ASSERT(node != nullptr);
|
|
|
|
|
if (place == kBefore) {
|
|
|
|
|
node->inputIndex.emplace_back(newTensorIdx);
|
|
|
|
|
} else {
|
|
|
|
@ -340,11 +347,13 @@ STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_
|
|
|
|
|
|
|
|
|
|
STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_t inTensorIdx,
|
|
|
|
|
std::unique_ptr<TensorT> tensor) {
|
|
|
|
|
MS_ASSERT(graphT != nullptr);
|
|
|
|
|
if (nodeIdx >= graphT->nodes.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
auto node = graphT->nodes.at(nodeIdx).get();
|
|
|
|
|
MS_ASSERT(node != nullptr);
|
|
|
|
|
if (inTensorIdx >= graphT->allTensors.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "inTensorIdx out of range: " << nodeIdx;
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
@ -358,7 +367,9 @@ STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex,
|
|
|
|
|
std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer) {
|
|
|
|
|
std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, const OpDefCopyer &opDefCopyer) {
|
|
|
|
|
MS_ASSERT(graphT != nullptr);
|
|
|
|
|
MS_ASSERT(errorCode != nullptr);
|
|
|
|
|
if (existNodeIdx >= graphT->nodes.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "nodeIdx out of range: " << existNodeIdx;
|
|
|
|
|
return graphT->nodes.end();
|
|
|
|
@ -370,7 +381,9 @@ NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPla
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx,
|
|
|
|
|
std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer) {
|
|
|
|
|
std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, const OpDefCopyer &opDefCopyer) {
|
|
|
|
|
MS_ASSERT(graphT != nullptr);
|
|
|
|
|
MS_ASSERT(errorCode != nullptr);
|
|
|
|
|
if (place == kBefore) {
|
|
|
|
|
return InsertNodeBefore(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, opDefCopyer);
|
|
|
|
|
} else if (place == kAfter) {
|
|
|
|
@ -382,7 +395,9 @@ NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPl
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx,
|
|
|
|
|
std::unique_ptr<CNodeT> toAddNodeIn, STATUS *errorCode, OpDefCopyer opDefCopyer) {
|
|
|
|
|
std::unique_ptr<CNodeT> toAddNodeIn, STATUS *errorCode, const OpDefCopyer &opDefCopyer) {
|
|
|
|
|
MS_ASSERT(graphT != nullptr);
|
|
|
|
|
MS_ASSERT(errorCode != nullptr);
|
|
|
|
|
auto &existNode = *existNodeIter;
|
|
|
|
|
MS_ASSERT(existNode != nullptr);
|
|
|
|
|
MS_ASSERT(existNode->inputIndex.size() > inputIndexIdx);
|
|
|
|
@ -390,7 +405,7 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
|
|
|
|
|
auto preTensorIdx = existNode->inputIndex.at(inputIndexIdx);
|
|
|
|
|
MS_ASSERT(graphT->allTensors.size() > preTensorIdx);
|
|
|
|
|
|
|
|
|
|
auto preNodeIdxes = GetInputNodeIdx(*graphT, *(existNode.get()), inputIndexIdx);
|
|
|
|
|
auto preNodeIdxes = GetInputNodeIdx(*graphT, *(existNode), inputIndexIdx);
|
|
|
|
|
if (preNodeIdxes.empty()) {
|
|
|
|
|
auto &preTensor = graphT->allTensors.at(preTensorIdx);
|
|
|
|
|
MS_ASSERT(preTensor != nullptr);
|
|
|
|
@ -402,9 +417,12 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
|
|
|
|
|
}
|
|
|
|
|
preTensor->refCount = 0;
|
|
|
|
|
preTensor->data.clear();
|
|
|
|
|
MS_ASSERT(toAddNodeIn->primitive != nullptr);
|
|
|
|
|
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
|
|
|
|
|
preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
|
|
|
|
|
toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
|
|
|
|
|
auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
|
|
|
|
|
MS_ASSERT(prim != nullptr);
|
|
|
|
|
preTensor->dataType = prim->srcT;
|
|
|
|
|
toAddTensor->dataType = prim->dstT;
|
|
|
|
|
}
|
|
|
|
|
graphT->allTensors.emplace_back(std::move(toAddTensor));
|
|
|
|
|
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
|
|
|
|
@ -438,9 +456,12 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
|
|
|
|
|
MS_LOG(ERROR) << "Copy TensorT failed";
|
|
|
|
|
return graphT->nodes.end();
|
|
|
|
|
}
|
|
|
|
|
MS_ASSERT(toAddNodeIn->primitive != nullptr);
|
|
|
|
|
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
|
|
|
|
|
preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
|
|
|
|
|
toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
|
|
|
|
|
auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
|
|
|
|
|
MS_ASSERT(prim != nullptr);
|
|
|
|
|
preTensor->dataType = prim->srcT;
|
|
|
|
|
toAddTensor->dataType = prim->dstT;
|
|
|
|
|
}
|
|
|
|
|
graphT->allTensors.emplace_back(std::move(toAddTensor));
|
|
|
|
|
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
|
|
|
|
@ -473,7 +494,10 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx,
|
|
|
|
|
std::unique_ptr<schema::CNodeT> toAddNodeIn, STATUS *errorCode, OpDefCopyer opDefCopyer) {
|
|
|
|
|
std::unique_ptr<schema::CNodeT> toAddNodeIn, STATUS *errorCode,
|
|
|
|
|
const OpDefCopyer &opDefCopyer) {
|
|
|
|
|
MS_ASSERT(graphT != nullptr);
|
|
|
|
|
MS_ASSERT(errorCode != nullptr);
|
|
|
|
|
auto &existNode = *existNodeIter;
|
|
|
|
|
MS_ASSERT(existNode != nullptr);
|
|
|
|
|
MS_ASSERT(existNode->outputIndex.size() > outputIndexIdx);
|
|
|
|
@ -481,7 +505,7 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
|
|
|
|
|
auto postTensorIdx = existNode->outputIndex.at(outputIndexIdx);
|
|
|
|
|
MS_ASSERT(graphT->allTensors.size() > postTensorIdx);
|
|
|
|
|
|
|
|
|
|
auto postNodeIdxes = GetOutputNodeIdx(*graphT, *(existNode.get()), outputIndexIdx);
|
|
|
|
|
auto postNodeIdxes = GetOutputNodeIdx(*graphT, *(existNode), outputIndexIdx);
|
|
|
|
|
if (postNodeIdxes.empty()) {
|
|
|
|
|
auto &postTensor = graphT->allTensors.at(postTensorIdx);
|
|
|
|
|
MS_ASSERT(postTensor != nullptr);
|
|
|
|
@ -491,9 +515,12 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
|
|
|
|
|
*errorCode = RET_NULL_PTR;
|
|
|
|
|
return graphT->nodes.end();
|
|
|
|
|
}
|
|
|
|
|
MS_ASSERT(toAddNodeIn->primitive != nullptr);
|
|
|
|
|
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
|
|
|
|
|
postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
|
|
|
|
|
toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
|
|
|
|
|
auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
|
|
|
|
|
MS_ASSERT(prim != nullptr);
|
|
|
|
|
postTensor->dataType = prim->srcT;
|
|
|
|
|
toAddTensor->dataType = prim->dstT;
|
|
|
|
|
}
|
|
|
|
|
graphT->allTensors.emplace_back(std::move(toAddTensor));
|
|
|
|
|
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
|
|
|
|
@ -554,9 +581,12 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
|
|
|
|
|
*errorCode = RET_NULL_PTR;
|
|
|
|
|
return graphT->nodes.end();
|
|
|
|
|
}
|
|
|
|
|
MS_ASSERT(toAddNodeIn->primitive != nullptr);
|
|
|
|
|
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
|
|
|
|
|
postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
|
|
|
|
|
toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
|
|
|
|
|
auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
|
|
|
|
|
MS_ASSERT(prim != nullptr);
|
|
|
|
|
postTensor->dataType = prim->srcT;
|
|
|
|
|
toAddTensor->dataType = prim->dstT;
|
|
|
|
|
}
|
|
|
|
|
graphT->allTensors.emplace_back(std::move(toAddTensor));
|
|
|
|
|
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
|
|
|
|
@ -589,13 +619,9 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
|
|
|
|
|
return existNodeIter;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS ValidateFileStr(const std::string &modelFile, std::string fileType) {
|
|
|
|
|
if (modelFile.size() > fileType.size()) {
|
|
|
|
|
if (modelFile.substr(modelFile.size() - fileType.size()) == fileType) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
} else {
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
STATUS ValidateFileStr(const std::string &modelFile, const std::string &fileType) {
|
|
|
|
|
if (modelFile.size() > fileType.size() && modelFile.substr(modelFile.size() - fileType.size()) == fileType) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
} else {
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|