|
|
|
@ -328,15 +328,15 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
RemoveIfMakeTuple(cnode);
|
|
|
|
|
#ifdef SUPPORT_TRAIN
|
|
|
|
|
RemoveIfDepend(cnode);
|
|
|
|
|
#endif
|
|
|
|
|
if (train_flag) {
|
|
|
|
|
RemoveIfDepend(cnode);
|
|
|
|
|
if (primitive_c->Type() == schema::PrimitiveType_Depend ||
|
|
|
|
|
primitive_c->Type() == schema::PrimitiveType_ControlDepend) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) ||
|
|
|
|
|
#ifdef SUPPORT_TRAIN
|
|
|
|
|
(primitive_c->Type() == schema::PrimitiveType_Depend) ||
|
|
|
|
|
(primitive_c->Type() == schema::PrimitiveType_ControlDepend) ||
|
|
|
|
|
#endif
|
|
|
|
|
(primitive_c->Type() == schema::PrimitiveType_MakeTuple)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
@ -424,8 +424,10 @@ int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::uniqu
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) {
|
|
|
|
|
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive,
|
|
|
|
|
bool train_flag) {
|
|
|
|
|
static int subgraph_index = 0;
|
|
|
|
|
this->train_flag = train_flag;
|
|
|
|
|
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
|
|
|
|
|
int ret = ExportSubgraph(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
@ -439,24 +441,18 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode,
|
|
|
|
|
std::string input_name = input_anode->fullname_with_scope();
|
|
|
|
|
auto input_cnode = utils::cast<CNodePtr>(input_anode);
|
|
|
|
|
if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) {
|
|
|
|
|
#ifndef SUPPORT_TRAIN
|
|
|
|
|
if (node_id_map_.find(input_name) != node_id_map_.end()) {
|
|
|
|
|
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]);
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
bool found = false;
|
|
|
|
|
if (node_id_map_.find(input_name) != node_id_map_.end()) {
|
|
|
|
|
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]);
|
|
|
|
|
found = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (found == false) {
|
|
|
|
|
if (!found) {
|
|
|
|
|
auto input_index_key = input_name + "_o:" + std::to_string(0);
|
|
|
|
|
if (node_id_map_.find(input_index_key) != node_id_map_.end()) {
|
|
|
|
|
output_cnode->inputIndex.emplace_back(node_id_map_[input_index_key]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
} else {
|
|
|
|
|
auto inputs = input_cnode->inputs();
|
|
|
|
|
|
|
|
|
@ -481,17 +477,12 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode,
|
|
|
|
|
: GetValue<int>(value_node->value()));
|
|
|
|
|
auto iter = node_id_map_.find(input_index_key);
|
|
|
|
|
if (iter == node_id_map_.end()) {
|
|
|
|
|
#ifdef SUPPORT_TRAIN
|
|
|
|
|
input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(0); // try name with 0
|
|
|
|
|
iter = node_id_map_.find(input_index_key);
|
|
|
|
|
if (iter == node_id_map_.end()) {
|
|
|
|
|
MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
output_cnode->inputIndex.emplace_back(iter->second);
|
|
|
|
|
}
|
|
|
|
@ -571,9 +562,7 @@ int AnfExporter::ProcessTensor(const ValueNodePtr &valueNode, std::unique_ptr<sc
|
|
|
|
|
(void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims),
|
|
|
|
|
[](const int64_t &value) { return static_cast<int32_t>(value); });
|
|
|
|
|
(*paramTensor)->dims = dims;
|
|
|
|
|
#ifdef SUPPORT_TRAIN
|
|
|
|
|
if ((*paramTensor)->dims.size() == 0) (*paramTensor)->dims = {1};
|
|
|
|
|
#endif
|
|
|
|
|
if (train_flag && (*paramTensor)->dims.empty()) (*paramTensor)->dims = {1};
|
|
|
|
|
(*paramTensor)->nodeType = schema::NodeType::NodeType_ValueNode;
|
|
|
|
|
auto data = value->cast<tensor::TensorPtr>();
|
|
|
|
|
(*paramTensor)->data.resize(data->Size());
|
|
|
|
@ -679,11 +668,11 @@ int AnfExporter::ProcessParamValueLite(const ValueNodePtr &valueNode, std::uniqu
|
|
|
|
|
(*paramTensor)->format = schema::Format(valueLite->format());
|
|
|
|
|
(*paramTensor)->dataType = valueLite->tensor_type();
|
|
|
|
|
(*paramTensor)->dims = valueLite->tensor_shape();
|
|
|
|
|
#ifdef SUPPORT_TRAIN
|
|
|
|
|
if ((*paramTensor)->dims.size() == 0) {
|
|
|
|
|
|
|
|
|
|
if (train_flag && (*paramTensor)->dims.empty()) {
|
|
|
|
|
(*paramTensor)->dims = {1};
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
ret = memcpy_s((*paramTensor)->data.data(), valueLite->tensor_size() * sizeof(uint8_t), valueLite->tensor_addr(),
|
|
|
|
|
valueLite->tensor_size());
|
|
|
|
|
if (ret != EOK) {
|
|
|
|
@ -703,9 +692,9 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano
|
|
|
|
|
auto paramTensor = std::make_unique<schema::TensorT>();
|
|
|
|
|
auto value = valueNode->value();
|
|
|
|
|
int ret = RET_OK;
|
|
|
|
|
#ifdef SUPPORT_TRAIN
|
|
|
|
|
paramTensor->name = valueNode->fullname_with_scope();
|
|
|
|
|
#endif
|
|
|
|
|
if (train_flag) {
|
|
|
|
|
paramTensor->name = valueNode->fullname_with_scope();
|
|
|
|
|
}
|
|
|
|
|
if (value->isa<tensor::Tensor>()) {
|
|
|
|
|
ret = ProcessTensor(valueNode, ¶mTensor, value, output_cnode, meta_graphT);
|
|
|
|
|
} else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) {
|
|
|
|
@ -797,44 +786,44 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
|
|
|
|
|
}
|
|
|
|
|
msTensor->nodeType = schema::NodeType_CNode;
|
|
|
|
|
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size());
|
|
|
|
|
#ifdef SUPPORT_TRAIN
|
|
|
|
|
std::string name = cnode_name + "_o:" + std::to_string(i);
|
|
|
|
|
node_id_map_[name] = meta_graphT->allTensors.size();
|
|
|
|
|
meta_graphT->allTensors.emplace_back(msTensor);
|
|
|
|
|
if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) ||
|
|
|
|
|
IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) ||
|
|
|
|
|
IsPrimitiveCNode(cnode, schema::PrimitiveType_Adam))
|
|
|
|
|
break;
|
|
|
|
|
#else
|
|
|
|
|
if (elements.size() == 1) {
|
|
|
|
|
node_id_map_[cnode_name] = meta_graphT->allTensors.size();
|
|
|
|
|
msTensor->name = cnode_name;
|
|
|
|
|
} else {
|
|
|
|
|
if (train_flag) {
|
|
|
|
|
std::string name = cnode_name + "_o:" + std::to_string(i);
|
|
|
|
|
node_id_map_[name] = meta_graphT->allTensors.size();
|
|
|
|
|
msTensor->name = name;
|
|
|
|
|
}
|
|
|
|
|
meta_graphT->allTensors.emplace_back(msTensor);
|
|
|
|
|
if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) ||
|
|
|
|
|
IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) ||
|
|
|
|
|
IsPrimitiveCNode(cnode, schema::PrimitiveType_Adam))
|
|
|
|
|
break;
|
|
|
|
|
} else {
|
|
|
|
|
if (elements.size() == 1) {
|
|
|
|
|
node_id_map_[cnode_name] = meta_graphT->allTensors.size();
|
|
|
|
|
msTensor->name = cnode_name;
|
|
|
|
|
} else {
|
|
|
|
|
std::string name = cnode_name + "_o:" + std::to_string(i);
|
|
|
|
|
node_id_map_[name] = meta_graphT->allTensors.size();
|
|
|
|
|
msTensor->name = name;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
|
|
|
|
|
MS_LOG(ERROR) << "abstract is not AbstractTensor";
|
|
|
|
|
delete (msTensor);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto type = kNumberTypeFloat32;
|
|
|
|
|
if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
|
|
|
|
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
|
|
|
|
|
auto typePtr = abstract_tensor->element()->GetTypeTrack();
|
|
|
|
|
type = typePtr->type_id();
|
|
|
|
|
}
|
|
|
|
|
msTensor->dataType = type;
|
|
|
|
|
meta_graphT->allTensors.emplace_back(msTensor);
|
|
|
|
|
if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) ||
|
|
|
|
|
IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) ||
|
|
|
|
|
IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm) ||
|
|
|
|
|
IsPrimitiveCNode(cnode, schema::PrimitiveType_LayerNorm)) {
|
|
|
|
|
break;
|
|
|
|
|
if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
|
|
|
|
|
MS_LOG(ERROR) << "abstract is not AbstractTensor";
|
|
|
|
|
delete (msTensor);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto type = kNumberTypeFloat32;
|
|
|
|
|
if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
|
|
|
|
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
|
|
|
|
|
auto typePtr = abstract_tensor->element()->GetTypeTrack();
|
|
|
|
|
type = typePtr->type_id();
|
|
|
|
|
}
|
|
|
|
|
msTensor->dataType = type;
|
|
|
|
|
meta_graphT->allTensors.emplace_back(msTensor);
|
|
|
|
|
if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) ||
|
|
|
|
|
IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) ||
|
|
|
|
|
IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm) ||
|
|
|
|
|
IsPrimitiveCNode(cnode, schema::PrimitiveType_LayerNorm)) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto ms_tensor = new (std::nothrow) schema::TensorT();
|
|
|
|
@ -927,8 +916,8 @@ CNodePtr AnfExporter::CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr node
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) {
|
|
|
|
|
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive, bool train_flag) {
|
|
|
|
|
AnfExporter anf_exporter;
|
|
|
|
|
return anf_exporter.Export(func_graph, keep_graph, copy_primitive);
|
|
|
|
|
return anf_exporter.Export(func_graph, keep_graph, copy_primitive, train_flag);
|
|
|
|
|
}
|
|
|
|
|
} // namespace mindspore::lite
|
|
|
|
|