|
|
|
@ -125,6 +125,74 @@ STATUS TransOpInsertPass::FindOutTransType() {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TransOpInsertPass::TransformAttrByAxes(int *origin_attr, int *axes, int element_size) {
|
|
|
|
|
if (origin_attr == nullptr || axes == nullptr || element_size == 0) {
|
|
|
|
|
MS_LOG(INFO) << "Attr data is from other nodes.";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto axis_map = GetNc2NhAxisMap();
|
|
|
|
|
std::vector<int> cur_attr;
|
|
|
|
|
for (int dim = 0; dim < 4; ++dim) {
|
|
|
|
|
for (int index = 0; index < element_size; ++index) {
|
|
|
|
|
int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + 4 : axes[index]];
|
|
|
|
|
if (nhwc_dim == dim || (nhwc_dim + 4) == dim) {
|
|
|
|
|
cur_attr.push_back(origin_attr[index]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (int index = 0; index < element_size; ++index) {
|
|
|
|
|
origin_attr[index] = cur_attr[index];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS TransOpInsertPass::ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) {
|
|
|
|
|
if (node == nullptr && node->primitive == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "node or primitive null";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
auto type = node->primitive->value.type;
|
|
|
|
|
if (type == PrimitiveType_StridedSlice) {
|
|
|
|
|
// onnx input size is equal to 5 always.
|
|
|
|
|
if (node->inputIndex.size() == 5) {
|
|
|
|
|
for (int index = 1; index < 5; ++index) {
|
|
|
|
|
if (graph->allTensors[node->inputIndex[index]]->data.data() == nullptr) {
|
|
|
|
|
MS_LOG(INFO) << "Here don't consider input is from other nodes.";
|
|
|
|
|
return RET_NOT_SUPPORT;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
int element_num = graph->allTensors[node->inputIndex[1]]->dims[0];
|
|
|
|
|
auto axes = graph->allTensors[node->inputIndex[3]]->data;
|
|
|
|
|
for (int index = 1; index < 5; ++index) {
|
|
|
|
|
TransformAttrByAxes(reinterpret_cast<int *>(graph->allTensors[node->inputIndex[index]]->data.data()),
|
|
|
|
|
reinterpret_cast<int *>(axes.data()), element_num);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (type == PrimitiveType_Slice) {
|
|
|
|
|
auto attr = node->primitive->value.AsSlice();
|
|
|
|
|
if (attr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "node->primitive->value.AsSlice() is nullptr.";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
// transform attr
|
|
|
|
|
attr->format = schema::Format_NHWC;
|
|
|
|
|
if (attr->begin.empty() || attr->size.empty()) {
|
|
|
|
|
MS_LOG(INFO) << "Here don't consider these attr are from other nodes.";
|
|
|
|
|
return RET_NOT_SUPPORT;
|
|
|
|
|
}
|
|
|
|
|
int element_num = attr->begin.size();
|
|
|
|
|
if (attr->axes.empty()) {
|
|
|
|
|
for (int index = 0; index < element_num; ++index) {
|
|
|
|
|
attr->axes.push_back(index);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
TransformAttrByAxes(attr->begin.data(), attr->axes.data(), element_num);
|
|
|
|
|
TransformAttrByAxes(attr->size.data(), attr->axes.data(), element_num);
|
|
|
|
|
TransformAttrByAxes(attr->axes.data(), attr->axes.data(), element_num);
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) {
|
|
|
|
|
if (node == nullptr && node->primitive == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "node or primitive null";
|
|
|
|
@ -153,19 +221,6 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni
|
|
|
|
|
}
|
|
|
|
|
node->primitive->value.AsConcat()->axis = axis_map[origin_axis];
|
|
|
|
|
}
|
|
|
|
|
if (type == PrimitiveType_StridedSlice) {
|
|
|
|
|
auto attr = node->primitive->value.AsStridedSlice();
|
|
|
|
|
if (attr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "node->primitive->value.AsStridedSlice() is nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
auto origin_begin = attr->begin;
|
|
|
|
|
attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]};
|
|
|
|
|
auto origin_end = attr->end;
|
|
|
|
|
attr->end = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]};
|
|
|
|
|
auto origin_stride = attr->stride;
|
|
|
|
|
attr->stride = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]};
|
|
|
|
|
}
|
|
|
|
|
if (type == PrimitiveType_Split) {
|
|
|
|
|
auto origin_axis = node->primitive->value.AsSplit()->splitDim;
|
|
|
|
|
auto axis_map = GetNc2NhAxisMap();
|
|
|
|
@ -200,20 +255,8 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni
|
|
|
|
|
}
|
|
|
|
|
node->primitive->value.AsCrop()->offsets = offsets;
|
|
|
|
|
}
|
|
|
|
|
if (type == PrimitiveType_Slice) {
|
|
|
|
|
auto attr = node->primitive->value.AsSlice();
|
|
|
|
|
if (attr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "node->primitive->value.AsSlice() is nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
auto origin_begin = attr->begin;
|
|
|
|
|
attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]};
|
|
|
|
|
auto origin_end = attr->axes;
|
|
|
|
|
if (origin_end.size() >= 4) {
|
|
|
|
|
attr->axes = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]};
|
|
|
|
|
}
|
|
|
|
|
auto origin_stride = attr->size;
|
|
|
|
|
attr->size = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]};
|
|
|
|
|
if (type == PrimitiveType_Slice || type == PrimitiveType_StridedSlice) {
|
|
|
|
|
return ChangeOpAttrForSlice(graph, node);
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
@ -246,7 +289,7 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
|
|
|
|
|
}
|
|
|
|
|
ret = ChangeOpAxis(graph, node);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "ChangeOpAxis error";
|
|
|
|
|
MS_LOG(INFO) << "no need to ChangeOpAxis";
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
has_insert_nodes.push_back(node.get());
|
|
|
|
@ -258,6 +301,10 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
|
|
|
|
|
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed";
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
if ((*iter)->primitive->value.type == schema::PrimitiveType_StridedSlice ||
|
|
|
|
|
(*iter)->primitive->value.type == schema::PrimitiveType_Slice) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto output_tensor_size = (*iter)->outputIndex.size();
|
|
|
|
|
for (size_t i = 0; i < output_tensor_size; i++) {
|
|
|
|
|