|
|
@ -35,12 +35,12 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) {
|
|
|
|
|
|
|
|
|
|
|
|
inputs.emplace_back(cnode->input(0));
|
|
|
|
inputs.emplace_back(cnode->input(0));
|
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
|
|
|
AnfNodePtr inputNode = cnode->input(i);
|
|
|
|
AnfNodePtr input_node = cnode->input(i);
|
|
|
|
if (!inputNode->isa<CNode>()) {
|
|
|
|
if (!input_node->isa<CNode>()) {
|
|
|
|
inputs.emplace_back(cnode->input(i));
|
|
|
|
inputs.emplace_back(cnode->input(i));
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto make_tuple_node = utils::cast<CNodePtr>(inputNode);
|
|
|
|
auto make_tuple_node = utils::cast<CNodePtr>(input_node);
|
|
|
|
if (IsPrimitiveCNode(make_tuple_node, schema::PrimitiveType_MakeTuple)) {
|
|
|
|
if (IsPrimitiveCNode(make_tuple_node, schema::PrimitiveType_MakeTuple)) {
|
|
|
|
has_make_tuple = true;
|
|
|
|
has_make_tuple = true;
|
|
|
|
for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) {
|
|
|
|
for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) {
|
|
|
@ -62,12 +62,12 @@ bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) {
|
|
|
|
inputs.clear();
|
|
|
|
inputs.clear();
|
|
|
|
inputs.emplace_back(cnode->input(0));
|
|
|
|
inputs.emplace_back(cnode->input(0));
|
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
|
|
|
AnfNodePtr inputNode = cnode->input(i);
|
|
|
|
AnfNodePtr input_node = cnode->input(i);
|
|
|
|
if (!inputNode->isa<CNode>()) {
|
|
|
|
if (!input_node->isa<CNode>()) {
|
|
|
|
inputs.emplace_back(cnode->input(i));
|
|
|
|
inputs.emplace_back(cnode->input(i));
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto tuple_get_item_node = utils::cast<CNodePtr>(inputNode);
|
|
|
|
auto tuple_get_item_node = utils::cast<CNodePtr>(input_node);
|
|
|
|
if (IsPrimitiveCNode(tuple_get_item_node, schema::PrimitiveType_TupleGetItem)) {
|
|
|
|
if (IsPrimitiveCNode(tuple_get_item_node, schema::PrimitiveType_TupleGetItem)) {
|
|
|
|
has_tuple_get_item = true;
|
|
|
|
has_tuple_get_item = true;
|
|
|
|
inputs.emplace_back(tuple_get_item_node->input(1));
|
|
|
|
inputs.emplace_back(tuple_get_item_node->input(1));
|
|
|
@ -76,7 +76,7 @@ bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) {
|
|
|
|
MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
|
|
|
|
MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ValueNodePtr value_node = utils::cast<ValueNodePtr>(indexNode);
|
|
|
|
auto value_node = utils::cast<ValueNodePtr>(indexNode);
|
|
|
|
map_remove_get_item_[tuple_get_item_node->input(1)->fullname_with_scope()] = GetValue<int>(value_node->value());
|
|
|
|
map_remove_get_item_[tuple_get_item_node->input(1)->fullname_with_scope()] = GetValue<int>(value_node->value());
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
inputs.emplace_back(cnode->input(i));
|
|
|
|
inputs.emplace_back(cnode->input(i));
|
|
|
@ -92,15 +92,20 @@ bool AnfExporter::AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &m
|
|
|
|
MS_ASSERT(meta_graphT != nullptr);
|
|
|
|
MS_ASSERT(meta_graphT != nullptr);
|
|
|
|
MS_ASSERT(cnode != nullptr);
|
|
|
|
MS_ASSERT(cnode != nullptr);
|
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
|
|
|
auto inputNode = cnode->input(i);
|
|
|
|
auto input_anode = cnode->input(i);
|
|
|
|
if (!inputNode->isa<CNode>()) {
|
|
|
|
if (!input_anode->isa<CNode>()) {
|
|
|
|
MS_LOG(ERROR) << "Node of Return's input is not CNode";
|
|
|
|
MS_LOG(ERROR) << "Node of Return's input is not CNode";
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto inputCNode = utils::cast<CNodePtr>(inputNode);
|
|
|
|
auto input_cnode = utils::cast<CNodePtr>(input_anode);
|
|
|
|
std::string inputName = inputNode->fullname_with_scope();
|
|
|
|
std::string input_name = input_anode->fullname_with_scope();
|
|
|
|
auto graphOutput = node_id_map_[inputName];
|
|
|
|
auto iter = node_id_map_.find(input_name);
|
|
|
|
meta_graphT->outputIndex.emplace_back(graphOutput);
|
|
|
|
if (iter == node_id_map_.end()) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Could not find output node";
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto graph_output = iter->second;
|
|
|
|
|
|
|
|
meta_graphT->outputIndex.emplace_back(graph_output);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -198,10 +203,10 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
map_remove_get_item_.clear();
|
|
|
|
map_remove_get_item_.clear();
|
|
|
|
RemoveIfMakeTuple(cnode);
|
|
|
|
RemoveIfMakeTuple(cnode);
|
|
|
|
if (!RemoveIfTupleGetItem(cnode)) {
|
|
|
|
// if (!RemoveIfTupleGetItem(cnode)) {
|
|
|
|
MS_LOG(ERROR) << "RemoveIfTupleGetItem failed";
|
|
|
|
// MS_LOG(ERROR) << "RemoveIfTupleGetItem failed";
|
|
|
|
return nullptr;
|
|
|
|
// return nullptr;
|
|
|
|
}
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
if (primT->value.type == schema::PrimitiveType_Return) {
|
|
|
|
if (primT->value.type == schema::PrimitiveType_Return) {
|
|
|
|
AddOutPutIfReturn(meta_graphT, cnode);
|
|
|
|
AddOutPutIfReturn(meta_graphT, cnode);
|
|
|
@ -234,33 +239,54 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
|
|
|
|
return meta_graphT.release();
|
|
|
|
return meta_graphT.release();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, schema::CNodeT *output_cnode) {
|
|
|
|
int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, schema::CNodeT *output_cnode) {
|
|
|
|
std::string input_name = input_anode->fullname_with_scope();
|
|
|
|
std::string input_name = input_anode->fullname_with_scope();
|
|
|
|
if (!map_remove_get_item_.empty()) {
|
|
|
|
auto input_cnode = utils::cast<CNodePtr>(input_anode);
|
|
|
|
for (auto name : map_remove_get_item_) {
|
|
|
|
if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) {
|
|
|
|
if (name.first == input_name) {
|
|
|
|
if (node_id_map_.find(input_name) != node_id_map_.end()) {
|
|
|
|
input_name = input_name + "_o:" + std::to_string(name.second);
|
|
|
|
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]);
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
auto inputs = input_cnode->inputs();
|
|
|
|
|
|
|
|
if (inputs.size() != 3) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << inputs.size();
|
|
|
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto get_item_input_cnode = inputs.at(1);
|
|
|
|
|
|
|
|
auto index_vnode = inputs.at(2);
|
|
|
|
|
|
|
|
if (!utils::isa<ValueNode>(index_vnode)) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
|
|
|
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto value_node = utils::cast<ValueNodePtr>(index_vnode);
|
|
|
|
|
|
|
|
if (value_node == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "cast to ValueNode failed";
|
|
|
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto input_index_key =
|
|
|
|
|
|
|
|
get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(GetValue<int>(value_node->value()));
|
|
|
|
|
|
|
|
auto iter = node_id_map_.find(input_index_key);
|
|
|
|
|
|
|
|
if (iter == node_id_map_.end()) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Can not find get_item output tensor";
|
|
|
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
output_cnode->inputIndex.emplace_back(iter->second);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (node_id_map_.find(input_name) != node_id_map_.end()) {
|
|
|
|
return RET_OK;
|
|
|
|
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anode, size_t anode_index,
|
|
|
|
int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anode,
|
|
|
|
const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
|
|
|
const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
|
|
|
schema::CNodeT *output_cnode) {
|
|
|
|
schema::CNodeT *output_cnode) {
|
|
|
|
std::string input_name = input_anode->fullname_with_scope();
|
|
|
|
|
|
|
|
auto paramNode = input_anode->cast<ParameterPtr>();
|
|
|
|
auto paramNode = input_anode->cast<ParameterPtr>();
|
|
|
|
if (paramNode->name().empty()) {
|
|
|
|
std::string input_name = paramNode->fullname_with_scope();
|
|
|
|
paramNode->set_name(input_name + "_i:" + std::to_string(anode_index - 1));
|
|
|
|
if (node_id_map_.find(input_name) != node_id_map_.end()) {
|
|
|
|
}
|
|
|
|
|
|
|
|
if (node_id_map_.find(paramNode->name()) != node_id_map_.end()) {
|
|
|
|
|
|
|
|
output_cnode->inputIndex.emplace_back(node_id_map_[paramNode->name()]);
|
|
|
|
output_cnode->inputIndex.emplace_back(node_id_map_[paramNode->name()]);
|
|
|
|
return RET_OK;
|
|
|
|
return RET_OK;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto paramTensor = std::make_unique<schema::TensorT>();
|
|
|
|
auto paramTensor = std::make_unique<schema::TensorT>();
|
|
|
|
|
|
|
|
paramTensor->nodeType = schema::NodeType_ValueNode;
|
|
|
|
|
|
|
|
paramTensor->format = schema::Format_NHWC;
|
|
|
|
auto abstractBase = paramNode->abstract();
|
|
|
|
auto abstractBase = paramNode->abstract();
|
|
|
|
if (abstractBase == nullptr) {
|
|
|
|
if (abstractBase == nullptr) {
|
|
|
|
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name();
|
|
|
|
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name();
|
|
|
@ -274,7 +300,6 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anod
|
|
|
|
auto typePtr = abstractTensor->element()->GetTypeTrack();
|
|
|
|
auto typePtr = abstractTensor->element()->GetTypeTrack();
|
|
|
|
MS_ASSERT(typePtr != nullptr);
|
|
|
|
MS_ASSERT(typePtr != nullptr);
|
|
|
|
paramTensor->dataType = typePtr->type_id();
|
|
|
|
paramTensor->dataType = typePtr->type_id();
|
|
|
|
paramTensor->format = schema::Format(abstractTensor->format());
|
|
|
|
|
|
|
|
if (!utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
|
|
|
|
if (!utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) {
|
|
|
|
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name();
|
|
|
|
MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name();
|
|
|
|
return RET_ERROR;
|
|
|
|
return RET_ERROR;
|
|
|
@ -282,11 +307,11 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anod
|
|
|
|
paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
|
|
|
|
paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
|
|
|
|
auto paramValue = std::dynamic_pointer_cast<ParamValueLite>(paramNode->default_param());
|
|
|
|
auto paramValue = std::dynamic_pointer_cast<ParamValueLite>(paramNode->default_param());
|
|
|
|
if (paramValue != nullptr) {
|
|
|
|
if (paramValue != nullptr) {
|
|
|
|
paramTensor->nodeType = schema::NodeType_ValueNode;
|
|
|
|
|
|
|
|
paramTensor->data.resize(paramValue->tensor_size());
|
|
|
|
paramTensor->data.resize(paramValue->tensor_size());
|
|
|
|
|
|
|
|
paramTensor->format = schema::Format(paramValue->format());
|
|
|
|
memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size());
|
|
|
|
memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
node_id_map_[paramNode->fullname_with_scope()] = meta_graphT->allTensors.size();
|
|
|
|
node_id_map_[input_name] = meta_graphT->allTensors.size();
|
|
|
|
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
|
|
|
|
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
|
|
|
|
meta_graphT->allTensors.emplace_back(std::move(paramTensor));
|
|
|
|
meta_graphT->allTensors.emplace_back(std::move(paramTensor));
|
|
|
|
return RET_OK;
|
|
|
|
return RET_OK;
|
|
|
@ -345,9 +370,13 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<sch
|
|
|
|
auto input_node = cnode->input(i);
|
|
|
|
auto input_node = cnode->input(i);
|
|
|
|
if (input_node->isa<CNode>()) {
|
|
|
|
if (input_node->isa<CNode>()) {
|
|
|
|
is_graph_input = false;
|
|
|
|
is_graph_input = false;
|
|
|
|
ConvertInputCNode(input_node, fb_node);
|
|
|
|
auto ret = ConvertInputCNode(input_node, fb_node);
|
|
|
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "ConvertInputCNode failed";
|
|
|
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
|
|
|
}
|
|
|
|
} else if (input_node->isa<Parameter>()) {
|
|
|
|
} else if (input_node->isa<Parameter>()) {
|
|
|
|
auto ret = ConvertInputParameter(input_node, i, meta_graphT, fb_node);
|
|
|
|
auto ret = ConvertInputParameter(input_node, meta_graphT, fb_node);
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
MS_LOG(ERROR) << "ConvertInputParameter failed";
|
|
|
|
MS_LOG(ERROR) << "ConvertInputParameter failed";
|
|
|
|
return RET_ERROR;
|
|
|
|
return RET_ERROR;
|
|
|
|