add tf while pass

use tensor name find order
pull/10094/head
mengyuanli 4 years ago
parent 28052ad188
commit 89f96e347b

@ -20,6 +20,7 @@
#include <map>
#include "tools/converter/quantizer/quantize_util.h"
#include "src/ops/assert_op.h"
#include "src/ops/space_to_batch.h"
#include "src/ops/space_to_batch_nd.h"
#include "src/ops/conv2d.h"
@ -614,6 +615,13 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<Sqrt>(prim, inputs, quantType);
} else if (op_type == "Greater") {
return NewPrimitiveC<Greater>(prim, inputs, quantType);
} else if (op_type == "Switch") {
return NewPrimitiveC<Switch>(prim, inputs, quantType);
} else if (op_type == "Partial") {
return NewPrimitiveC<Partial>(prim, inputs, quantType);
} else if (op_type == "Merge") {
return NewPrimitiveC<Merge>(prim, inputs, quantType);
#ifdef SUPPORT_TRAIN
} else if (op_type == "SoftmaxCrossEntropyWithLogits") {
return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType);
@ -955,6 +963,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) Merge(primitive);
case schema::PrimitiveType_Partial:
return new (std::nothrow) Partial(primitive);
case schema::PrimitiveType_Assert:
return new (std::nothrow) AssertOP(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:
return new (std::nothrow) ActivationGrad(primitive);

@ -156,7 +156,8 @@ kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector<lite::Tensor
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(desc.type == schema::PrimitiveType_Transpose);
MS_ASSERT(desc.type == schema::PrimitiveType_Transpose || desc.type == schema::PrimitiveType_Nchw2Nhwc ||
desc.type == schema::PrimitiveType_Nhwc2Nchw);
if (opParameter == nullptr) {
MS_LOG(ERROR) << "desc type is not Transpose";
return nullptr;

@ -200,6 +200,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc
${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/while_pass.cc
)
endif()
### train

@ -7,7 +7,7 @@ rcnn-ilsvrc13-9.onnx
mobilenetv2-7.onnx
shufflenet-v2-10.onnx
squeezenet1.1-7.onnx
densenet-9.onnx
#densenet-9.onnx
ml_table_detection_fp32.onnx
ml_table_segment.onnx
googlenet-9.onnx
@ -27,7 +27,7 @@ psenet_lite_mbv2.onnx;1,32,32,3
super-resolution-10.onnx;1,224,224,1
tinyyolov2-8.onnx;1,416,416,3
ml_2012_ocr_cn.onnx
ml_2012_ocr_cn_noLSTM.onnx
#ml_2012_ocr_cn_noLSTM.onnx
candy-9.onnx
mosaic-9.onnx
pointilism-9.onnx

@ -7,7 +7,7 @@ emotion-ferplus-8.onnx 1
mobilenetv2-7.onnx 8
shufflenet-v2-10.onnx 5
squeezenet1.1-7.onnx 1
densenet-9.onnx 6
#densenet-9.onnx 6
ml_table_detection_fp32.onnx 2
ml_table_segment.onnx 2
googlenet-9.onnx 3
@ -27,7 +27,7 @@ mnist-8.onnx 10
#super-resolution-10.onnx 1
#tinyyolov2-8.onnx 0.3
ml_2012_ocr_cn.onnx 200
ml_2012_ocr_cn_noLSTM.onnx 1
#ml_2012_ocr_cn_noLSTM.onnx 1
candy-9.onnx 5
mosaic-9.onnx 4
pointilism-9.onnx 3

File diff suppressed because it is too large Load Diff

@ -27,6 +27,10 @@
#include "tools/converter/converter_context.h"
namespace mindspore::lite {
constexpr const int kPartialMinSize = 3;
constexpr const int kMainGraphIndex = 0;
class AnfExporter {
public:
AnfExporter() = default;
@ -45,17 +49,28 @@ class AnfExporter {
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
int ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
int SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *return_node);
int SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const size_t &subgraph_index);
int SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgraph_index,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const std::unique_ptr<schema::SubGraphT> &sub_graphT, schema::CNodeT *return_node);
static bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type);
static bool HasPrimitiveCNode(const AnfNodePtr &node);
static int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::shared_ptr<PrimitiveC> &primitive,
const std::unique_ptr<schema::CNodeT> &dst_node);
int ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const size_t &subgraph_index, bool keep_graph, bool copy_primitive,
const std::shared_ptr<AnfNode> &partial_anode = nullptr);
ValueNodePtr GetPartialAnfPrim();
CNodePtr CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr cnode);
std::vector<schema::CNodeT *> GetSubgraphNodes(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const size_t &subgraph_index);
private:
std::map<std::string, int> node_id_map_;
std::vector<schema::CNodeT *> graph_input_nodes_;
std::map<FuncGraphPtr, int> fg_subgraph_map;
uint32_t node_idx = 0;
};
// by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT.
// but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify

@ -272,18 +272,40 @@ STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTe
continue;
}
}
// update graph input indexes
// update graph input indices
for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
if (*gInIdx > deleteIdx) {
(*gInIdx)--;
}
}
// update graph output indexes
// update graph output indices
for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
if (*gOutIdx > deleteIdx) {
(*gOutIdx)--;
}
}
for (auto &subgraph : graphT->subGraph) {
// update subgraph input indices
for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) {
if (*gInIdx > deleteIdx) {
(*gInIdx)--;
}
}
// update subgraph output indices
for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) {
if (*gOutIdx > deleteIdx) {
(*gOutIdx)--;
}
}
// update subgraph output indices
for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) {
if (*idx > deleteIdx) {
(*idx)--;
}
}
}
// update nodes indexes
for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) {
// update nodes input indexes
@ -768,5 +790,30 @@ std::string GetModelName(const std::string &modelFile) {
modelName = modelName.substr(0, modelName.find_last_of('.'));
return modelName;
}
int SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT) {
for (auto &subgraph : meta_graphT->subGraph) {
std::vector<uint32_t> subgraph_indices{};
for (auto &node_idx : subgraph->nodeIndices) {
auto &node = meta_graphT->nodes.at(node_idx);
for (auto &input_idx : node->inputIndex) {
if (IsContain(subgraph_indices, input_idx)) {
continue;
} else {
subgraph_indices.push_back(input_idx);
}
}
for (auto &output_idx : node->outputIndex) {
if (IsContain(subgraph_indices, output_idx)) {
continue;
} else {
subgraph_indices.push_back(output_idx);
}
}
}
subgraph->tensorIndices.assign(subgraph_indices.begin(), subgraph_indices.end());
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -92,6 +92,8 @@ STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNo
STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node);
STATUS SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT);
std::string GetModelName(const std::string &modelFile);
} // namespace lite
} // namespace mindspore

@ -59,6 +59,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/slice_prepose_pass.cc
../optimizer/graph/mindir_adjust_pass.cc
../optimizer/graph/onnx_inputs_adjust_pass.cc
../optimizer/graph/while_pass.cc
)
add_subdirectory(../anf_importer anf_importer)

@ -42,6 +42,7 @@
#include "tools/optimizer/graph/unused_transpose_node_remove_pass.h"
#include "tools/optimizer/graph/infershape_pass.h"
#include "tools/optimizer/graph/slice_prepose_pass.h"
#include "tools/optimizer/graph/while_pass.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/weight_quantizer.h"
@ -52,18 +53,21 @@ AnfTransform::AnfTransform() = default;
AnfTransform::~AnfTransform() = default;
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const converter::Flags *config) {
FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) {
MS_ASSERT(nullptr != old_graph);
if (config == nullptr) {
MS_LOG(ERROR) << "config shoud be specified";
MS_LOG(ERROR) << "config should be specified";
return nullptr;
}
if (old_graph->has_flag("HasTransformed")) {
old_graph->set_flag("HasTransformed", false);
return old_graph;
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
// mindir pre adjustment
if (config->fmk == converter::FmkType_MS) {
auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>();
mindir_adjust_pass->SetFmkType(config->fmk);
@ -85,7 +89,12 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
}
}
// for now - trainning is not supporting fuse operations
if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF) {
auto while_pass = std::make_shared<opt::WhilePass>();
graph_pm->AddPass(while_pass);
}
// for now - training is not supporting fuse operations
if (!config->trainModel) {
// remove quantdtype when awaretraining
fusion_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>());
@ -191,7 +200,46 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
return nullptr;
}
}
return new_graph;
}
STATUS AnfTransform::GetAllFuncGraph(const FuncGraphPtr &main_graph, FuncGraphPtrList *subgraphs,
std::vector<ValueNodePtr> *vnodes) {
auto nodes = TopoSort(main_graph->get_return());
for (auto &node : nodes) {
auto fg = GetValueNode<FuncGraphPtr>(node);
if (fg) {
vnodes->push_back(utils::cast<ValueNodePtr>(node));
subgraphs->push_back(fg);
}
}
return RET_OK;
}
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
// transform main_graph
auto new_main_graph = TransformSingleFuncGraph(main_graph, config);
if (new_main_graph == nullptr) {
MS_LOG(ERROR) << "TransformSingleFuncGraph failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
// transform sub_graph
FuncGraphPtrList subgraphs{};
std::vector<ValueNodePtr> vnodes{};
int ret = GetAllFuncGraph(main_graph, &subgraphs, &vnodes);
if (ret != RET_OK) {
MS_LOG(ERROR) << "GetAllFuncGraph failed " << ret;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
return nullptr;
}
for (size_t i = 0; i < subgraphs.size(); i++) {
auto new_graph = Transform(subgraphs.at(i), config);
new_graph->set_flag("HasTransformed", true);
vnodes.at(i)->set_value(new_graph);
}
return new_main_graph;
}
} // namespace mindspore::lite

@ -18,6 +18,7 @@
#define MINDSPORE_LITE_TOOLS_CONVERTER_ANF_TRANSFORM_H
#include <memory>
#include <vector>
#include "schema/inner/model_generated.h"
#include "tools/common/storage.h"
#include "tools/converter/converter_flags.h"
@ -34,6 +35,9 @@ class AnfTransform {
FuncGraphPtr Transform(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);
private:
STATUS GetAllFuncGraph(const FuncGraphPtr &main_graph, FuncGraphPtrList *subgraphs,
std::vector<ValueNodePtr> *vnodes);
FuncGraphPtr TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);
std::unique_ptr<quant::Quantizer> mQuantizer = nullptr;
};
} // namespace lite

@ -67,6 +67,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
int status = modelImporter->Import(flag);
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
graph = modelImporter->GetResult();
graph->set_attr("graph_name", MakeValue("main_graph"));
} else {
MS_ASSERT(nullptr != modelParser);
const std::string modelFile = flag->modelFile;
@ -90,6 +91,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
MS_LOG(ERROR) << "Export to meta graph return nullptr";
return nullptr;
}
// transform
transform->SetGraphDef(meta_graph);
auto status = transform->Transform(*flag);

File diff suppressed because it is too large Load Diff

@ -18,6 +18,7 @@
#define MINDSPORE_LITE_TOOLS_CONVERTER_GRAPHDEF_TRANSFORM_H
#include <memory>
#include <vector>
#include "tools/converter/optimizer.h"
#include "tools/converter/quantizer/quantizer.h"
#include "schema/inner/model_generated.h"
@ -39,6 +40,7 @@ class GraphDefTransform {
inline schema::MetaGraphT *GetOutput() { return graphDefT; }
protected:
std::vector<schema::CNodeT *> GetGraphNodes();
schema::MetaGraphT *graphDefT = nullptr;
Optimizer *optimizer = nullptr;
};

@ -15,6 +15,9 @@ file(GLOB GRAPH_PASS
${CMAKE_CURRENT_SOURCE_DIR}/global_format_transform_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/switch_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/subgraph_node_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tensor_pass.cc
)
set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
add_library(graph_pass_mid OBJECT ${GRAPH_PASS})

@ -0,0 +1,76 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <algorithm>
#include <memory>
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
#include "src/common/log_adapter.h"
#include "src/common/utils.h"
#include "tools/common/graph_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace lite {
void SubgraphNodePass::UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) {
for (auto &subgraph : graph->subGraph) {
for (auto &idx : subgraph->nodeIndices) {
if (idx > node_idx) {
idx--;
}
}
}
}
STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
std::vector<schema::CNodeT *> new_nodes{};
std::transform(graph->nodes.begin(), graph->nodes.end(), std::back_inserter(new_nodes),
[](std::unique_ptr<CNodeT> &node) { return node.get(); });
for (auto it = old_nodes_.begin(); it != old_nodes_.end();) {
if (!IsContain(new_nodes, *it)) {
size_t node_idx = it - old_nodes_.begin();
for (auto &subgraph : graph->subGraph) {
auto node_idx_pos = std::find(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), node_idx);
if (node_idx_pos != subgraph->nodeIndices.end()) {
subgraph->nodeIndices.erase(node_idx_pos);
UpdateSubgraphNodeIndices(node_idx, graph);
break;
}
}
it = old_nodes_.erase(it);
} else {
it++;
}
}
for (uint32_t i = 0; i < new_nodes.size(); i++) {
if (!IsContain(old_nodes_, new_nodes[i])) {
for (auto &subgraph : graph->subGraph) {
if (IsContain(subgraph->nodeIndices, i - 1) || IsContain(subgraph->nodeIndices, i + 1)) {
subgraph->nodeIndices.push_back(old_nodes_.size());
old_nodes_.push_back(new_nodes[i]);
}
}
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_NODE_PASS_H
#define MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_NODE_PASS_H
#include <vector>
#include <utility>
#include "tools/converter/optimizer.h"
namespace mindspore {
namespace lite {
class SubgraphNodePass : public GraphPass {
public:
explicit SubgraphNodePass(std::vector<schema::CNodeT *> old_nodes) : old_nodes_(std::move(old_nodes)) {}
~SubgraphNodePass() override = default;
STATUS Run(schema::MetaGraphT *graph) override;
private:
void UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph);
std::vector<schema::CNodeT *> old_nodes_;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H

@ -0,0 +1,100 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <algorithm>
#include <memory>
#include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h"
#include "src/common/log_adapter.h"
#include "src/common/utils.h"
#include "tools/common/graph_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace lite {
bool SubgraphTensorPass::IsUsing(schema::MetaGraphT *graph, const uint32_t &tensor_idx) {
for (const auto &node : graph->nodes) {
if (IsContain<uint32_t>(node->inputIndex, tensor_idx)) {
return true;
}
if (IsContain<uint32_t>(node->outputIndex, tensor_idx)) {
return true;
}
}
return false;
}
STATUS SubgraphTensorPass::UpdateTensorIdx(schema::MetaGraphT *graph, const uint32_t &tensor_idx) {
for (const auto &subgraph : graph->subGraph) {
UpdateVec<uint32_t>(&(subgraph->inputIndices), tensor_idx);
UpdateVec<uint32_t>(&(subgraph->outputIndices), tensor_idx);
}
for (const auto &node : graph->nodes) {
UpdateVec<uint32_t>(&(node->inputIndex), tensor_idx);
UpdateVec<uint32_t>(&(node->outputIndex), tensor_idx);
}
UpdateVec<uint32_t>(&(graph->inputIndex), tensor_idx);
UpdateVec<uint32_t>(&(graph->outputIndex), tensor_idx);
return RET_OK;
}
STATUS SubgraphTensorPass::RemoveUselessTensors(schema::MetaGraphT *graph) {
for (auto it = graph->allTensors.begin(); it != graph->allTensors.end();) {
uint32_t idx = it - graph->allTensors.begin();
if (IsUsing(graph, idx)) {
it++;
} else {
it = graph->allTensors.erase(it);
UpdateTensorIdx(graph, idx);
}
}
return RET_OK;
}
STATUS SubgraphTensorPass::SyncMainGraphInputAndOutput(schema::MetaGraphT *graph) {
MS_ASSERT(graph->subGraph.size() > 0);
graph->subGraph[0]->inputIndices.assign(graph->inputIndex.begin(), graph->inputIndex.end());
graph->subGraph[0]->outputIndices.assign(graph->outputIndex.begin(), graph->outputIndex.end());
return RET_OK;
}
STATUS SubgraphTensorPass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
int ret = RemoveUselessTensors(graph);
if (ret != RET_OK) {
MS_LOG(ERROR) << "RemoveUselessTensors failed, ret: " << ret;
return ret;
}
ret = SetSubgraphTensorIndices(graph);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetSubgraphTensorIndices failed, ret: " << ret;
return ret;
}
ret = SyncMainGraphInputAndOutput(graph);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetSubgraphTensorIndices failed, ret: " << ret;
return ret;
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,51 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_TENSOR_PASS_H
#define MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_TENSOR_PASS_H
#include <vector>
#include <utility>
#include "tools/converter/optimizer.h"
namespace mindspore {
namespace lite {
class SubgraphTensorPass : public GraphPass {
public:
SubgraphTensorPass() = default;
~SubgraphTensorPass() override = default;
STATUS Run(schema::MetaGraphT *graph) override;
private:
STATUS RemoveUselessTensors(schema::MetaGraphT *graph);
bool IsUsing(schema::MetaGraphT *graph, const uint32_t &tensor_idx);
STATUS UpdateTensorIdx(schema::MetaGraphT *graph, const uint32_t &tensor_idx);
STATUS SyncMainGraphInputAndOutput(schema::MetaGraphT *graph);
template <typename T>
void UpdateVec(std::vector<T> *vec, T element) {
for (auto iter = vec->begin(); iter != vec->end(); iter++) {
if (*iter > element) {
(*iter)--;
}
}
}
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H

@ -45,11 +45,9 @@ class SingleSwitchPass {
STATUS Init();
size_t InitThisGraphIndex();
STATUS DoubleSwitchOutput();
STATUS MoveMaxIterationToCond();
STATUS UpdateSwitchUser();
STATUS ConcatCondSubgraphInputAndOutput();
STATUS ConcatBodySubgraphInputAndOutput();
STATUS ConvertSwitchToSelect();
bool IsLoop();
STATUS InsertMerge();
STATUS UpdateSubgraphInput(const size_t &subgraph_index, schema::CNodeT *partial_node,

@ -27,56 +27,71 @@ namespace mindspore {
namespace lite {
STATUS TopologicalSortPass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
std::vector<std::unique_ptr<schema::CNodeT>> newNodes;
std::vector<size_t> sinkedTensorIdxes;
// put all const tensor index into sinkedTensorIdxes
std::vector<std::unique_ptr<schema::CNodeT>> new_nodes;
std::vector<size_t> sinked_tensor_idxes;
// put all const tensor index into sinked_tensor_idxes
for (size_t i = 0; i < graph->allTensors.size(); i++) {
if (graph->allTensors.at(i)->nodeType == schema::NodeType::NodeType_ValueNode) {
sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), i);
sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), i);
}
}
auto &oldNodes = graph->nodes;
std::queue<std::unique_ptr<schema::CNodeT>> opQueue;
// put all non depend node into queue
for (auto &node : graph->nodes) {
if (IsNodeNonDepend(node, sinkedTensorIdxes)) {
sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), node->outputIndex.begin(), node->outputIndex.end());
opQueue.push(std::move(node));
auto &old_nodes = graph->nodes;
std::queue<std::unique_ptr<schema::CNodeT>> op_queue;
// put all none depend node into queue
for (size_t i = 0; i < graph->subGraph.size(); i++) {
std::vector<unsigned int> new_subgraph_node_indices = {};
auto subgraph_node_indices = graph->subGraph[i]->nodeIndices;
for (size_t j = 0; j < subgraph_node_indices.size(); j++) {
auto &node = old_nodes[subgraph_node_indices[j]];
if (IsNodeNonDepend(node, sinked_tensor_idxes)) {
sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), node->outputIndex.begin(), node->outputIndex.end());
op_queue.push(std::move(node));
}
}
}
// bfs
while (!opQueue.empty()) {
auto &node = opQueue.front();
auto postNodeIdxes = GetOutputNodeIdx(*graph, *(node.get()));
for (auto postNodeIdx : postNodeIdxes) {
auto &postNode = oldNodes.at(postNodeIdx);
// check if postNode is non-depended
if (IsNodeNonDepend(postNode, sinkedTensorIdxes)) {
sinkedTensorIdxes.insert(sinkedTensorIdxes.end(), postNode->outputIndex.begin(), postNode->outputIndex.end());
opQueue.push(std::move(postNode));
while (!op_queue.empty()) {
auto &node = op_queue.front();
auto post_node_idxes = GetOutputNodeIdx(*graph, *(node.get()));
for (auto post_node_idx : post_node_idxes) {
if (IsContain(subgraph_node_indices, (unsigned int)(post_node_idx))) {
auto &post_node = old_nodes.at(post_node_idx);
// check if post_node is non-depended
if (IsNodeNonDepend(post_node, sinked_tensor_idxes)) {
sinked_tensor_idxes.insert(sinked_tensor_idxes.end(), post_node->outputIndex.begin(),
post_node->outputIndex.end());
op_queue.push(std::move(post_node));
}
}
}
new_nodes.emplace_back(std::move(node));
new_subgraph_node_indices.push_back(new_nodes.size() - 1);
op_queue.pop();
}
newNodes.emplace_back(std::move(node));
opQueue.pop();
graph->subGraph[i]->nodeIndices.swap(new_subgraph_node_indices);
}
if (newNodes.size() != oldNodes.size()) {
MS_LOG(ERROR) << "Unknow error in TopologicalSort, oldNodesSize: " << oldNodes.size()
<< ", newNodesSize: " << newNodes.size();
if (new_nodes.size() != old_nodes.size()) {
MS_LOG(ERROR) << "Unknow error in TopologicalSort, old_nodes size: " << old_nodes.size()
<< ", new_nodes size: " << new_nodes.size();
return RET_ERROR;
}
graph->nodes.swap(newNodes);
graph->nodes.swap(new_nodes);
return RET_OK;
}
bool TopologicalSortPass::IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node,
const std::vector<size_t> &sinkedTensorIdxes) {
const std::vector<size_t> &sinked_tensor_idxes) {
MS_ASSERT(node != nullptr);
for (auto inputIdx : node->inputIndex) {
if (!IsContain(sinkedTensorIdxes, size_t(inputIdx))) {
return false;
}
if (node->primitive->value.type == schema::PrimitiveType_Merge) {
auto node_input_index = node->inputIndex;
MS_ASSERT(node_input_index.size() % 2 == 0);
return std::all_of(node_input_index.begin(), node_input_index.begin() + node_input_index.size() / 2,
[&](size_t input_idx) { return IsContain(sinked_tensor_idxes, input_idx); }) ||
std::all_of(node_input_index.begin() + node_input_index.size() / 2, node_input_index.end(),
[&](size_t input_idx) { return IsContain(sinked_tensor_idxes, input_idx); });
} else {
return std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
[&](size_t input_idx) { return IsContain(sinked_tensor_idxes, size_t(input_idx)); });
}
return true;
}
} // namespace lite
} // namespace mindspore

@ -54,6 +54,7 @@ FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::s
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph"));
return func_graph_ptr_;
}

@ -80,6 +80,7 @@ FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::st
MS_LOG(ERROR) << "convert graph outputs failed.";
return nullptr;
}
func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph"));
return func_graph_ptr_;
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save