|
|
|
@ -244,6 +244,16 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
// set quantParams to Int8GivenTensor.
|
|
|
|
|
std::unique_ptr<schema::QuantParamT> quant_param = std::make_unique<schema::QuantParamT>();
|
|
|
|
|
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
|
|
|
|
if (onnx_node_attr.name() == "Y_scale") {
|
|
|
|
|
quant_param->scale = onnx_node_attr.f();
|
|
|
|
|
} else if (onnx_node_attr.name() == "Y_zero_point") {
|
|
|
|
|
quant_param->zeroPoint = static_cast<int32_t>(onnx_node_attr.i());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
tensor->quantParams.emplace_back(std::move(quant_param));
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "unsupported data type " << tensor->dataType;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
@ -256,9 +266,8 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
|
|
|
|
schema::CNodeT *dst_op, schema::TensorT *dst_tensor,
|
|
|
|
|
TensorCache *tensor_cache, const QuantType &quantType,
|
|
|
|
|
schema::MetaGraphT *dst_graph) {
|
|
|
|
|
schema::CNodeT *dst_op, TensorCache *tensor_cache,
|
|
|
|
|
const QuantType &quantType, schema::MetaGraphT *dst_graph) {
|
|
|
|
|
// change op_type() to name(), that is unique
|
|
|
|
|
static bool interrupt = false;
|
|
|
|
|
dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0);
|
|
|
|
@ -267,7 +276,6 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
MS_LOG(DEBUG) << "onnx op name " << onnx_node.op_type() << ", dst op name: " << dst_op->name << ", input size "
|
|
|
|
|
<< onnx_node.input_size();
|
|
|
|
|
// get the real op type
|
|
|
|
|
SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache);
|
|
|
|
|
if (onnx_node.op_type() == "Loop") {
|
|
|
|
|
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
|
|
|
|
|
interrupt = true;
|
|
|
|
@ -305,6 +313,13 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
MS_LOG(ERROR) << "SetOpInputIndex failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (dst_op->primitive->value.type == schema::PrimitiveType_Conv2D) {
|
|
|
|
|
auto &weight_tensor = tensor_cache->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex));
|
|
|
|
|
weight_tensor->format = dst_op->primitive->value.AsConv2D()->format;
|
|
|
|
|
} else if (dst_op->primitive->value.type == schema::PrimitiveType_DeConv2D) {
|
|
|
|
|
auto &weight_tensor = tensor_cache->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex));
|
|
|
|
|
weight_tensor->format = dst_op->primitive->value.AsDeConv2D()->format;
|
|
|
|
|
}
|
|
|
|
|
// set op output index
|
|
|
|
|
std::vector<string> node_outputs;
|
|
|
|
|
(void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end());
|
|
|
|
@ -314,6 +329,13 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
MS_LOG(ERROR) << "SetOpOutputIndex failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto &output_tensor = tensor_cache->GetCachedTensor().at(dst_op->outputIndex.front());
|
|
|
|
|
if (output_tensor == nullptr) {
|
|
|
|
|
interrupt = true;
|
|
|
|
|
MS_LOG(ERROR) << "Output tensor of node " << onnx_node.op_type() << "is nullptr.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
SetOpQuantParams(onnx_graph, onnx_node, dst_op, output_tensor, tensor_cache);
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -572,9 +594,7 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>();
|
|
|
|
|
std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>();
|
|
|
|
|
status_node =
|
|
|
|
|
ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType, dst_graph);
|
|
|
|
|
status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), &tensor_cache, quantType, dst_graph);
|
|
|
|
|
if (status_node != RET_OK) {
|
|
|
|
|
status = (status == RET_OK ? status_node : status);
|
|
|
|
|
continue;
|
|
|
|
|