|
|
|
@ -44,7 +44,6 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
|
|
|
|
|
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
|
|
|
|
|
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
|
|
|
|
|
|
|
|
|
|
std::set<std::string> SPECIAL_NODE = {"Gemm"};
|
|
|
|
|
FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file,
|
|
|
|
|
const QuantType &quant_type) {
|
|
|
|
|
NoSupportOp::GetInstance()->SetFmkType("ONNX");
|
|
|
|
@ -215,11 +214,6 @@ STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const F
|
|
|
|
|
MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed.";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (IsSpecialOnnxNode(onnx_node)) {
|
|
|
|
|
auto status_node = ConvertSpecialOnnxNode(onnx_node, anf_graph, anf_nodes_map, primitive_c);
|
|
|
|
|
status = status == RET_OK ? status_node : status;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// build CNode
|
|
|
|
|
status = BuildCNode(onnx_node, anf_graph, anf_nodes_map, graph_inputs, primitive_c, root_node_name);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
@ -1023,117 +1017,6 @@ STATUS OnnxModelParser::BuildCondGraph(const FuncGraphPtr &cond_graph, const Anf
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph,
|
|
|
|
|
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
|
|
|
|
|
ops::PrimitiveC *primitive_c) {
|
|
|
|
|
if (primitive_c == nullptr || anf_graph == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "imitive_c is nullptr.";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
STATUS status = RET_OK;
|
|
|
|
|
if (onnx_node.op_type() == "Gemm") {
|
|
|
|
|
status = ConvertOnnxGemmNode(onnx_node, anf_graph, anf_nodes_map, primitive_c);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "the node is not special node.";
|
|
|
|
|
status = RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
delete primitive_c;
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph,
|
|
|
|
|
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
|
|
|
|
|
ops::PrimitiveC *primitive_c) {
|
|
|
|
|
if (primitive_c == nullptr || anf_graph == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "parameter has nullptr.";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
if (onnx_node.op_type() != "Gemm") {
|
|
|
|
|
MS_LOG(ERROR) << "this op is not gemm, it is " << onnx_node.op_type();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (primitive_c == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "primitive_c is nullptr.";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
auto status = BuildCNodeForGemm(onnx_node, anf_graph, anf_nodes_map, primitive_c, "MatMul");
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "convert gemm node failed.";
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
status = BuildCNodeForGemm(onnx_node, anf_graph, anf_nodes_map, primitive_c, "BiasAdd");
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "convert gemm node failed.";
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph,
|
|
|
|
|
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
|
|
|
|
|
ops::PrimitiveC *primitive_c, const std::string &name) {
|
|
|
|
|
if (primitive_c == nullptr || anf_graph == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "parameter has nullptr.";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
auto value = primitive_c->GetAttr(name);
|
|
|
|
|
primitive_c->EraseAttr(name);
|
|
|
|
|
if (value == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "op parse failed.";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
auto prim_ptr = value->cast<std::shared_ptr<ops::PrimitiveC>>();
|
|
|
|
|
if (prim_ptr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "primitive parse failed.";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
auto type_ptr = TypeIdToType(kTypeUnknown);
|
|
|
|
|
std::vector<int64_t> shape_vector;
|
|
|
|
|
std::vector<AnfNodePtr> op_inputs;
|
|
|
|
|
auto quant_params_holder = std::make_shared<QuantParamHolder>();
|
|
|
|
|
auto quant_params_holder_origin = primitive_c->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
|
|
|
|
|
if (name == "MatMul") {
|
|
|
|
|
for (int i = 0; i < 2; ++i) {
|
|
|
|
|
if (anf_nodes_map->find(onnx_node.input(i)) == anf_nodes_map->end()) {
|
|
|
|
|
MS_LOG(ERROR) << "op " << onnx_node.op_type() << " inputs get failed.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
} else {
|
|
|
|
|
op_inputs.push_back(anf_nodes_map->at(onnx_node.input(i)));
|
|
|
|
|
quant_params_holder->AddInputQuantParam(quant_params_holder_origin->input_quant_params().at(i));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
quant_params_holder->AddOutputQuantParam(std::vector<schema::QuantParamT>(1));
|
|
|
|
|
auto new_cnode = anf_graph->NewCNode(prim_ptr, op_inputs);
|
|
|
|
|
if (new_cnode == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new cnode error";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
new_cnode->set_fullname_with_scope("Gemm_MatMul_" + onnx_node.output(0));
|
|
|
|
|
new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
|
|
|
|
|
anf_nodes_map->emplace("Gemm_MatMul_" + onnx_node.output(0), new_cnode);
|
|
|
|
|
} else {
|
|
|
|
|
if (anf_nodes_map->find("Gemm_MatMul_" + onnx_node.output(0)) == anf_nodes_map->end() ||
|
|
|
|
|
anf_nodes_map->find(onnx_node.input(2)) == anf_nodes_map->end()) {
|
|
|
|
|
MS_LOG(ERROR) << "op " << onnx_node.op_type() << " inputs get failed.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
op_inputs.push_back(anf_nodes_map->at("Gemm_MatMul_" + onnx_node.output(0)));
|
|
|
|
|
op_inputs.push_back(anf_nodes_map->at(onnx_node.input(2)));
|
|
|
|
|
quant_params_holder->AddInputQuantParam(std::vector<schema::QuantParamT>(1));
|
|
|
|
|
quant_params_holder->AddInputQuantParam(quant_params_holder_origin->input_quant_params().at(2));
|
|
|
|
|
quant_params_holder->AddOutputQuantParam(quant_params_holder_origin->output_quant_params().front());
|
|
|
|
|
auto new_cnode = anf_graph->NewCNode(prim_ptr, op_inputs);
|
|
|
|
|
if (new_cnode == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new cnode error";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
new_cnode->set_fullname_with_scope("Gemm_BiasAdd_" + onnx_node.output(0));
|
|
|
|
|
new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
|
|
|
|
|
anf_nodes_map->emplace(onnx_node.output(0), new_cnode);
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const std::string &name, TypeId type) {
|
|
|
|
|
if (data == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "value is nullptr.";
|
|
|
|
@ -1281,10 +1164,6 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_t
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool OnnxModelParser::IsSpecialOnnxNode(const onnx::NodeProto &onnx_node) {
|
|
|
|
|
return SPECIAL_NODE.find(onnx_node.op_type()) != SPECIAL_NODE.end();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) {
|
|
|
|
|
auto iter = TYPE_MAP.find(onnx_type);
|
|
|
|
|
if (iter == TYPE_MAP.end()) {
|
|
|
|
|