|
|
|
@ -151,7 +151,6 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (data_type == kNumberTypeInt64) {
|
|
|
|
|
MS_LOG(ERROR) << "INT64" << proto.name();
|
|
|
|
|
tensor->dataType = kNumberTypeInt32; // CopyOnnxTensorData will convert int64 to int32
|
|
|
|
|
}
|
|
|
|
|
*index = tensor_cache->AddTensor(name, tensor.release(), type);
|
|
|
|
@ -168,7 +167,7 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(ERROR) << "input_value name: " << input_value.name() << ", graph input index: " << index;
|
|
|
|
|
MS_LOG(DEBUG) << "input_value name: " << input_value.name() << ", graph input index: " << index;
|
|
|
|
|
graph->inputIndex.emplace_back(static_cast<uint32_t>(index));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -184,7 +183,7 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
graph->outputIndex.emplace_back(index);
|
|
|
|
|
MS_LOG(ERROR) << "output_value name: " << output_value.name() << ", graph output index: " << index;
|
|
|
|
|
MS_LOG(DEBUG) << "output_value name: " << output_value.name() << ", graph output index: " << index;
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
@ -399,10 +398,9 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs
|
|
|
|
|
STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, schema::TensorT *tensor) {
|
|
|
|
|
size_t data_count = 1;
|
|
|
|
|
std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; });
|
|
|
|
|
MS_LOG(ERROR) << "const tensor dims " << tensor->dims.size();
|
|
|
|
|
size_t data_size = 0;
|
|
|
|
|
const void *tensor_data = nullptr;
|
|
|
|
|
int32_t *buffer = nullptr;
|
|
|
|
|
std::unique_ptr<int32_t[]> buffer;
|
|
|
|
|
switch (tensor->dataType) {
|
|
|
|
|
case kNumberTypeFloat32:
|
|
|
|
|
data_size = data_count * sizeof(float);
|
|
|
|
@ -422,7 +420,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
|
|
|
|
|
break;
|
|
|
|
|
case kNumberTypeInt64:
|
|
|
|
|
data_size = data_count * sizeof(int32_t);
|
|
|
|
|
buffer = new int32_t[data_count];
|
|
|
|
|
buffer = std::make_unique<int32_t[]>(data_count);
|
|
|
|
|
const int64_t *in_data;
|
|
|
|
|
if (onnx_const_value.int64_data_size() == 0) {
|
|
|
|
|
in_data = reinterpret_cast<const int64_t *>(onnx_const_value.raw_data().data());
|
|
|
|
@ -437,7 +435,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
|
|
|
|
|
buffer[i] = static_cast<int>(in_data[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
tensor_data = reinterpret_cast<void *>(buffer);
|
|
|
|
|
tensor_data = reinterpret_cast<void *>(buffer.get());
|
|
|
|
|
break;
|
|
|
|
|
case kNumberTypeUInt8:
|
|
|
|
|
case kNumberTypeInt8:
|
|
|
|
@ -453,9 +451,6 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (kNumberTypeInt64 == tensor->dataType) {
|
|
|
|
|
free(buffer);
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|