From f00ecc3820bb80caab8fcc2a74cddbb83db0e7bc Mon Sep 17 00:00:00 2001 From: zhengjun10 Date: Fri, 4 Dec 2020 12:22:26 +0800 Subject: [PATCH] add tf subgraph parser and rm unused code --- .../fusion/batchnorm_fold_fusion_pass.cc | 509 ------------------ .../fusion/batchnorm_fold_fusion_pass.h | 86 --- .../converter/parser/tf/tf_logical_parser.cc | 63 +++ .../converter/parser/tf/tf_logical_parser.h | 37 ++ .../converter/parser/tf/tf_model_parser.cc | 408 ++++++++++---- .../converter/parser/tf/tf_model_parser.h | 53 +- .../converter/parser/tf/tf_while_parser.cc | 62 +++ .../converter/parser/tf/tf_while_parser.h | 37 ++ 8 files changed, 534 insertions(+), 721 deletions(-) delete mode 100644 mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.cc delete mode 100644 mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_logical_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_logical_parser.h create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_while_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/tf/tf_while_parser.h diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.cc deleted file mode 100644 index 9016a3e134..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.cc +++ /dev/null @@ -1,509 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h" -#include -#include -#include -#include -#include -#include -#include "src/common/log_adapter.h" -#include "tools/common/graph_util.h" -#include "tools/common/tensor_util.h" -#include "include/errorcode.h" -#include "schema/inner/model_generated.h" - -namespace mindspore { -namespace lite { -#define kBatchNormFoldFusionPathLen6 6 -#define kBatchNormFoldFusionPathLen7 7 - -STATUS BatchNormFoldFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); } - -STATUS BatchNormFoldFusionPass::DefinePattern() { - // with preNode - { - auto inputOp = std::make_shared(); - inputOp->id = inputOpName; - inputOp->types = {schema::PrimitiveType_NONE}; - inputOp->isPlaceHold = true; - - auto convOp1 = std::make_shared(); - convOp1->id = convPatternOpName1; - convOp1->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; - convOp1->left = inputOp; - - auto bnFoldOp = std::make_shared(); - bnFoldOp->id = bnFoldOpName; - bnFoldOp->types = {schema::PrimitiveType_BatchNormFold}; - bnFoldOp->left = convOp1; - - auto mulFoldOp = std::make_shared(); - mulFoldOp->id = mulFoldOpName; - mulFoldOp->types = {schema::PrimitiveType_MulFold}; - mulFoldOp->left = bnFoldOp; - - auto fakeQuantOp = std::make_shared(); - fakeQuantOp->id = fakeQuantOpName; - fakeQuantOp->types = {schema::PrimitiveType_FakeQuantWithMinMax}; - fakeQuantOp->left = mulFoldOp; - - auto convOp2 = std::make_shared(); - convOp2->id = convPatternOpName2; - convOp2->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; - convOp2->left = fakeQuantOp; - convOp2->right = inputOp; - - auto addFoldOp = std::make_shared(); - addFoldOp->id = addFoldOpName; - addFoldOp->types = {schema::PrimitiveType_AddFold}; - addFoldOp->left = convOp2; - addFoldOp->right = bnFoldOp; - - std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern(withPrePatternName)); - if (fusionPattern == nullptr) { - MS_LOG(ERROR) << "new fusionPattern failed"; - return RET_ERROR; - } - fusionPattern->AddPatternOp(inputOp); - fusionPattern->AddPatternOp(convOp1); - fusionPattern->AddPatternOp(bnFoldOp); - fusionPattern->AddPatternOp(mulFoldOp); - fusionPattern->AddPatternOp(fakeQuantOp); - fusionPattern->AddPatternOp(convOp2); - fusionPattern->AddPatternOp(addFoldOp); - fusionPattern->Finish(); - - this->patterns.emplace_back(fusionPattern.release()); - } - // no preNode - { - auto convOp1 = std::make_shared(); - convOp1->id = convPatternOpName1; - convOp1->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; - - auto bnFoldOp = std::make_shared(); - bnFoldOp->id = bnFoldOpName; - bnFoldOp->types = {schema::PrimitiveType_BatchNormFold}; - bnFoldOp->left = convOp1; - - auto mulFoldOp = std::make_shared(); - mulFoldOp->id = mulFoldOpName; - mulFoldOp->types = {schema::PrimitiveType_MulFold}; - mulFoldOp->left = bnFoldOp; - - auto fakeQuantOp = std::make_shared(); - fakeQuantOp->id = fakeQuantOpName; - fakeQuantOp->types = {schema::PrimitiveType_FakeQuantWithMinMax}; - fakeQuantOp->left = mulFoldOp; - - auto convOp2 = std::make_shared(); - convOp2->id = convPatternOpName2; - convOp2->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D}; - convOp2->left = fakeQuantOp; - - auto addFoldOp = std::make_shared(); - addFoldOp->id = addFoldOpName; - addFoldOp->types = {schema::PrimitiveType_AddFold}; - addFoldOp->left = convOp2; - addFoldOp->right = bnFoldOp; - - std::unique_ptr fusionPattern(new (std::nothrow) FusionPattern(noPrePatternName)); - if (fusionPattern == nullptr) { - MS_LOG(ERROR) << "new fusionPattern failed"; - return RET_ERROR; - } - fusionPattern->AddPatternOp(convOp1); - fusionPattern->AddPatternOp(bnFoldOp); - fusionPattern->AddPatternOp(mulFoldOp); - fusionPattern->AddPatternOp(fakeQuantOp); - fusionPattern->AddPatternOp(convOp2); - fusionPattern->AddPatternOp(addFoldOp); - fusionPattern->Finish(); - - this->patterns.emplace_back(fusionPattern.release()); - } - return RET_OK; -} - -STATUS BatchNormFoldFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) { - MS_ASSERT(graph != nullptr); - if (patternName == withPrePatternName) { - if (matchedPath.size() != kBatchNormFoldFusionPathLen7) { - MS_LOG(ERROR) << "BatchNormFold-Fusion should have seven NodeIndex in matchedPair"; - return RET_PARAM_INVALID; - } - } else if (patternName == noPrePatternName) { - if (matchedPath.size() != kBatchNormFoldFusionPathLen6) { - MS_LOG(ERROR) << "BatchNormFold-Fusion should have six NodeIndex in matchedPair"; - return RET_PARAM_INVALID; - } - } - - auto status = FindNodes(graph, matchedPath); - if (status != RET_OK) { - MS_LOG(ERROR) << "FindNodes failed: " << status; - return status; - } - status = CheckPath(graph, matchedPath); - if (status != RET_OK) { - MS_LOG(ERROR) << "CheckPath failed: " << status; - return status; - } - status = FindTensors(); - if (status != RET_OK) { - MS_LOG(ERROR) << "FindTensors failed: " << status; - return status; - } - status = GenNewWeightTensor(); - if (status != RET_OK) { - MS_LOG(ERROR) << "GenNewWeightTensor failed: " << status; - return status; - } - status = GenNewBiasTensor(); - if (status != RET_OK) { - MS_LOG(ERROR) << "GenNewBiasTensor failed: " << status; - return status; - } - status = IsolateNodes(graph, matchedPath); - if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateNodes failed: " << status; - return status; - } - UpdateConvWeights(); - status = DeleteConstTensors(); - if (status != RET_OK) { - MS_LOG(ERROR) << "DeleteConstTensors failed: " << status; - return status; - } - return RET_OK; -} - -STATUS BatchNormFoldFusionPass::FindNodes(MetaGraphT *graph, - const std::unordered_map> &matchedPath) { - MS_ASSERT(graph != nullptr); - auto preConvPath = matchedPath.at(convPatternOpName1); - auto bnFoldPath = matchedPath.at(bnFoldOpName); - auto mulFoldPath = matchedPath.at(mulFoldOpName); - auto fakeQuantPath = matchedPath.at(fakeQuantOpName); - auto convPath = matchedPath.at(convPatternOpName2); - auto addFoldPath = matchedPath.at(addFoldOpName); - MS_ASSERT(preConvPath != nullptr); - MS_ASSERT(bnFoldPath != nullptr); - MS_ASSERT(mulFoldPath != nullptr); - MS_ASSERT(fakeQuantPath != nullptr); - MS_ASSERT(convPath != nullptr); - MS_ASSERT(addFoldPath != nullptr); - if (preConvPath->subGraphIdx != bnFoldPath->subGraphIdx || preConvPath->subGraphIdx != mulFoldPath->subGraphIdx || - preConvPath->subGraphIdx != fakeQuantPath->subGraphIdx || preConvPath->subGraphIdx != convPath->subGraphIdx || - preConvPath->subGraphIdx != addFoldPath->subGraphIdx) { - MS_LOG(ERROR) << "matched nodes should from same subGraph"; - return RET_ERROR; - } - MS_ASSERT(graph->nodes.size() > preConvPath->nodeIdx); - MS_ASSERT(graph->nodes.size() > bnFoldPath->nodeIdx); - MS_ASSERT(graph->nodes.size() > mulFoldPath->nodeIdx); - MS_ASSERT(graph->nodes.size() > fakeQuantPath->nodeIdx); - MS_ASSERT(graph->nodes.size() > convPath->nodeIdx); - MS_ASSERT(graph->nodes.size() > addFoldPath->nodeIdx); - preConv = graph->nodes.at(preConvPath->nodeIdx).get(); - bnFold = graph->nodes.at(bnFoldPath->nodeIdx).get(); - mulFold = graph->nodes.at(mulFoldPath->nodeIdx).get(); - fakeNode = graph->nodes.at(fakeQuantPath->nodeIdx).get(); - convNode = graph->nodes.at(convPath->nodeIdx).get(); - addFold = graph->nodes.at(addFoldPath->nodeIdx).get(); - MS_ASSERT(preConv != nullptr); - MS_ASSERT(bnFold != nullptr); - MS_ASSERT(mulFold != nullptr); - MS_ASSERT(fakeNode != nullptr); - MS_ASSERT(convNode != nullptr); - MS_ASSERT(addFold != nullptr); - return RET_OK; -} - -STATUS BatchNormFoldFusionPass::FindTensors() { - MS_ASSERT(graph != nullptr); - MS_ASSERT(bnFold != nullptr); - MS_ASSERT(addFold != nullptr); - if (bnFold->inputIndex.size() != 4) { - MS_LOG(ERROR) << "BatchNormFold node should have 4 inputTensor, got " << bnFold->inputIndex.size() - << " input tensors"; - return RET_ERROR; - } - if (addFold->inputIndex.size() != 5) { - MS_LOG(ERROR) << "AddFold node should have 5 inputTensor, got " << addFold->inputIndex.size() << " input tensors"; - return RET_ERROR; - } - MS_ASSERT(graph->allTensors.size() > bnFold->inputIndex.at(1)); - muTensor = graph->allTensors.at(bnFold->inputIndex.at(1)).get(); - MS_ASSERT(muTensor != nullptr); - MS_ASSERT(graph->allTensors.size() > bnFold->inputIndex.at(2)); - sigmaTensor = graph->allTensors.at(bnFold->inputIndex.at(2)).get(); - MS_ASSERT(sigmaTensor != nullptr); - MS_ASSERT(graph->allTensors.size() > addFold->inputIndex.at(1)); - betaTensor = graph->allTensors.at(addFold->inputIndex.at(1)).get(); - MS_ASSERT(betaTensor != nullptr); - MS_ASSERT(graph->allTensors.size() > addFold->inputIndex.at(2)); - gammaTensor = graph->allTensors.at(addFold->inputIndex.at(2)).get(); - MS_ASSERT(gammaTensor != nullptr); - - if (betaTensor->dims.size() != 1) { - MS_LOG(ERROR) << "ConstTensor should have only one dim, got " << betaTensor->dims.size(); - return RET_ERROR; - } - if (betaTensor->dims != gammaTensor->dims || betaTensor->dims != sigmaTensor->dims || - betaTensor->dims != muTensor->dims) { - MS_LOG(ERROR) << "All ConstTensor should have same dims"; - return RET_ERROR; - } - channelOut = betaTensor->dims.front(); - - MS_ASSERT(mulFold != nullptr); - if (mulFold->inputIndex.size() != 3) { - MS_LOG(ERROR) << "MulFold node should have 3 outputTensor, got " << addFold->inputIndex.size() << " output tensors"; - return RET_ERROR; - } - MS_ASSERT(graph->allTensors.size() > mulFold->inputIndex.front()); - oldWeightTensor = graph->allTensors.at(mulFold->inputIndex.front()).get(); - MS_ASSERT(oldWeightTensor != nullptr); - return RET_OK; -} - -STATUS BatchNormFoldFusionPass::CheckPath(MetaGraphT *graph, - const std::unordered_map> &matchedPath) { - MS_ASSERT(preConv != nullptr); - MS_ASSERT(convNode != nullptr); - MS_ASSERT(mulFold != nullptr); - MS_ASSERT(preConv->inputIndex.size() == 2); - MS_ASSERT(convNode->inputIndex.size() == 2); - MS_ASSERT(mulFold->inputIndex.size() == 3); - MS_ASSERT(preConv->inputIndex.front() == convNode->inputIndex.front()); - MS_ASSERT(preConv->inputIndex.at(1) == mulFold->inputIndex.front()); - return RET_OK; -} - -STATUS BatchNormFoldFusionPass::GenNewWeightTensor() { - MS_ASSERT(oldWeightTensor != nullptr); - MS_ASSERT(oldWeightTensor->dataType == DataType_DT_FLOAT); - MS_ASSERT(oldWeightTensor->refCount == schema::NodeType::NodeType_ValueNode); - auto weightShape = oldWeightTensor->dims; - if (weightShape.size() != 4) { - MS_LOG(ERROR) << "shape of weight should be 4 dims, got " << weightShape.size() << " dims"; - return RET_ERROR; - } - if (weightShape.front() != channelOut) { - MS_LOG(ERROR) << "weight should be in KCHW format, and outputChannel should be " << channelOut; - return RET_ERROR; - } - auto weightShapeSize = GetShapeSize(*oldWeightTensor); - newWeightTensor = std::unique_ptr(new (std::nothrow) TensorT); - if (newWeightTensor == nullptr) { - MS_LOG(ERROR) << "new weightTensor failed"; - return RET_ERROR; - } - newWeightTensor->dataType = oldWeightTensor->dataType; - newWeightTensor->format = oldWeightTensor->format; - newWeightTensor->refCount = schema::NodeType::NodeType_ValueNode; - newWeightTensor->dims = weightShape; - newWeightTensor->data.resize(weightShapeSize * sizeof(float)); - void *oldWeightData = oldWeightTensor->data.data(); - auto castedOldWeightData = static_cast(oldWeightData); - void *newWeightData = newWeightTensor->data.data(); - auto castedNewWeightData = static_cast(newWeightData); - MS_ASSERT(gammaTensor->dataType == DataType_DT_FLOAT); - void *gammaData = gammaTensor->data.data(); - auto *castedGammaData = static_cast(gammaData); - MS_ASSERT(muTensor->dataType == DataType_DT_FLOAT); - void *miData = muTensor->data.data(); - auto *castedMiData = static_cast(miData); - if (channelOut == 0) { - MS_LOG(ERROR) << "divisor 'channelOut' cannot be 0"; - return RET_ERROR; - } - size_t stride = weightShapeSize / channelOut; - for (int i = 0; i < channelOut; i++) { - for (size_t j = 0; j < stride; j++) { - if (fabs(castedMiData[i]) <= 0.0f) { - MS_LOG(ERROR) << "divisor 'castedMiData' cannot be 0"; - return RET_ERROR; - } - castedNewWeightData[i * stride + j] = castedOldWeightData[i * stride + j] * castedGammaData[i] / castedMiData[i]; - } - } - return RET_OK; -} - -STATUS BatchNormFoldFusionPass::GenNewBiasTensor() { // bias has no quant - std::vector biasShape = {channelOut}; - newBiasTensor = std::unique_ptr(new (std::nothrow) TensorT); - if (newBiasTensor == nullptr) { - MS_LOG(ERROR) << "new BiasTensor failed"; - return RET_ERROR; - } - newBiasTensor->dataType = 0; - newBiasTensor->format = schema::Format::Format_NUM_OF_FORMAT; - newBiasTensor->refCount = schema::NodeType::NodeType_ValueNode; - newBiasTensor->dims = biasShape; - newBiasTensor->data.resize(channelOut * sizeof(float)); - void *newBiasData = newBiasTensor->data.data(); - auto castedNewBiasData = static_cast(newBiasData); - MS_ASSERT(betaTensor->dataType == DataType_DT_FLOAT); - void *betaData = betaTensor->data.data(); - auto *castedBetaData = static_cast(betaData); - MS_ASSERT(gammaTensor->dataType == DataType_DT_FLOAT); - void *gammaData = gammaTensor->data.data(); - auto *castedGammaData = static_cast(gammaData); - MS_ASSERT(muTensor->dataType == DataType_DT_FLOAT); - void *miData = muTensor->data.data(); - auto *castedMiData = static_cast(miData); - MS_ASSERT(sigmaTensor->dataType == DataType_DT_FLOAT); - void *sigmaData = sigmaTensor->data.data(); - auto *castedSigmaData = static_cast(sigmaData); - for (int i = 0; i < channelOut; i++) { - if (fabs(castedSigmaData[i]) <= 0.0f) { - MS_LOG(ERROR) << "divisor 'castedSigmaData' cannot be 0"; - return RET_ERROR; - } - castedNewBiasData[i] = castedBetaData[i] - castedGammaData[i] * castedMiData[i] / castedSigmaData[i]; - } - return RET_OK; -} - -STATUS BatchNormFoldFusionPass::IsolateNodes( - MetaGraphT *graph, const std::unordered_map> &matchedPath) { - MS_ASSERT(graph != nullptr); - auto preConvPath = matchedPath.at(convPatternOpName1); - auto bnFoldPath = matchedPath.at(bnFoldOpName); - auto mulFoldPath = matchedPath.at(mulFoldOpName); - auto fakeQuantPath = matchedPath.at(fakeQuantOpName); - auto convPath = matchedPath.at(convPatternOpName2); - auto addFoldPath = matchedPath.at(addFoldOpName); - MS_ASSERT(preConvPath != nullptr); - MS_ASSERT(bnFoldPath != nullptr); - MS_ASSERT(mulFoldPath != nullptr); - MS_ASSERT(fakeQuantPath != nullptr); - MS_ASSERT(convPath != nullptr); - MS_ASSERT(addFoldPath != nullptr); - auto status = IsolateOneWayNode(graph, preConvPath->nodeIdx); - if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateOneWayNode " << preConv->name.c_str() << " failed, error: " << status; - return status; - } - std::vector toDeleteTensorIdxes; - toDeleteTensorIdxes.emplace_back(bnFold->inputIndex.at(3)); - toDeleteTensorIdxes.insert(toDeleteTensorIdxes.end(), bnFold->outputIndex.begin(), bnFold->outputIndex.end()); - status = RemoveTensor(graph, toDeleteTensorIdxes, true); - if (status != RET_OK) { - MS_LOG(ERROR) << "Remove Tensors of BnFold " << bnFold->name.c_str() << " failed, error: " << status; - return RET_ERROR; - } - status = IsolateOneWayNode(graph, bnFoldPath->nodeIdx); - if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateOneWayNode " << bnFold->name.c_str() << " failed, error: " << status; - return status; - } - status = IsolateOneWayNode(graph, mulFoldPath->nodeIdx); - if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateOneWayNode " << mulFold->name.c_str() << " failed, error: " << status; - return status; - } - status = IsolateOneWayNode(graph, addFoldPath->nodeIdx); - if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateOneWayNode " << addFold->name.c_str() << " failed, error: " << status; - return status; - } - return RET_OK; -} - -void BatchNormFoldFusionPass::UpdateConvWeights() { - MS_ASSERT(graph != nullptr); - MS_ASSERT(convNode != nullptr); - MS_ASSERT(newWeightTensor != nullptr); - MS_ASSERT(newBiasTensor != nullptr); - MS_ASSERT(graph->allTensors.size() > fakeNode->inputIndex.at(0)); - graph->allTensors.at(fakeNode->inputIndex.at(0)).reset(); - graph->allTensors.at(fakeNode->inputIndex.at(0)) = std::move(this->newWeightTensor); - graph->allTensors.emplace_back(std::move(this->newBiasTensor)); - convNode->inputIndex.emplace_back(graph->allTensors.size() - 1); - if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) { - convNode->primitive->value.AsConv2D()->hasBias = true; - } else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { - convNode->primitive->value.AsDepthwiseConv2D()->hasBias = true; - } else { - MS_ASSERT(false); - } - - this->oldWeightTensor = nullptr; - this->newWeightTensor = nullptr; - this->newBiasTensor = nullptr; -} - -STATUS BatchNormFoldFusionPass::DeleteConstTensors() { - MS_ASSERT(graph != nullptr); - bool muFind = false; - bool sigmaFind = false; - bool betaFind = false; - bool gammaFind = false; - std::vector toDeleteTensorIdxes; - for (size_t i = 0; i < graph->allTensors.size(); i++) { - auto &tensor = graph->allTensors.at(i); - if (tensor.get() == muTensor) { - toDeleteTensorIdxes.emplace_back(i); - muFind = true; - this->muTensor = nullptr; - } - if (tensor.get() == sigmaTensor) { - toDeleteTensorIdxes.emplace_back(i); - sigmaFind = true; - this->sigmaTensor = nullptr; - } - if (tensor.get() == gammaTensor) { - toDeleteTensorIdxes.emplace_back(i); - gammaFind = true; - this->gammaTensor = nullptr; - } - if (tensor.get() == betaTensor) { - toDeleteTensorIdxes.emplace_back(i); - betaFind = true; - this->betaTensor = nullptr; - } - } - if (!muFind || !sigmaFind || !betaFind || !gammaFind) { - MS_LOG(ERROR) << "Can not find muTensor or sigmaTensor or betaTensor or gammaTensor in graph"; - return RET_ERROR; - } - auto status = RemoveTensor(graph, toDeleteTensorIdxes); - if (status != RET_OK) { - MS_LOG(ERROR) << "Remove ConstTensors failed" << bnFold->name.c_str(); - return RET_ERROR; - } - return RET_OK; -} - -BatchNormFoldFusionPass::~BatchNormFoldFusionPass() { - if (newWeightTensor == nullptr) { - newWeightTensor.reset(); - newWeightTensor = nullptr; - } - if (newBiasTensor == nullptr) { - newBiasTensor.reset(); - newBiasTensor = nullptr; - } -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h deleted file mode 100644 index b0c4c4d604..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_PREDICT_BATCHNORM_FOLD_FUSION_PASS_H -#define MINDSPORE_PREDICT_BATCHNORM_FOLD_FUSION_PASS_H - -#include -#include -#include -#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" - -namespace mindspore { -namespace lite { -// input = input -// weight = SimQuantPerChannel(weight * gamma / sigma) -// bias = beta - gamma * mi / sigma -// MulFold: gamma sigma -// BatchNormFold: mi sigma -// AddFold: gamma beta mi sigma -class BatchNormFoldFusionPass : public FusionPass { - public: - BatchNormFoldFusionPass() = default; - - ~BatchNormFoldFusionPass() override; - - STATUS DefinePattern() override; - - STATUS DoFusion(MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) override; - - STATUS Run(MetaGraphT *graph) override; - - protected: - STATUS FindNodes(MetaGraphT *graph, const std::unordered_map> &matchedPath); - STATUS CheckPath(MetaGraphT *graph, const std::unordered_map> &matchedPath); - STATUS FindTensors(); - STATUS GenNewWeightTensor(); - STATUS GenNewBiasTensor(); - STATUS IsolateNodes(MetaGraphT *graph, const std::unordered_map> &matchedPath); - void UpdateConvWeights(); - STATUS DeleteConstTensors(); - - protected: - MetaGraphT *graph = nullptr; - CNodeT *preConv = nullptr; - CNodeT *bnFold = nullptr; - CNodeT *mulFold = nullptr; - CNodeT *fakeNode = nullptr; - CNodeT *convNode = nullptr; - CNodeT *addFold = nullptr; - TensorT *muTensor = nullptr; - TensorT *sigmaTensor = nullptr; - TensorT *gammaTensor = nullptr; - TensorT *betaTensor = nullptr; - TensorT *oldWeightTensor = nullptr; - int32_t channelOut = 0; - - std::unique_ptr newWeightTensor = nullptr; - std::unique_ptr newBiasTensor = nullptr; - - std::string inputOpName = "Input"; - std::string convPatternOpName1 = "Convolution1"; - std::string bnFoldOpName = "BatchNormFold"; - std::string mulFoldOpName = "MulFold"; - std::string fakeQuantOpName = "FakeQuant"; - std::string convPatternOpName2 = "Convolution2"; - std::string addFoldOpName = "AddFold"; - std::string withPrePatternName = "BNFoldFusionWithPre"; - std::string noPrePatternName = "BNFoldFusionNoPre"; -}; -} // namespace lite -} // namespace mindspore -#endif // MINDSPORE_PREDICT_BATCHNORM_FOLD_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/parser/tf/tf_logical_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_logical_parser.cc new file mode 100644 index 0000000000..b362bfa2a0 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_logical_parser.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_logical_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFLogicalParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF LogicalParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is nullptr"; + return RET_NULL_PTR; + } + if (tf_op.op() == "LogicalAnd") { + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + primitive->value.type = schema::PrimitiveType_LogicalAnd; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + } + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + + return RET_OK; +} +TFNodeRegistrar g_tfLogicalAndParser("LogicalAnd", new TFLogicalParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_logical_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_logical_parser.h new file mode 100644 index 0000000000..c06893f9fd --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_logical_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LOGICAL_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LOGICAL_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFLogicalParser : public TFNodeParser { + public: + TFLogicalParser() = default; + ~TFLogicalParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_LOGICAL_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index 64363dad53..baa688a486 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -17,37 +17,57 @@ #include "tools/converter/parser/tf/tf_model_parser.h" #include +#include #include -#include "src/common/utils.h" #include "src/common/log_adapter.h" -#include "tools/common/graph_util.h" -#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "src/common/utils.h" #include "src/param_value_lite.h" +#include "tools/common/graph_util.h" #include "tools/common/protobuf_utils.h" +#include "tools/converter/parser/tf/tf_node_parser_registry.h" namespace mindspore { namespace lite { - -AnfNodePtr TFModelParser::GetAnfNode(const std::string &name) { +namespace { +// subgraph node input may be a:output:0/a:z:0 +std::string GetFlattenNodeName(std::string input_name) { + std::regex re("\\:+"); + std::vector input_splits(std::sregex_token_iterator(input_name.begin(), input_name.end(), re, -1), + std::sregex_token_iterator()); + if (input_splits.size() == 3) { + if (input_splits[2] == "0") { + input_name = input_splits[0]; + } else { + input_name = input_splits[0] + input_splits[2]; // multi output node + } + } + return input_name; +} +AnfNodePtr GetAnfNode(const std::string &name, const std::unordered_map &anf_node_map) { AnfNodePtr ret = nullptr; if (anf_node_map.find(name) != anf_node_map.end()) { - ret = anf_node_map[name]; + ret = anf_node_map.at(name); } else if (anf_node_map.find(name + ":0") != anf_node_map.end()) { - ret = anf_node_map[name + ":0"]; + ret = anf_node_map.at(name + ":0"); } return ret; } -std::string TFModelParser::GetOriginInputName(const tensorflow::NodeDef &node) { +std::string GetOriginInputName(const tensorflow::NodeDef &node, + const std::map &tf_graph_nodes) { if (node.op() != "Identity" && node.op() != "StopGradient") { return node.name(); } auto tmp_node = &node; while (tmp_node->op() == "Identity" || tmp_node->op() == "StopGradient") { - tmp_node = tf_node_map[tmp_node->input(0)]; + if (tf_graph_nodes.find(tmp_node->input(0)) == tf_graph_nodes.end()) { + return tmp_node->input(0); + } + tmp_node = tf_graph_nodes.at(tmp_node->input(0)); } return tmp_node->name(); } +} // namespace STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, const ParameterPtr ¶meter, std::vector *shape_vector) { @@ -126,11 +146,11 @@ STATUS TFModelParser::ConvertConstTensor(const tensorflow::AttrValue &attr_value param_value->set_tensor_type(type); param_value->set_format(schema::Format::Format_NHWC); parameter->set_default_param(param_value); - parameter->set_name("const_" + std::to_string(anf_node_map.size()) + "_parameter"); return RET_OK; } -STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter) { +STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter, + std::unordered_map *anf_node_map) { MS_ASSERT(node != nullptr); MS_ASSERT(parameter != nullptr); @@ -157,8 +177,7 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa return status; } } else { - parameter->set_name("placeholder_" + std::to_string(anf_node_map.size())); - graph_input_names.emplace_back(parameter->name()); + graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names } auto abstract_tensor = std::make_shared(type_ptr, shape_vector); @@ -166,14 +185,19 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa MS_LOG(ERROR) << "abstract_tensor is nullptr"; return RET_ERROR; } + parameter->set_name(node.name()); parameter->set_abstract(abstract_tensor); - anf_node_map[node.name()] = parameter; + (*anf_node_map)[node.name()] = parameter; + (*anf_node_map)[node.name() + ":0"] = parameter; + return RET_OK; } -STATUS TFModelParser::ConvertGraphInputsAndConsts() { - for (auto &pair : tf_node_map) { +STATUS TFModelParser::ConvertGraphInputsAndConsts( + const std::map &tf_graph_nodes, const FuncGraphPtr &anf_graph, + std::unordered_map *anf_node_map) { + for (auto &pair : tf_graph_nodes) { bool have_data_depend = false; for (int i = 0; i < pair.second->input_size(); ++i) { auto name = pair.second->input(i); @@ -183,8 +207,8 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts() { } } if (!have_data_depend) { - auto parameter = funcGraphPtr->add_parameter(); - if (ConvertParameter(*pair.second, parameter) != RET_OK) { + auto parameter = anf_graph->add_parameter(); + if (ConvertParameter(*pair.second, parameter, anf_node_map) != RET_OK) { MS_LOG(ERROR) << "convert Parameter Node failed"; return RET_ERROR; } @@ -192,7 +216,7 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts() { } return RET_OK; } - +FuncGraphPtr paserTfFuction() { return nullptr; } FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) { auto status = ValidateFileStr(modelFile, ".pb"); @@ -201,51 +225,189 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } - tf_graph_def = std::make_unique(); - if (tf_graph_def == nullptr) { - MS_LOG(ERROR) << "tf_graph_def is nullptr"; + tf_root_graph_ = std::make_unique(); + if (tf_root_graph_ == nullptr) { + MS_LOG(ERROR) << "tf_root_graph_ is nullptr"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); return nullptr; } - status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_graph_def.get()); + status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get()); if (status != RET_OK) { MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); return nullptr; } - funcGraphPtr = std::make_shared(); - if (funcGraphPtr == nullptr) { + anf_root_graph_ = std::make_shared(); + if (anf_root_graph_ == nullptr) { MS_LOG(ERROR) << "funGraphPtr is nullptr"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); return nullptr; } - for (int i = 0; i < tf_graph_def->node_size(); i++) { - auto &node_def = tf_graph_def->node(i); - tf_node_map[node_def.name()] = &node_def; + for (int i = 0; i < tf_root_graph_->node_size(); i++) { + auto &node_def = tf_root_graph_->node(i); + tf_root_graph_nodes_[node_def.name()] = &node_def; } - status = ConvertGraphInputsAndConsts(); + status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); if (status != RET_OK) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } - - status = ConvertOps(); - if (status != RET_OK) { + bool success_flag = true; + for (int i = 0; i < tf_root_graph_->node_size(); i++) { + auto &node_def = tf_root_graph_->node(i); + status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); + if (status != RET_OK) { + success_flag = false; + } + } + if (!success_flag) { MS_LOG(ERROR) << "Convert ops failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } - - status = ConvertGraphOutputs(); + status = ConvertRootGraphOutputs(); if (status != RET_OK) { MS_LOG(ERROR) << "Convert graph outputs failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; } - return funcGraphPtr; + + status = ConvertSubgraph(); + if (status != RET_OK) { + MS_LOG(ERROR) << "Convert subgraph failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return nullptr; + } + + return anf_root_graph_; +} +STATUS TFModelParser::ConvertSubgraph() { + auto graph_def_liarary = tf_root_graph_->library(); + auto subgraph_size = graph_def_liarary.function_size(); + std::map while_cond_map; + std::map while_body_map; + std::vector sub_graph_inputs; + for (int i = 0; i < subgraph_size; i++) { + auto &tf_sub_fuction = graph_def_liarary.function(i); + auto &tf_sub_signature = tf_sub_fuction.signature(); + auto input_arg_size = tf_sub_signature.input_arg_size(); + + auto &sub_graph_name = tf_sub_signature.name(); + if (!function_while_map_.count(sub_graph_name)) { + MS_LOG(ERROR) << "function map not contains sub graph name." << sub_graph_name; + return RET_ERROR; + } + auto while_cnode = function_while_map_[sub_graph_name]->cast(); + if (while_cnode == nullptr || static_cast(while_cnode->inputs().size()) != input_arg_size + 1) { + MS_LOG(ERROR) << "while cnode not equal input arg size"; + return RET_ERROR; + } + + FuncGraphPtr sub_func_graph = std::make_shared(); + std::unordered_map anf_sub_node_map; + // convert sub graph inputs + for (int j = 0; j < input_arg_size; j++) { + auto &input_arg = tf_sub_signature.input_arg(j); + auto paramter = sub_func_graph->add_parameter(); + paramter->set_name(input_arg.name()); + anf_sub_node_map[input_arg.name()] = paramter; + sub_graph_inputs.emplace_back(paramter); + } + std::map tf_sub_node_map; + for (int j = 0; j < tf_sub_fuction.node_def_size(); j++) { + auto &node_def = tf_sub_fuction.node_def(j); + tf_sub_node_map[node_def.name()] = &node_def; + } + STATUS status = RET_OK; + status = ConvertGraphInputsAndConsts(tf_sub_node_map, sub_func_graph, &anf_sub_node_map); + if (status != RET_OK) { + MS_LOG(ERROR) << "Convert subgraph consts failed"; + return status; + } + // convert sub graph ops + for (int j = 0; j < tf_sub_fuction.node_def_size(); j++) { + auto &node_def = tf_sub_fuction.node_def(j); + status = ConvertOps(node_def, tf_sub_node_map, sub_func_graph, &anf_sub_node_map); + if (status != RET_OK) { + MS_LOG(ERROR) << "Convert subgraph ops failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return RET_ERROR; + } + } + + // convert subgraph outputs + std::vector sub_output_nodes; + auto &subgraph_ret = tf_sub_fuction.ret(); + for (auto &t : subgraph_ret) { + MS_LOG(INFO) << "subret " << t.first << " " << t.second; + auto tf_output_name = GetFlattenNodeName(t.second); + AnfNodePtr anf_node = nullptr; + if (tf_sub_node_map.find(tf_output_name) == tf_sub_node_map.end()) { + anf_node = GetAnfNode(tf_output_name, anf_sub_node_map); + } else { + auto tf_real_name = GetOriginInputName(*tf_sub_node_map[tf_output_name], tf_sub_node_map); + anf_node = GetAnfNode(tf_real_name, anf_sub_node_map); + } + if (anf_node == nullptr) { + MS_LOG(ERROR) << "can't find anf node,tf node flatten name" << tf_output_name; + return RET_ERROR; + } + sub_output_nodes.push_back(anf_node); + } + status = MakeAnfGraphOutputs(&sub_output_nodes, sub_func_graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "cmake anf graph outputs node error"; + return status; + } + + // add while cond body function to while node input + if (sub_graph_name.find("cond") != std::string::npos) { + while_cond_map[while_cnode] = sub_func_graph; + } else { + while_body_map[while_cnode] = sub_func_graph; + } + // hardcode subgraph inputs name + for (size_t j = 0; j < sub_graph_inputs.size(); j++) { + sub_graph_inputs[j]->set_name("graph" + std::to_string(i) + "_input_" + std::to_string(j) + "parameter"); + } + MS_LOG(INFO) << "parse subgraph end:" << sub_graph_name; + } + auto status = WhileNodePostProcess(while_cond_map, while_body_map); + if (status != RET_OK) { + MS_LOG(ERROR) << "while node post process failed"; + return status; + } + return RET_OK; } +STATUS TFModelParser::WhileNodePostProcess(const std::map &while_cond_map, + const std::map &while_body_map) { + if (while_cond_map.size() != while_body_map.size()) { + MS_LOG(ERROR) << "while cond body size error"; + return RET_ERROR; + } + std::vector roots = {anf_root_graph_}; + auto root_func_manager = std::make_shared(roots); + anf_root_graph_->set_manager(root_func_manager); + for (auto &kv : while_cond_map) { + auto while_node = kv.first; + auto &cond_sub_graph = kv.second; + auto &body_sub_graph = while_body_map.at(while_node); + cond_sub_graph->set_manager(root_func_manager); + body_sub_graph->set_manager(root_func_manager); + auto cond_value_node = NewValueNode(cond_sub_graph); + auto body_value_node = NewValueNode(body_sub_graph); + auto new_while_inputs = while_node->cast()->inputs(); + new_while_inputs[0] = cond_value_node; + new_while_inputs.insert(new_while_inputs.begin() + 1, body_value_node); + auto new_while_node = anf_root_graph_->NewCNode(new_while_inputs); + new_while_node->set_abstract(while_node->abstract()); + root_func_manager->Replace(while_node, new_while_node); + } + return RET_OK; +} + schema::MetaGraphT *TFModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) { MS_LOG(ERROR) << "TF Model Parser not return MetaGraph, use TFModelParser::Parse instead"; @@ -253,15 +415,21 @@ schema::MetaGraphT *TFModelParser::ParseToFb(const std::string &modelFile, const } STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def, - const std::vector &input_names, std::vector *inputs) { + const std::vector &input_names, + const std::map &tf_node_map, + const std::unordered_map &anf_node_map, + std::vector *inputs) { + MS_ASSERT(node_def != nullptr); // parse inputs for (size_t j = 0; j < input_names.size(); j++) { std::string input_name = input_names[j]; // input may be produced by multi-outputs node - if (tf_node_map.find(input_name) != tf_node_map.end()) { - auto input_node = tf_node_map[input_name]; - input_name = GetOriginInputName(*input_node); + // subgraph input name x:output:index,need flatten + auto flatten_input_name = GetFlattenNodeName(input_name); + if (tf_node_map.find(flatten_input_name) != tf_node_map.end()) { + auto input_node = tf_node_map.at(flatten_input_name); + flatten_input_name = GetOriginInputName(*input_node, tf_node_map); } - auto input = GetAnfNode(input_name); + auto input = GetAnfNode(flatten_input_name, anf_node_map); if (input == nullptr) { MS_LOG(ERROR) << node_def.name() << " input " << j << ": " << input_name << " can't find parsed in_nodes"; return RET_ERROR; @@ -271,11 +439,16 @@ STATUS TFModelParser::ConvertInputNodes(const tensorflow::NodeDef &node_def, return RET_OK; } -STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, int output_size) { +STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, + std::unordered_map *anf_node_map, + const FuncGraphPtr &anf_graph, int output_size) { + MS_ASSERT(op != nullptr); + MS_ASSERT(anf_node != nullptr); + MS_ASSERT(anf_graph != nullptr); if (output_size == 1) { std::vector shape_vector; anf_node->set_abstract(std::make_shared(kFloat32, shape_vector)); - anf_node_map.insert(std::pair(op.name(), anf_node)); + anf_node_map->insert(std::pair(op.name(), anf_node)); } else { AbstractBasePtrList abstractList; for (int output_idx = 0; output_idx < output_size; output_idx++) { @@ -289,104 +462,125 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr); auto getItemValue = NewValueNode(MakeValue(output_idx)); std::vector inputs{tupleGetItemPrim, anf_node, getItemValue}; - CNodePtr getItemCNode = funcGraphPtr->NewCNode(inputs); + CNodePtr getItemCNode = anf_graph->NewCNode(inputs); std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); getItemCNode->set_fullname_with_scope(output_item_name); - anf_node_map.insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode)); + anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode)); } anf_node->set_abstract(std::make_shared(abstractList)); } return RET_OK; } -STATUS TFModelParser::ConvertOps() { +STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, + const std::map &tf_node_map, + const FuncGraphPtr &func_graph_ptr, + std::unordered_map *anf_node_map) { + MS_ASSERT(node_def != nullptr); + MS_ASSERT(func_graph_ptr != nullptr); NoSupportOp::GetInstance()->SetFmkType("TF"); STATUS status = RET_OK; - int op_idx = 0; - for (int i = 0; i < tf_graph_def->node_size(); i++) { - auto &node_def = tf_graph_def->node(i); - const auto &op_type = node_def.op(); - if (op_type == "Placeholder" || op_type == "Const" || op_type == "Identity" || op_type == "StopGradient") { - continue; - } - auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type); - if (node_parser == nullptr) { - NoSupportOp::GetInstance()->InsertOp(op_type); - status = (status == RET_OK ? RET_NOT_FIND_OP : status); - MS_LOG(ERROR) << "cannot find node parser:" << op_type; - continue; - } - if (status != RET_OK) { - continue; - } - PrimitiveC *primitiveC = nullptr; - int output_size; - std::vector input_names; - status = node_parser->Parse(node_def, tf_node_map, &primitiveC, &input_names, &output_size); - if (status != RET_OK) { - MS_LOG(ERROR) << "node " << op_type << " parser failed"; - continue; - } + const auto &op_type = node_def.op(); + if (op_type == "Placeholder" || op_type == "Const" || op_type == "Identity" || op_type == "StopGradient") { + return RET_OK; + } - auto value_node = NewValueNode(std::shared_ptr(primitiveC)); - if (value_node == nullptr) { - MS_LOG(ERROR) << "value_node is nullptr"; - status = RET_ERROR; - continue; - } - std::vector inputs = {value_node}; - status = ConvertInputNodes(node_def, input_names, &inputs); - if (status != RET_OK) { - continue; + auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type); + if (node_parser == nullptr) { + NoSupportOp::GetInstance()->InsertOp(op_type); + MS_LOG(ERROR) << "cannot find node parser:" << op_type; + return RET_NOT_FIND_OP; + } + PrimitiveC *primitiveC = nullptr; + int output_size; + std::vector input_names; + status = node_parser->Parse(node_def, tf_node_map, &primitiveC, &input_names, &output_size); + if (status != RET_OK) { + MS_LOG(ERROR) << "node " << op_type << " parser failed"; + return RET_ERROR; + } + auto value_node = NewValueNode(std::shared_ptr(primitiveC)); + if (value_node == nullptr) { + MS_LOG(ERROR) << "value_node is nullptr"; + return RET_ERROR; + } + std::vector inputs = {value_node}; + status = ConvertInputNodes(node_def, input_names, tf_node_map, *anf_node_map, &inputs); + if (status != RET_OK) { + return status; + } + // control_depends are not processed currently + auto anf_node = func_graph_ptr->NewCNode(inputs); + anf_node->set_fullname_with_scope(node_def.name()); + if (op_type == "StatelessWhile" || op_type == "while") { + MS_LOG(INFO) << "find while node:" << node_def.name(); + tensorflow::AttrValue attr_value; + if (TensorFlowUtils::FindAttrValue(node_def, "body", &attr_value)) { + auto body_name = attr_value.func().name(); + function_while_map_[body_name] = anf_node; + MS_LOG(DEBUG) << "parse body name:" << body_name; } - // control_depends are not processed currently - auto anf_node = funcGraphPtr->NewCNode(inputs); - anf_node->set_fullname_with_scope(op_type + "-" + std::to_string(op_idx++)); - - status = ConvertOutputTensor(node_def, anf_node, output_size); - if (status != RET_OK) { - MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - continue; + if (TensorFlowUtils::FindAttrValue(node_def, "cond", &attr_value)) { + auto cond_name = attr_value.func().name(); + function_while_map_[cond_name] = anf_node; + MS_LOG(DEBUG) << "parse cond name:" << cond_name; } } + + status = ConvertOutputTensor(node_def, anf_node, anf_node_map, func_graph_ptr, output_size); + if (status != RET_OK) { + MS_LOG(ERROR) << "Convert output tensors for " << anf_node->fullname_with_scope() << " failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return RET_ERROR; + } return status; } -STATUS TFModelParser::ConvertGraphOutputs() { +STATUS TFModelParser::ConvertRootGraphOutputs() { // because output of intermediate node in anf graph may also be output tensors, we search output tensors in - // tf_node_map but not anf_node_map + // tf_root_graph_nodes_ but not anf_root_node_map_ std::set all_node_inputs; std::vector output_nodes; - for (auto &pair : tf_node_map) { + for (auto &pair : tf_root_graph_nodes_) { for (int i = 0; i < pair.second->input_size(); ++i) { all_node_inputs.insert(pair.second->input(i)); } } - for (auto &pair : tf_node_map) { + for (auto &pair : tf_root_graph_nodes_) { auto it = all_node_inputs.find(pair.first); if (it == all_node_inputs.end() && pair.second->input_size() > 0) { // output node not constraint to Identity - auto origin_name = GetOriginInputName(*(pair.second)); - auto anf_node = GetAnfNode(origin_name); + auto origin_name = GetOriginInputName(*(pair.second), tf_root_graph_nodes_); + auto anf_node = GetAnfNode(origin_name, anf_root_node_map_); if (anf_node == nullptr) { MS_LOG(ERROR) << "can't find anf node"; return RET_ERROR; } output_nodes.push_back(anf_node); - graph_output_names.push_back(anf_node->fullname_with_scope()); + graph_output_names_.push_back(anf_node->fullname_with_scope()); } } - - if (output_nodes.size() > 1) { - std::vector &make_tuple_inputs = output_nodes; + auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_); + if (status != RET_OK) { + MS_LOG(ERROR) << "make anf graph outputs node error"; + return status; + } + return RET_OK; +} +STATUS TFModelParser::MakeAnfGraphOutputs(std::vector *output_nodes, const FuncGraphPtr &anf_graph) { + if (output_nodes->empty() || anf_graph == nullptr) { + MS_LOG(ERROR) << "anf output nodes empty or null anf graph"; + return RET_ERROR; + } + if (output_nodes->size() > 1) { + std::vector *make_tuple_inputs = output_nodes; auto make_tuple_prim_ptr = GetMakeTuplePrim(); if (make_tuple_prim_ptr == nullptr) { MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; return RET_NULL_PTR; } auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); - make_tuple_inputs.insert(output_nodes.begin(), make_tuple_prim); - auto make_tuple_cnode = funcGraphPtr->NewCNode(make_tuple_inputs); + make_tuple_inputs->insert(make_tuple_inputs->begin(), make_tuple_prim); + auto make_tuple_cnode = anf_graph->NewCNode(*make_tuple_inputs); make_tuple_cnode->set_fullname_with_scope("return tuple"); auto return_prim_ptr = GetReturnPrim(); @@ -396,20 +590,20 @@ STATUS TFModelParser::ConvertGraphOutputs() { } auto value_node = NewValueNode(return_prim_ptr); std::vector op_inputs = {value_node, make_tuple_cnode}; - auto cnode = funcGraphPtr->NewCNode(op_inputs); + auto cnode = anf_graph->NewCNode(op_inputs); cnode->set_fullname_with_scope("return"); - funcGraphPtr->set_return(cnode); - } else if (output_nodes.size() == 1) { + anf_graph->set_return(cnode); + } else { auto return_prim_ptr = GetReturnPrim(); if (return_prim_ptr == nullptr) { MS_LOG(ERROR) << "GetReturnPrim return nullptr"; return RET_NULL_PTR; } auto value_node = NewValueNode(return_prim_ptr); - std::vector op_inputs{value_node, output_nodes.front()}; - auto return_cnode = funcGraphPtr->NewCNode(op_inputs); + std::vector op_inputs{value_node, output_nodes->front()}; + auto return_cnode = anf_graph->NewCNode(op_inputs); return_cnode->set_fullname_with_scope("return"); - funcGraphPtr->set_return(return_cnode); + anf_graph->set_return(return_cnode); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h index 67b3f50618..49336f8728 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h @@ -17,17 +17,17 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_MODEL_PARSER_H -#include -#include -#include #include +#include +#include #include +#include +#include "proto/graph.pb.h" +#include "proto/node_def.pb.h" +#include "schema/inner/model_generated.h" #include "securec/include/securec.h" #include "tools/common/tensor_util.h" #include "tools/converter/model_parser.h" -#include "schema/inner/model_generated.h" -#include "proto/node_def.pb.h" -#include "proto/graph.pb.h" namespace mindspore { namespace lite { @@ -43,24 +43,39 @@ class TFModelParser : public ModelParser { const QuantType &quantType = QuantType_QUANT_NONE) override; private: - AnfNodePtr GetAnfNode(const std::string &name); - std::string GetOriginInputName(const tensorflow::NodeDef &node); STATUS ConvertConstTensor(const tensorflow::AttrValue &attr_value, const TypeId &type, const ParameterPtr ¶meter, std::vector *shape_vector); - STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter); - STATUS ConvertGraphInputsAndConsts(); + STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter, + std::unordered_map *anf_node_map); + STATUS ConvertGraphInputsAndConsts(const std::map &tf_graph_nodes, + const FuncGraphPtr &anf_graph, + std::unordered_map *anf_node_map); STATUS ConvertInputNodes(const tensorflow::NodeDef &node_def, const std::vector &input_names, + const std::map &tf_node_map, + const std::unordered_map &anf_node_map, std::vector *inputs); - STATUS ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, int output_size); - STATUS ConvertOps(); - STATUS ConvertGraphOutputs(); + STATUS ConvertOutputTensor(const tensorflow::NodeDef &op, const CNodePtr &anf_node, + std::unordered_map *anf_node_map, const FuncGraphPtr &anf_graph, + int output_size); + STATUS ConvertOps(const tensorflow::NodeDef &node_def, + const std::map &tf_node_map, + const FuncGraphPtr &func_graph_ptr, std::unordered_map *anf_node_map); + STATUS ConvertRootGraphOutputs(); + + STATUS ConvertSubgraph(); + + STATUS WhileNodePostProcess(const std::map &while_cond_map, + const std::map &while_body_map); + + STATUS MakeAnfGraphOutputs(std::vector *output_nodes, const FuncGraphPtr &anf_graph); - FuncGraphPtr funcGraphPtr; - std::unique_ptr tf_graph_def; - std::map tf_node_map; - std::unordered_map anf_node_map; - std::vector graph_input_names; - std::vector graph_output_names; + FuncGraphPtr anf_root_graph_; + std::unique_ptr tf_root_graph_; // tf root graph def + std::map tf_root_graph_nodes_; // tf root graph node map + std::unordered_map anf_root_node_map_; + std::vector graph_input_names_; + std::vector graph_output_names_; + std::map function_while_map_; // tf function name->while_node_name }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_while_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_while_parser.cc new file mode 100644 index 0000000000..f4c8869bb5 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_while_parser.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_while_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFWhileParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, PrimitiveC **primitiveC, + std::vector *inputs, int *output_size) { + MS_LOG(INFO) << "TF WhileParser"; + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is nullptr"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + primitive->value.type = schema::PrimitiveType_While; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = tf_op.input_size(); + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + return RET_OK; +} +TFNodeRegistrar g_tfStatelessWhileParser("StatelessWhile", new TFWhileParser()); +TFNodeRegistrar g_tfWhileParser("While", new TFWhileParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_while_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_while_parser.h new file mode 100644 index 0000000000..287d5cb43b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_while_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_WHILE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_WHILE_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFWhileParser : public TFNodeParser { + public: + TFWhileParser() = default; + ~TFWhileParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_WHILE_PARSER_H_