mslite fix onnx conv\reshape\dropout\matmul parser, fix infershape_pass bug

pull/5843/head
xuanyue 4 years ago
parent 3c7b668d63
commit a4aa06db53

@ -549,6 +549,7 @@ table NetOutput {
}
table MatMul {
broadcast : bool = false;
transposeA : bool = false;
transposeB : bool = false;
}

@ -80,7 +80,7 @@ int MatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:
MS_LOG(ERROR) << "value_as_MatMul return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateMatMul(*fbb, attr->transposeA(), attr->transposeB());
auto val_offset = schema::CreateMatMul(*fbb, attr->broadcast(), attr->transposeA(), attr->transposeB());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_MatMul, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;

@ -122,6 +122,7 @@
#include "src/ops/l2_norm.h"
#include "src/ops/sparse_to_dense.h"
#include "src/ops/detection_post_process.h"
#include "src/ops/dropout.h"
#ifdef PRIMITIVE_WRITEABLE
#include "tools/converter/quantizer/quantize_util.h"
#endif
@ -617,6 +618,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new SparseToDense(primitive);
case schema::PrimitiveType_DetectionPostProcess:
return new DetectionPostProcess(primitive);
case schema::PrimitiveType_Dropout:
return new Dropout(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:
@ -866,6 +869,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) {
return NewPrimitiveC<SparseToDense>(primitive);
case schema::PrimitiveType_DetectionPostProcess:
return NewPrimitiveC<DetectionPostProcess>(primitive);
case schema::PrimitiveType_Dropout:
return NewPrimitiveC<Dropout>(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:

@ -38,26 +38,22 @@ std::vector<Tensor *> ConvertTensorToLiteTensor(MetaGraphT *graph, const std::ve
MS_LOG(ERROR) << "lite tensor is nullptr";
return std::vector<Tensor *>();
}
// reshape op must get tensor data to infershape
if (node_type == schema::PrimitiveType_Reshape && i == 1 && tensorT->nodeType == NodeType_ValueNode) {
auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t);
// when tensorT as param input
if (lite_tensor_size == 0) {
return std::vector<Tensor *>();
}
auto ret = lite_tensor->MallocData();
if (ret != 0) {
MS_LOG(ERROR) << "Malloc tensor data failed";
return std::vector<Tensor *>();
}
ret = memcpy_s(lite_tensor->MutableData(), lite_tensor->Size(), tensorT->data.data(), lite_tensor_size);
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return std::vector<Tensor *>();
}
auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t);
// when tensorT as param input
if (lite_tensor_size == 0) {
lite_tensors.emplace_back(lite_tensor.release());
continue;
}
auto ret = lite_tensor->MallocData();
if (ret != 0) {
MS_LOG(ERROR) << "Malloc tensor data failed";
return std::vector<Tensor *>();
}
ret = memcpy_s(lite_tensor->MutableData(), lite_tensor->Size(), tensorT->data.data(), lite_tensor_size);
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return std::vector<Tensor *>();
}
lite_tensors.emplace_back(lite_tensor.release());
}
return lite_tensors;

@ -70,7 +70,13 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
// set default params
attr->strideH = 1;
attr->strideW = 1;
attr->dilateH = 1;
attr->dilateW = 1;
attr->group = 1;
attr->padMode = schema::PadMode_NOTSET;
// set opdef each attr params
for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "group") {
@ -165,7 +171,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
if (attr->group != 1) {
if (attr->group == attr->channelOut) {
if (!ParseGroupConvolution(attr, op)) {
MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
return RET_ERROR;

@ -41,7 +41,7 @@ STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "ratio") {
attr->ratio = static_cast<int32_t>(onnx_node_attr.i());
attr->ratio = static_cast<float>(onnx_node_attr.f());
}
}

@ -42,6 +42,9 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
float beta = 1.0f;
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "broadcast") {
attr->broadcast = static_cast<bool>(onnx_node_attr.i());
}
if (attribute_name == "transA") {
attr->transposeA = static_cast<bool>(onnx_node_attr.i());
} else if (attribute_name == "transB") {

@ -39,29 +39,18 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
return RET_NULL_PTR;
}
attr->format = schema::Format::Format_NCHW;
std::vector<onnx::TensorProto> params;
for (int i = 0; i < onnx_node.input_size(); ++i) {
const auto &input_name = onnx_node.input(i);
for (const auto &it : onnx_graph.initializer()) {
if (it.name() == input_name) {
params.emplace_back(it);
break;
attr->format = schema::Format_NCHW;
std::vector<int64_t> shape;
shape.clear();
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "shape") {
for (int i = 0; i < onnx_node_attr.ints_size(); ++i) {
shape.push_back(static_cast<int64_t>(onnx_node_attr.ints(i)));
}
}
}
if (params.empty()) {
MS_LOG(DEBUG) << "shape from another op other than const initializer";
} else {
if (params.size() != 1) {
MS_LOG(ERROR) << "shape param num is " << params.size() << ", not equal to 1";
return RET_ERROR;
}
for (int i = 0; i < params[0].int64_data_size(); ++i) {
attr->shape.emplace_back(params[0].int64_data(i));
}
}
attr->shape = shape;
op->primitive->value.type = schema::PrimitiveType_Reshape;
op->primitive->value.value = attr.release();

@ -144,7 +144,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->power = 0.0f;
attr->power = 1.0f;
attr->scale = 1.0f;
attr->shift = 0.0f;
op->primitive->value.type = schema::PrimitiveType_Power;

Loading…
Cancel
Save