!8407 [lite] fix bugs for onnx single operator.

From: @xu_anyue
Reviewed-by: @hangangqiang
Signed-off-by:
pull/8407/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 93907bd83d

@ -553,7 +553,6 @@ table Slice {
axes: [int];
begin: [int];
size: [int];
step: [int];
}
table Floor {

@ -137,6 +137,13 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
MS_LOG(ERROR) << "Op should have " << kOutputNum << " output tensor";
return RET_ERROR;
}
void *attr = node->primitive->value.value;
if (node->primitive->value.type == schema::PrimitiveType_SpaceToDepth) {
reinterpret_cast<schema::SpaceToDepthT *>(attr)->format = schema::Format_NHWC;
}
if (node->primitive->value.type == schema::PrimitiveType_DepthToSpace) {
reinterpret_cast<schema::DepthToSpaceT *>(attr)->format = schema::Format_NHWC;
}
STATUS status = RET_OK;
#ifdef SUPPORT_TRAIN
if (IsContain(GetNhwcAllInputOpList(), GetCNodeTType(**iter))) {

@ -125,6 +125,74 @@ STATUS TransOpInsertPass::FindOutTransType() {
return RET_OK;
}
void TransOpInsertPass::TransformAttrByAxes(int *origin_attr, int *axes, int element_size) {
if (origin_attr == nullptr || axes == nullptr || element_size == 0) {
MS_LOG(INFO) << "Attr data is from other nodes.";
return;
}
auto axis_map = GetNc2NhAxisMap();
std::vector<int> cur_attr;
for (int dim = 0; dim < 4; ++dim) {
for (int index = 0; index < element_size; ++index) {
int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + 4 : axes[index]];
if (nhwc_dim == dim || (nhwc_dim + 4) == dim) {
cur_attr.push_back(origin_attr[index]);
}
}
}
for (int index = 0; index < element_size; ++index) {
origin_attr[index] = cur_attr[index];
}
}
STATUS TransOpInsertPass::ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) {
if (node == nullptr && node->primitive == nullptr) {
MS_LOG(ERROR) << "node or primitive null";
return RET_NULL_PTR;
}
auto type = node->primitive->value.type;
if (type == PrimitiveType_StridedSlice) {
// onnx input size is equal to 5 always.
if (node->inputIndex.size() == 5) {
for (int index = 1; index < 5; ++index) {
if (graph->allTensors[node->inputIndex[index]]->data.data() == nullptr) {
MS_LOG(INFO) << "Here don't consider input is from other nodes.";
return RET_NOT_SUPPORT;
}
}
int element_num = graph->allTensors[node->inputIndex[1]]->dims[0];
auto axes = graph->allTensors[node->inputIndex[3]]->data;
for (int index = 1; index < 5; ++index) {
TransformAttrByAxes(reinterpret_cast<int *>(graph->allTensors[node->inputIndex[index]]->data.data()),
reinterpret_cast<int *>(axes.data()), element_num);
}
}
}
if (type == PrimitiveType_Slice) {
auto attr = node->primitive->value.AsSlice();
if (attr == nullptr) {
MS_LOG(ERROR) << "node->primitive->value.AsSlice() is nullptr.";
return RET_NULL_PTR;
}
// transform attr
attr->format = schema::Format_NHWC;
if (attr->begin.empty() || attr->size.empty()) {
MS_LOG(INFO) << "Here don't consider these attr are from other nodes.";
return RET_NOT_SUPPORT;
}
int element_num = attr->begin.size();
if (attr->axes.empty()) {
for (int index = 0; index < element_num; ++index) {
attr->axes.push_back(index);
}
}
TransformAttrByAxes(attr->begin.data(), attr->axes.data(), element_num);
TransformAttrByAxes(attr->size.data(), attr->axes.data(), element_num);
TransformAttrByAxes(attr->axes.data(), attr->axes.data(), element_num);
}
return RET_OK;
}
STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) {
if (node == nullptr && node->primitive == nullptr) {
MS_LOG(ERROR) << "node or primitive null";
@ -153,19 +221,6 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni
}
node->primitive->value.AsConcat()->axis = axis_map[origin_axis];
}
if (type == PrimitiveType_StridedSlice) {
auto attr = node->primitive->value.AsStridedSlice();
if (attr == nullptr) {
MS_LOG(ERROR) << "node->primitive->value.AsStridedSlice() is nullptr";
return RET_NULL_PTR;
}
auto origin_begin = attr->begin;
attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]};
auto origin_end = attr->end;
attr->end = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]};
auto origin_stride = attr->stride;
attr->stride = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]};
}
if (type == PrimitiveType_Split) {
auto origin_axis = node->primitive->value.AsSplit()->splitDim;
auto axis_map = GetNc2NhAxisMap();
@ -200,20 +255,8 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni
}
node->primitive->value.AsCrop()->offsets = offsets;
}
if (type == PrimitiveType_Slice) {
auto attr = node->primitive->value.AsSlice();
if (attr == nullptr) {
MS_LOG(ERROR) << "node->primitive->value.AsSlice() is nullptr";
return RET_NULL_PTR;
}
auto origin_begin = attr->begin;
attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]};
auto origin_end = attr->axes;
if (origin_end.size() >= 4) {
attr->axes = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]};
}
auto origin_stride = attr->size;
attr->size = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]};
if (type == PrimitiveType_Slice || type == PrimitiveType_StridedSlice) {
return ChangeOpAttrForSlice(graph, node);
}
return RET_OK;
}
@ -246,7 +289,7 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
}
ret = ChangeOpAxis(graph, node);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ChangeOpAxis error";
MS_LOG(INFO) << "no need to ChangeOpAxis";
return ret;
}
has_insert_nodes.push_back(node.get());
@ -258,6 +301,10 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed";
return status;
}
if ((*iter)->primitive->value.type == schema::PrimitiveType_StridedSlice ||
(*iter)->primitive->value.type == schema::PrimitiveType_Slice) {
break;
}
}
auto output_tensor_size = (*iter)->outputIndex.size();
for (size_t i = 0; i < output_tensor_size; i++) {

@ -37,6 +37,10 @@ class TransOpInsertPass : public FormatTransPass {
STATUS FindOutTransType();
void TransformAttrByAxes(int *origin_attr, int *axes, int element_size);
STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node);
STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node);
private:

Loading…
Cancel
Save