|
|
|
|
@ -16,6 +16,7 @@
|
|
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h"
|
|
|
|
|
#include "tools/common/converter_op_utils.h"
|
|
|
|
|
@ -117,48 +118,86 @@ STATUS TransOpInsertPass::FindOutTransType() {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) {
|
|
|
|
|
if (node == nullptr && node->primitive == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "node or primitive null";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
auto type = node->primitive->value.type;
|
|
|
|
|
if (graph->allTensors.at(node->inputIndex[0])->dims.size() != 4) {
|
|
|
|
|
MS_LOG(ERROR) << "change op axis only support 4 dims";
|
|
|
|
|
return RET_NOT_SUPPORT;
|
|
|
|
|
}
|
|
|
|
|
if (type == PrimitiveType_Concat) {
|
|
|
|
|
auto origin_axis = node->primitive->value.AsConcat()->axis;
|
|
|
|
|
auto axis_map = GetNc2NhAxisMap();
|
|
|
|
|
node->primitive->value.AsConcat()->axis = axis_map[origin_axis];
|
|
|
|
|
}
|
|
|
|
|
if (type == PrimitiveType_StridedSlice) {
|
|
|
|
|
auto attr = node->primitive->value.AsStridedSlice();
|
|
|
|
|
auto origin_begin = attr->begin;
|
|
|
|
|
attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]};
|
|
|
|
|
auto origin_end = attr->end;
|
|
|
|
|
attr->end = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]};
|
|
|
|
|
auto origin_stride = attr->stride;
|
|
|
|
|
attr->stride = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]};
|
|
|
|
|
}
|
|
|
|
|
if (type == PrimitiveType_Split) {
|
|
|
|
|
auto origin_axis = node->primitive->value.AsSplit()->splitDim;
|
|
|
|
|
auto axis_map = GetNc2NhAxisMap();
|
|
|
|
|
node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis];
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
|
|
|
|
|
MS_ASSERT(graph != nullptr);
|
|
|
|
|
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
|
|
|
|
auto &node = *iter;
|
|
|
|
|
auto type = node->primitive->value.type;
|
|
|
|
|
if (!IsContain(GetInsertOpList(), type)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto node_name = node->name;
|
|
|
|
|
if (!CanFusion(graph, node)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto ret = FindOutTransType();
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "FindOutTransType error";
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
// 4 dims means infershape success,can delete
|
|
|
|
|
if (type == PrimitiveType_Concat) {
|
|
|
|
|
if (graph->allTensors.at(node->inputIndex[0])->dims.size() == 4) {
|
|
|
|
|
node->primitive->value.AsConcat()->axis = -1;
|
|
|
|
|
} else {
|
|
|
|
|
bool changed = true;
|
|
|
|
|
int run_counts = 0;
|
|
|
|
|
std::vector<CNodeT *> has_insert_nodes;
|
|
|
|
|
while (changed && run_counts < 10) {
|
|
|
|
|
changed = false;
|
|
|
|
|
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
|
|
|
|
auto &node = *iter;
|
|
|
|
|
auto type = node->primitive->value.type;
|
|
|
|
|
if (IsContain(has_insert_nodes, node.get()) || !IsContain(GetInsertOpList(), type)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
STATUS status = RET_OK;
|
|
|
|
|
auto input_tensor_size = (*iter)->inputIndex.size();
|
|
|
|
|
for (size_t i = 0; i < input_tensor_size; i++) {
|
|
|
|
|
iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type_, &status);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed";
|
|
|
|
|
return status;
|
|
|
|
|
auto node_name = node->name;
|
|
|
|
|
if (!CanFusion(graph, node)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto output_tensor_size = (*iter)->outputIndex.size();
|
|
|
|
|
for (size_t i = 0; i < output_tensor_size; i++) {
|
|
|
|
|
iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type_, &status);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Insert" << post_insert_trans_type_ << "Node before " << (*iter)->name << " failed";
|
|
|
|
|
return status;
|
|
|
|
|
auto ret = FindOutTransType();
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "FindOutTransType error";
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
ret = ChangeOpAxis(graph, node);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "ChangeOpAxis error";
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
has_insert_nodes.push_back(node.get());
|
|
|
|
|
STATUS status = RET_OK;
|
|
|
|
|
auto input_tensor_size = (*iter)->inputIndex.size();
|
|
|
|
|
for (size_t i = 0; i < input_tensor_size; i++) {
|
|
|
|
|
iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type_, &status);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed";
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto output_tensor_size = (*iter)->outputIndex.size();
|
|
|
|
|
for (size_t i = 0; i < output_tensor_size; i++) {
|
|
|
|
|
iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type_, &status);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Insert" << post_insert_trans_type_ << "Node before " << (*iter)->name << " failed";
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
changed = true;
|
|
|
|
|
}
|
|
|
|
|
run_counts++;
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|