|
|
|
@ -203,7 +203,9 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64)
|
|
|
|
|
|
|
|
|
|
int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node,
|
|
|
|
|
const onnx::ValueInfoProto &value_proto) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (node == nullptr) {
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
if (!value_proto.has_type() || !value_proto.has_name()) {
|
|
|
|
|
MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! ";
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
@ -236,12 +238,16 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node
|
|
|
|
|
|
|
|
|
|
if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) {
|
|
|
|
|
Tensor *tensor_info = new Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_info);
|
|
|
|
|
if (tensor_info == nullptr) {
|
|
|
|
|
return RET_MEMORY_FAILED;
|
|
|
|
|
}
|
|
|
|
|
tensor_info->MallocData();
|
|
|
|
|
const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()];
|
|
|
|
|
std::string initial_data = initialize_proto.raw_data();
|
|
|
|
|
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->MutableData());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_data_buf);
|
|
|
|
|
if (tensor_data_buf == nullptr) {
|
|
|
|
|
return RET_MEMORY_FAILED;
|
|
|
|
|
}
|
|
|
|
|
tensor_info->SetData(nullptr);
|
|
|
|
|
auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size());
|
|
|
|
|
if (EOK != ret) {
|
|
|
|
@ -252,7 +258,9 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(param_value);
|
|
|
|
|
if (param_value == nullptr) {
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
param_value->set_tensor_addr(tensor_data_buf);
|
|
|
|
|
param_value->set_tensor_size(tensor_info->Size());
|
|
|
|
|
param_value->set_tensor_type(tensor_info->data_type());
|
|
|
|
@ -266,7 +274,9 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node
|
|
|
|
|
|
|
|
|
|
int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
|
|
|
|
const onnx::GraphProto &importProto) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
|
|
|
|
if (outputFuncGraph == nullptr) {
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size();
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < importProto.initializer_size(); ++i) {
|
|
|
|
@ -293,7 +303,9 @@ int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &output
|
|
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
|
|
|
|
|
const onnx::TensorProto &attr_tensor) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
if (prim == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
const int attr_tensor_type = attr_tensor.data_type();
|
|
|
|
|
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
|
|
|
|
|
MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
|
|
|
|
@ -336,7 +348,9 @@ ValuePtr AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const onnx::Tensor
|
|
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
|
|
|
|
|
const onnx::TensorProto &attr_tensor) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
if (prim == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
const int attr_tensor_type = attr_tensor.data_type();
|
|
|
|
|
const std::string &tensor_buf = attr_tensor.raw_data();
|
|
|
|
|
std::vector<int> shape;
|
|
|
|
@ -371,7 +385,9 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &pr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
if (prim == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
const std::string &attr_name = attr_proto.name();
|
|
|
|
|
if (!attr_proto.has_ref_attr_name()) {
|
|
|
|
|
MS_LOG(ERROR) << "CNode parse attr type has no ref_attr_name";
|
|
|
|
@ -435,7 +451,9 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto new_value_node = NewValueNode(MakeValue(tensor_info));
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_value_node);
|
|
|
|
|
if (new_value_node == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]);
|
|
|
|
|
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
|
|
|
|
|
new_value_node->set_abstract(abstract_tensor);
|
|
|
|
@ -539,7 +557,10 @@ std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromProt
|
|
|
|
|
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
|
|
|
|
const onnx::NodeProto &node_proto,
|
|
|
|
|
const schema::QuantType &quantType) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
|
|
|
|
if (outputFuncGraph == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "output funcgraph is nullptr";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (!node_proto.has_op_type()) {
|
|
|
|
|
MS_LOG(ERROR) << "Get CNode op_type failed!";
|
|
|
|
|
return nullptr;
|
|
|
|
@ -548,7 +569,10 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
|
|
|
|
|
const std::string &fullname_with_scope = node_proto.domain();
|
|
|
|
|
const std::string &node_type = node_proto.op_type();
|
|
|
|
|
PrimitivePtr prim = std::make_shared<mindspore::Primitive>(node_type);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
if (prim == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new primitive failed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
prim->set_instance_name(node_type);
|
|
|
|
|
std::unordered_map<std::string, abstract::AbstractTensorPtr> kv;
|
|
|
|
|
string shape_ref_attr_name;
|
|
|
|
@ -582,7 +606,10 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
|
|
|
|
|
}
|
|
|
|
|
inputs.insert(inputs.begin(), NewValueNode(primitivec_ptr));
|
|
|
|
|
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
|
|
|
|
if (cnode_ptr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "funcgraph new cnode failed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (0 == kv.size()) {
|
|
|
|
|
AbstractBasePtrList elem;
|
|
|
|
|
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
|
|
|
|
@ -604,8 +631,10 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
|
|
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
|
|
|
|
const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
|
|
|
|
if (outputFuncGraph == nullptr || cnode_ptr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "output funcgraph or cnode is nullptr";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
|
|
|
if (importProto.output_size() > 1) {
|
|
|
|
|
inputs.clear();
|
|
|
|
@ -633,7 +662,10 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
|
|
|
|
|
inputs.push_back(NewValueNode(primitive_return_value_ptr));
|
|
|
|
|
inputs.push_back(maketuple_ptr);
|
|
|
|
|
auto return_node = outputFuncGraph->NewCNode(inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(return_node);
|
|
|
|
|
if (return_node == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "funcgraph new cnode failed";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
outputFuncGraph->set_return(return_node);
|
|
|
|
|
MS_LOG(INFO) << "Construct funcgraph finined, all success.";
|
|
|
|
|
} else {
|
|
|
|
@ -656,7 +688,10 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
|
|
|
|
|
inputs.push_back(NewValueNode(primitiveTReturnValuePtr));
|
|
|
|
|
inputs.push_back(cnode_ptr);
|
|
|
|
|
auto return_node = outputFuncGraph->NewCNode(inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(return_node);
|
|
|
|
|
if (return_node == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "funcgraph new cnode failed";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return_node->set_abstract(abstract_tensor);
|
|
|
|
|
outputFuncGraph->set_return(return_node);
|
|
|
|
|
MS_LOG(INFO) << "Construct funcgraph finined, all success!";
|
|
|
|
@ -667,7 +702,10 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
|
|
|
|
|
int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
|
|
|
|
|
const onnx::GraphProto &importProto,
|
|
|
|
|
const schema::QuantType &quantType) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
|
|
|
|
if (outputFuncGraph == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "funcgraph is nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
|
|
|
|
|
CNodePtr cnode_ptr = nullptr;
|
|
|
|
|
for (int i = 0; i < importProto.node_size(); ++i) {
|
|
|
|
@ -696,9 +734,15 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG
|
|
|
|
|
|
|
|
|
|
int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
|
|
|
|
|
const schema::QuantType &quantType) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
|
|
|
|
if (outputFuncGraph == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "fundgraph is nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(debug_info_ptr);
|
|
|
|
|
if (debug_info_ptr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "funcgraph's debug info is nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
if (importProto.has_name()) {
|
|
|
|
|
debug_info_ptr->set_name(importProto.name());
|
|
|
|
|
} else {
|
|
|
|
@ -735,7 +779,10 @@ int AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &mod
|
|
|
|
|
|
|
|
|
|
int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
|
|
|
|
|
FuncGraphPtr dstGraph = std::make_shared<mindspore::FuncGraph>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(dstGraph);
|
|
|
|
|
if (dstGraph == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "funcgraph is nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
int status = ParseModelConfigureInfo(*onnx_model_);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
|
|
|
|