|
|
|
@ -34,7 +34,6 @@
|
|
|
|
|
#include "src/tensor.h"
|
|
|
|
|
#include "src/param_value_lite.h"
|
|
|
|
|
#include "src/common/utils.h"
|
|
|
|
|
#include "ops/partial.h"
|
|
|
|
|
#include "tools/common/graph_util.h"
|
|
|
|
|
#include "src/ops/ops_utils.h"
|
|
|
|
|
|
|
|
|
@ -287,7 +286,6 @@ int AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &m
|
|
|
|
|
|
|
|
|
|
int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgraph_index,
|
|
|
|
|
const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
|
|
|
|
const std::unique_ptr<schema::SubGraphT> &sub_graphT,
|
|
|
|
|
schema::CNodeT *return_node) {
|
|
|
|
|
MS_ASSERT(nullptr != meta_graphT);
|
|
|
|
|
MS_ASSERT(nullptr != return_node);
|
|
|
|
@ -319,9 +317,15 @@ int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgrap
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AnfExporter::HasExported(const FuncGraphPtr &func_graph) {
|
|
|
|
|
if (fg_subgraph_map_.find(func_graph) != fg_subgraph_map_.end()) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
|
|
|
|
const size_t &subgraph_index, const bool &keep_graph, const bool ©_primitive,
|
|
|
|
|
const std::unique_ptr<schema::SubGraphT> &sub_graphT) {
|
|
|
|
|
const size_t &subgraph_index, const bool &keep_graph, const bool ©_primitive) {
|
|
|
|
|
int ret = RET_OK;
|
|
|
|
|
auto cnodes = GetOrderedCNodes(func_graph);
|
|
|
|
|
for (const auto &cnode : cnodes) {
|
|
|
|
@ -334,19 +338,18 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
|
|
|
|
|
prim = GetValueNode<std::shared_ptr<Primitive>>(partial_cnode->input(0));
|
|
|
|
|
primT = GetPrimitiveT(partial_cnode->input(0));
|
|
|
|
|
MS_ASSERT(primT != nullptr);
|
|
|
|
|
auto pos = fg_subgraph_map.find(fg);
|
|
|
|
|
if (pos != fg_subgraph_map.end()) {
|
|
|
|
|
auto pos = fg_subgraph_map_.find(fg);
|
|
|
|
|
if (pos != fg_subgraph_map_.end()) {
|
|
|
|
|
MS_ASSERT(primT->value.AsPartialFusion() != nullptr);
|
|
|
|
|
primT->value.AsPartialFusion()->sub_graph_index = fg_subgraph_map.at(fg);
|
|
|
|
|
primT->value.AsPartialFusion()->sub_graph_index = fg_subgraph_map_.at(fg);
|
|
|
|
|
} else {
|
|
|
|
|
size_t next_subgraph_index = fg_subgraph_map.size() + 1;
|
|
|
|
|
fg_subgraph_map.insert(std::pair<FuncGraphPtr, int>{fg, next_subgraph_index});
|
|
|
|
|
size_t next_subgraph_index = meta_graphT->subGraph.size();
|
|
|
|
|
MS_ASSERT(primT->value.AsPartialFusion() != nullptr);
|
|
|
|
|
primT->value.AsPartialFusion()->sub_graph_index = next_subgraph_index;
|
|
|
|
|
ret = ExportSubgraph(fg, meta_graphT, next_subgraph_index, keep_graph, copy_primitive, cnode);
|
|
|
|
|
ret = ExportSubgraph(fg, meta_graphT, keep_graph, copy_primitive, cnode);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "ExportSubgraph failed";
|
|
|
|
|
break;
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
@ -374,7 +377,7 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
|
|
|
|
|
}
|
|
|
|
|
if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
|
|
|
|
|
node->name = mindspore::ops::kNameReturn;
|
|
|
|
|
ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, sub_graphT, node.get());
|
|
|
|
|
ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, node.get());
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "SetOpOutputN failed";
|
|
|
|
|
break;
|
|
|
|
@ -398,26 +401,28 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
meta_graphT->nodes.push_back(std::move(node));
|
|
|
|
|
meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx++);
|
|
|
|
|
meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx_++);
|
|
|
|
|
}
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
|
|
|
|
const size_t &subgraph_index, bool keep_graph, bool copy_primitive,
|
|
|
|
|
const std::shared_ptr<AnfNode> &partial_anode) {
|
|
|
|
|
int ret = RET_OK;
|
|
|
|
|
bool keep_graph, bool copy_primitive, const std::shared_ptr<AnfNode> &partial_anode) {
|
|
|
|
|
if (HasExported(func_graph)) {
|
|
|
|
|
MS_LOG(INFO) << "Has been exported.";
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
meta_graphT->subGraph.emplace_back(std::make_unique<schema::SubGraphT>());
|
|
|
|
|
auto &sub_graphT = meta_graphT->subGraph.at(subgraph_index);
|
|
|
|
|
auto subgraph_index = meta_graphT->subGraph.size() - 1;
|
|
|
|
|
fg_subgraph_map_[func_graph] = subgraph_index;
|
|
|
|
|
auto subgraph_name = func_graph->get_attr("graph_name");
|
|
|
|
|
MS_ASSERT(subgraph_name != nullptr);
|
|
|
|
|
sub_graphT->name = GetValue<std::string>(subgraph_name);
|
|
|
|
|
auto fmk = func_graph->get_attr("fmk");
|
|
|
|
|
MS_ASSERT(fmk != nullptr);
|
|
|
|
|
meta_graphT->fmkType = GetValue<int>(fmk);
|
|
|
|
|
meta_graphT->subGraph.back()->name = GetValue<std::string>(subgraph_name);
|
|
|
|
|
|
|
|
|
|
ret = Anf2Fb(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive, sub_graphT);
|
|
|
|
|
int ret = Anf2Fb(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Anf2Fb failed";
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
@ -441,11 +446,14 @@ int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::uniqu
|
|
|
|
|
|
|
|
|
|
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive,
|
|
|
|
|
bool train_flag) {
|
|
|
|
|
static int subgraph_index = 0;
|
|
|
|
|
this->train_flag = train_flag;
|
|
|
|
|
this->train_flag_ = train_flag;
|
|
|
|
|
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
|
|
|
|
|
int ret = ExportSubgraph(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive);
|
|
|
|
|
auto fmk = func_graph->get_attr("fmk");
|
|
|
|
|
MS_ASSERT(fmk != nullptr);
|
|
|
|
|
meta_graphT->fmkType = GetValue<int>(fmk);
|
|
|
|
|
int ret = ExportSubgraph(func_graph, meta_graphT, keep_graph, copy_primitive);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "ExportSubgraph failed.";
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
@ -455,7 +463,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee
|
|
|
|
|
int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode) {
|
|
|
|
|
MS_ASSERT(input_anode != nullptr && output_cnode != nullptr);
|
|
|
|
|
auto input_name = input_anode->fullname_with_scope();
|
|
|
|
|
if (this->train_flag) {
|
|
|
|
|
if (this->train_flag_) {
|
|
|
|
|
bool found = false;
|
|
|
|
|
if (node_id_map_.find(input_name) != node_id_map_.end()) {
|
|
|
|
|
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]);
|
|
|
|
@ -618,7 +626,7 @@ int AnfExporter::ProcessTensor(const ValueNodePtr &valueNode, std::unique_ptr<sc
|
|
|
|
|
(void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims),
|
|
|
|
|
[](const int64_t &value) { return static_cast<int32_t>(value); });
|
|
|
|
|
(*paramTensor)->dims = dims;
|
|
|
|
|
if (train_flag && (*paramTensor)->dims.empty()) (*paramTensor)->dims = {1};
|
|
|
|
|
if (train_flag_ && (*paramTensor)->dims.empty()) (*paramTensor)->dims = {1};
|
|
|
|
|
(*paramTensor)->nodeType = NodeType_ValueNode;
|
|
|
|
|
auto data = value->cast<tensor::TensorPtr>();
|
|
|
|
|
(*paramTensor)->data.resize(data->Size());
|
|
|
|
@ -742,7 +750,7 @@ int AnfExporter::ProcessParamValueLite(const ValueNodePtr &valueNode, std::uniqu
|
|
|
|
|
(*paramTensor)->dataType = valueLite->tensor_type();
|
|
|
|
|
(*paramTensor)->dims = valueLite->tensor_shape();
|
|
|
|
|
|
|
|
|
|
if (train_flag && (*paramTensor)->dims.empty()) {
|
|
|
|
|
if (train_flag_ && (*paramTensor)->dims.empty()) {
|
|
|
|
|
(*paramTensor)->dims = {1};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -765,7 +773,7 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano
|
|
|
|
|
auto paramTensor = std::make_unique<schema::TensorT>();
|
|
|
|
|
auto value = valueNode->value();
|
|
|
|
|
int ret = RET_OK;
|
|
|
|
|
if (train_flag) {
|
|
|
|
|
if (train_flag_) {
|
|
|
|
|
paramTensor->name = valueNode->fullname_with_scope();
|
|
|
|
|
}
|
|
|
|
|
if (value->isa<tensor::Tensor>()) {
|
|
|
|
@ -861,7 +869,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
|
|
|
|
|
}
|
|
|
|
|
msTensor->nodeType = NodeType_CNode;
|
|
|
|
|
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size());
|
|
|
|
|
if (train_flag) {
|
|
|
|
|
if (train_flag_) {
|
|
|
|
|
std::string name = cnode_name + "_o:" + std::to_string(i);
|
|
|
|
|
node_id_map_[name] = meta_graphT->allTensors.size();
|
|
|
|
|
meta_graphT->allTensors.emplace_back(msTensor);
|
|
|
|
|