|
|
|
@ -43,7 +43,6 @@ using int64 = int64_t;
|
|
|
|
|
using uint64 = uint64_t;
|
|
|
|
|
|
|
|
|
|
namespace mindspore::lite {
|
|
|
|
|
|
|
|
|
|
static constexpr char kConstantValueNode[] = "Constant";
|
|
|
|
|
|
|
|
|
|
enum ParseForm : int {
|
|
|
|
@ -212,7 +211,7 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node
|
|
|
|
|
node->set_name(value_proto.name());
|
|
|
|
|
const auto &type_proto = value_proto.type();
|
|
|
|
|
if (!type_proto.has_tensor_type()) {
|
|
|
|
|
MS_LOG(ERROR) << "onnx TypeProto has no tesor_type! ";
|
|
|
|
|
MS_LOG(ERROR) << "onnx TypeProto has no tensor_type! ";
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
const onnx::TypeProto_Tensor &tensor_typeproto = type_proto.tensor_type();
|
|
|
|
@ -248,6 +247,7 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node
|
|
|
|
|
std::string initial_data = initialize_proto.raw_data();
|
|
|
|
|
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->MutableData());
|
|
|
|
|
if (tensor_data_buf == nullptr) {
|
|
|
|
|
delete tensor_info;
|
|
|
|
|
return RET_MEMORY_FAILED;
|
|
|
|
|
}
|
|
|
|
|
tensor_info->set_data(nullptr);
|
|
|
|
@ -261,6 +261,7 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node
|
|
|
|
|
|
|
|
|
|
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>();
|
|
|
|
|
if (param_value == nullptr) {
|
|
|
|
|
delete tensor_info;
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
param_value->set_tensor_addr(tensor_data_buf);
|
|
|
|
@ -367,22 +368,38 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &pr
|
|
|
|
|
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape_vector);
|
|
|
|
|
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
|
|
|
|
|
ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size());
|
|
|
|
|
if (EOK != ret) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s error";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
prim->set_attr(attr_name, MakeValue(tensor_info));
|
|
|
|
|
} else {
|
|
|
|
|
if (attr_tensor_type == onnx::TensorProto_DataType_DOUBLE) {
|
|
|
|
|
size_t data_size = sizeof(double);
|
|
|
|
|
double attr_value = 0.0;
|
|
|
|
|
ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size());
|
|
|
|
|
if (EOK != ret) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s error";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
prim->set_attr(attr_name, MakeValue<double>(attr_value));
|
|
|
|
|
} else if (attr_tensor_type == onnx::TensorProto_DataType_INT64) {
|
|
|
|
|
size_t data_size = sizeof(int64_t);
|
|
|
|
|
int64_t attr_value = 0;
|
|
|
|
|
ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size());
|
|
|
|
|
if (EOK != ret) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s error";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
prim->set_attr(attr_name, MakeValue<int64_t>(attr_value));
|
|
|
|
|
} else if (attr_tensor_type == onnx::TensorProto_DataType_BOOL) {
|
|
|
|
|
size_t data_size = sizeof(bool);
|
|
|
|
|
bool attr_value = false;
|
|
|
|
|
ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size());
|
|
|
|
|
if (EOK != ret) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s error";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
prim->set_attr(attr_name, MakeValue<bool>(attr_value));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -399,7 +416,7 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, con
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
|
|
|
|
string type;
|
|
|
|
|
string type = "";
|
|
|
|
|
std::size_t pos(0);
|
|
|
|
|
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
|
|
|
|
|
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
|
|
|
|
@ -503,7 +520,7 @@ bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &value_
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
|
|
|
|
string type;
|
|
|
|
|
string type = "";
|
|
|
|
|
std::size_t pos(0);
|
|
|
|
|
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
|
|
|
|
|
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
|
|
|
|
@ -682,9 +699,17 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
|
|
|
|
|
const onnx::ValueInfoProto &output_node = importProto.output(out_size);
|
|
|
|
|
const std::string &out_tuple = output_node.name();
|
|
|
|
|
inputs.push_back(anfnode_build_map_[out_tuple]);
|
|
|
|
|
if (anfnode_build_map_[out_tuple] == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "AnfNode is nullptr";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
elem.push_back(anfnode_build_map_[out_tuple]->abstract());
|
|
|
|
|
}
|
|
|
|
|
auto maketuple_ptr = outputFuncGraph->NewCNode(inputs);
|
|
|
|
|
if (maketuple_ptr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "maketuple_ptr is nullptr";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
|
|
|
|
inputs.clear();
|
|
|
|
|
auto primReturn = std::make_unique<schema::PrimitiveT>();
|
|
|
|
@ -857,6 +882,10 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
|
|
|
|
|
MS_LOG(ERROR) << "Parse configuration info for pb file failed!";
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
if (onnx_model_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "onnx_model_ is nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
const onnx::GraphProto &graphBuild = onnx_model_->graph();
|
|
|
|
|
status = BuildFuncGraph(dstGraph, graphBuild, quantType);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
@ -871,6 +900,11 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) {
|
|
|
|
|
|
|
|
|
|
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) {
|
|
|
|
|
auto onnx_model = new (std::nothrow) onnx::ModelProto;
|
|
|
|
|
if (onnx_model == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "New onnx ModelProto failed!";
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (RET_OK != ValidateFileStr(model_path, ".mindir")) {
|
|
|
|
|
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.mindir";
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_INPUT_PARAM_INVALID);
|
|
|
|
|