!14209 [MS][LITE]optimize anf export

From: @mengyuanli
Reviewed-by: @zhang_xue_tong,@hangangqiang
Signed-off-by: @zhang_xue_tong
pull/14209/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit bca8389e31

@ -34,7 +34,6 @@
#include "src/tensor.h"
#include "src/param_value_lite.h"
#include "src/common/utils.h"
#include "ops/partial.h"
#include "tools/common/graph_util.h"
#include "src/ops/ops_utils.h"
@ -287,7 +286,6 @@ int AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &m
int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgraph_index,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const std::unique_ptr<schema::SubGraphT> &sub_graphT,
schema::CNodeT *return_node) {
MS_ASSERT(nullptr != meta_graphT);
MS_ASSERT(nullptr != return_node);
@ -319,9 +317,15 @@ int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgrap
return RET_OK;
}
bool AnfExporter::HasExported(const FuncGraphPtr &func_graph) {
if (fg_subgraph_map_.find(func_graph) != fg_subgraph_map_.end()) {
return true;
}
return false;
}
int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const size_t &subgraph_index, const bool &keep_graph, const bool &copy_primitive,
const std::unique_ptr<schema::SubGraphT> &sub_graphT) {
const size_t &subgraph_index, const bool &keep_graph, const bool &copy_primitive) {
int ret = RET_OK;
auto cnodes = GetOrderedCNodes(func_graph);
for (const auto &cnode : cnodes) {
@ -334,19 +338,18 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
prim = GetValueNode<std::shared_ptr<Primitive>>(partial_cnode->input(0));
primT = GetPrimitiveT(partial_cnode->input(0));
MS_ASSERT(primT != nullptr);
auto pos = fg_subgraph_map.find(fg);
if (pos != fg_subgraph_map.end()) {
auto pos = fg_subgraph_map_.find(fg);
if (pos != fg_subgraph_map_.end()) {
MS_ASSERT(primT->value.AsPartialFusion() != nullptr);
primT->value.AsPartialFusion()->sub_graph_index = fg_subgraph_map.at(fg);
primT->value.AsPartialFusion()->sub_graph_index = fg_subgraph_map_.at(fg);
} else {
size_t next_subgraph_index = fg_subgraph_map.size() + 1;
fg_subgraph_map.insert(std::pair<FuncGraphPtr, int>{fg, next_subgraph_index});
size_t next_subgraph_index = meta_graphT->subGraph.size();
MS_ASSERT(primT->value.AsPartialFusion() != nullptr);
primT->value.AsPartialFusion()->sub_graph_index = next_subgraph_index;
ret = ExportSubgraph(fg, meta_graphT, next_subgraph_index, keep_graph, copy_primitive, cnode);
ret = ExportSubgraph(fg, meta_graphT, keep_graph, copy_primitive, cnode);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ExportSubgraph failed";
break;
return ret;
}
}
} else {
@ -374,7 +377,7 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
}
if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
node->name = mindspore::ops::kNameReturn;
ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, sub_graphT, node.get());
ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, node.get());
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetOpOutputN failed";
break;
@ -398,26 +401,28 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
break;
}
meta_graphT->nodes.push_back(std::move(node));
meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx++);
meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx_++);
}
return ret;
}
int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const size_t &subgraph_index, bool keep_graph, bool copy_primitive,
const std::shared_ptr<AnfNode> &partial_anode) {
int ret = RET_OK;
bool keep_graph, bool copy_primitive, const std::shared_ptr<AnfNode> &partial_anode) {
if (HasExported(func_graph)) {
MS_LOG(INFO) << "Has been exported.";
return RET_OK;
}
meta_graphT->subGraph.emplace_back(std::make_unique<schema::SubGraphT>());
auto &sub_graphT = meta_graphT->subGraph.at(subgraph_index);
auto subgraph_index = meta_graphT->subGraph.size() - 1;
fg_subgraph_map_[func_graph] = subgraph_index;
auto subgraph_name = func_graph->get_attr("graph_name");
MS_ASSERT(subgraph_name != nullptr);
sub_graphT->name = GetValue<std::string>(subgraph_name);
auto fmk = func_graph->get_attr("fmk");
MS_ASSERT(fmk != nullptr);
meta_graphT->fmkType = GetValue<int>(fmk);
meta_graphT->subGraph.back()->name = GetValue<std::string>(subgraph_name);
ret = Anf2Fb(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive, sub_graphT);
int ret = Anf2Fb(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Anf2Fb failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
return ret;
}
@ -441,11 +446,14 @@ int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::uniqu
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive,
bool train_flag) {
static int subgraph_index = 0;
this->train_flag = train_flag;
this->train_flag_ = train_flag;
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
int ret = ExportSubgraph(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive);
auto fmk = func_graph->get_attr("fmk");
MS_ASSERT(fmk != nullptr);
meta_graphT->fmkType = GetValue<int>(fmk);
int ret = ExportSubgraph(func_graph, meta_graphT, keep_graph, copy_primitive);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ExportSubgraph failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
return nullptr;
}
@ -455,7 +463,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee
int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode) {
MS_ASSERT(input_anode != nullptr && output_cnode != nullptr);
auto input_name = input_anode->fullname_with_scope();
if (this->train_flag) {
if (this->train_flag_) {
bool found = false;
if (node_id_map_.find(input_name) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]);
@ -618,7 +626,7 @@ int AnfExporter::ProcessTensor(const ValueNodePtr &valueNode, std::unique_ptr<sc
(void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims),
[](const int64_t &value) { return static_cast<int32_t>(value); });
(*paramTensor)->dims = dims;
if (train_flag && (*paramTensor)->dims.empty()) (*paramTensor)->dims = {1};
if (train_flag_ && (*paramTensor)->dims.empty()) (*paramTensor)->dims = {1};
(*paramTensor)->nodeType = NodeType_ValueNode;
auto data = value->cast<tensor::TensorPtr>();
(*paramTensor)->data.resize(data->Size());
@ -742,7 +750,7 @@ int AnfExporter::ProcessParamValueLite(const ValueNodePtr &valueNode, std::uniqu
(*paramTensor)->dataType = valueLite->tensor_type();
(*paramTensor)->dims = valueLite->tensor_shape();
if (train_flag && (*paramTensor)->dims.empty()) {
if (train_flag_ && (*paramTensor)->dims.empty()) {
(*paramTensor)->dims = {1};
}
@ -765,7 +773,7 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano
auto paramTensor = std::make_unique<schema::TensorT>();
auto value = valueNode->value();
int ret = RET_OK;
if (train_flag) {
if (train_flag_) {
paramTensor->name = valueNode->fullname_with_scope();
}
if (value->isa<tensor::Tensor>()) {
@ -861,7 +869,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
}
msTensor->nodeType = NodeType_CNode;
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size());
if (train_flag) {
if (train_flag_) {
std::string name = cnode_name + "_o:" + std::to_string(i);
node_id_map_[name] = meta_graphT->allTensors.size();
meta_graphT->allTensors.emplace_back(msTensor);

@ -30,7 +30,6 @@ using mindspore::ops::PrimitiveC;
namespace mindspore::lite {
constexpr const int kPartialMinSize = 3;
constexpr const int kMainGraphIndex = 0;
class AnfExporter {
@ -74,28 +73,27 @@ class AnfExporter {
const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
int SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const size_t &subgraph_index);
int SetGraphoutputIndex(const CNodePtr &cnode, size_t subgraph_index,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const std::unique_ptr<schema::SubGraphT> &sub_graphT, schema::CNodeT *return_node);
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *return_node);
static int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::shared_ptr<mindspore::Primitive> &primitive,
const std::unique_ptr<schema::CNodeT> &dst_node);
int Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const size_t &subgraph_index, const bool &keep_graph, const bool &copy_primitive,
const std::unique_ptr<schema::SubGraphT> &sub_graphT);
const size_t &subgraph_index, const bool &keep_graph, const bool &copy_primitive);
int ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const size_t &subgraph_index, bool keep_graph, bool copy_primitive,
const std::shared_ptr<AnfNode> &partial_anode = nullptr);
bool keep_graph, bool copy_primitive, const std::shared_ptr<AnfNode> &partial_anode = nullptr);
static ValueNodePtr GetPartialAnfPrim();
static CNodePtr CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr cnode);
static std::vector<schema::CNodeT *> GetSubgraphNodes(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const size_t &subgraph_index);
bool HasExported(const FuncGraphPtr &func_graph);
private:
std::map<std::string, int> node_id_map_;
std::vector<schema::CNodeT *> graph_input_nodes_;
std::map<FuncGraphPtr, int> fg_subgraph_map;
uint32_t node_idx = 0;
bool train_flag = false;
// The first item is FuncGraph which has been exported, the second item is the subgraph index in meta_graph
std::map<FuncGraphPtr, int> fg_subgraph_map_;
uint32_t node_idx_ = 0;
bool train_flag_ = false;
};
// by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT.
// but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify

Loading…
Cancel
Save