|
|
|
@ -44,6 +44,44 @@ STATUS FormatTransPass::Run(schema::MetaGraphT *graph) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *beforeNodeType,
|
|
|
|
|
FormatTransNodeType *afterNodeType) {
|
|
|
|
|
if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc
|
|
|
|
|
return RET_NO_CHANGE;
|
|
|
|
|
} else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw
|
|
|
|
|
if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) {
|
|
|
|
|
return RET_NO_CHANGE;
|
|
|
|
|
}
|
|
|
|
|
*beforeNodeType = kNCHW2NHWC;
|
|
|
|
|
*afterNodeType = kNHWC2NCHW;
|
|
|
|
|
return RET_OK;
|
|
|
|
|
} else if (fmkType == converter::FmkType_MS) {
|
|
|
|
|
if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) {
|
|
|
|
|
return RET_NO_CHANGE;
|
|
|
|
|
}
|
|
|
|
|
*beforeNodeType = kNCHW2NHWC;
|
|
|
|
|
*afterNodeType = kNHWC2NCHW;
|
|
|
|
|
return RET_OK;
|
|
|
|
|
} else if (fmkType == converter::FmkType_ONNX) {
|
|
|
|
|
if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) {
|
|
|
|
|
return RET_NO_CHANGE;
|
|
|
|
|
}
|
|
|
|
|
*beforeNodeType = kNCHW2NHWC;
|
|
|
|
|
*afterNodeType = kNHWC2NCHW;
|
|
|
|
|
return RET_OK;
|
|
|
|
|
} else if (fmkType == converter::FmkType_TF) {
|
|
|
|
|
if (IsContain(GetNhwcOpList(), GetCNodeTType(node)) && GetFormat(node) == schema::Format_NCHW) {
|
|
|
|
|
*beforeNodeType = kNCHW2NHWC;
|
|
|
|
|
*afterNodeType = kNHWC2NCHW;
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
return RET_NO_CHANGE;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported fmk: " << fmkType;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS FormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) {
|
|
|
|
|
if (fmkType == converter::FmkType_TF || fmkType == converter::FmkType_TFLITE) {
|
|
|
|
|
return RET_OK;
|
|
|
|
@ -53,6 +91,14 @@ STATUS FormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) {
|
|
|
|
|
if (graph->nodes.empty()) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
// onnx input format may be nhwc
|
|
|
|
|
if (fmkType == converter::FmkType_ONNX && graph->inputIndex.size() == 1) {
|
|
|
|
|
auto &input_tensor = graph->allTensors.at(graph->inputIndex[0]);
|
|
|
|
|
auto &input_dims = input_tensor->dims;
|
|
|
|
|
if (input_dims.size() == 4 && input_dims[3] != -1 && input_dims[1] == -1) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto graphInputIdxes = graph->inputIndex;
|
|
|
|
|
for (size_t i = 0; i < graphInputIdxes.size(); i++) {
|
|
|
|
|
bool transed = false;
|
|
|
|
@ -100,38 +146,15 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
|
|
|
|
|
MS_ASSERT(graph != nullptr);
|
|
|
|
|
// insert before and after the op cal by nchw/nc4hw4
|
|
|
|
|
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
|
|
|
|
FormatTransNodeType beforeNodeType, afterNodeType;
|
|
|
|
|
if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc
|
|
|
|
|
FormatTransNodeType beforeNodeType = kNCHW2NHWC;
|
|
|
|
|
FormatTransNodeType afterNodeType = kNHWC2NCHW;
|
|
|
|
|
STATUS status = RET_OK;
|
|
|
|
|
status = GetInsertFormatTrans(**iter, &beforeNodeType, &afterNodeType);
|
|
|
|
|
if (status == RET_NO_CHANGE) {
|
|
|
|
|
continue;
|
|
|
|
|
} else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw
|
|
|
|
|
if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
beforeNodeType = kNCHW2NHWC;
|
|
|
|
|
afterNodeType = kNHWC2NCHW;
|
|
|
|
|
} else if (fmkType == converter::FmkType_MS) {
|
|
|
|
|
if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
beforeNodeType = kNCHW2NHWC;
|
|
|
|
|
afterNodeType = kNHWC2NCHW;
|
|
|
|
|
} else if (fmkType == converter::FmkType_ONNX) {
|
|
|
|
|
if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
beforeNodeType = kNCHW2NHWC;
|
|
|
|
|
afterNodeType = kNHWC2NCHW;
|
|
|
|
|
} else if (fmkType == converter::FmkType_TF) {
|
|
|
|
|
auto &node = *iter;
|
|
|
|
|
if (IsContain(GetNhwcOpList(), GetCNodeTType(**iter)) && GetFormat(node) == schema::Format_NCHW) {
|
|
|
|
|
beforeNodeType = kNCHW2NHWC;
|
|
|
|
|
afterNodeType = kNHWC2NCHW;
|
|
|
|
|
} else {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported fmk: " << fmkType;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
auto &node = *iter;
|
|
|
|
|
auto nodeName = node->name;
|
|
|
|
@ -150,7 +173,6 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
|
|
|
|
|
if (node->primitive->value.type == schema::PrimitiveType_DepthToSpace) {
|
|
|
|
|
reinterpret_cast<schema::DepthToSpaceT *>(attr)->format = schema::Format_NHWC;
|
|
|
|
|
}
|
|
|
|
|
STATUS status = RET_OK;
|
|
|
|
|
#ifdef SUPPORT_TRAIN
|
|
|
|
|
if (IsContain(GetNhwcAllInputOpList(), GetCNodeTType(**iter))) {
|
|
|
|
|
int idx_num = node->inputIndex.size();
|
|
|
|
@ -250,18 +272,18 @@ void FormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quan
|
|
|
|
|
|
|
|
|
|
void FormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkType; }
|
|
|
|
|
|
|
|
|
|
int FormatTransPass::GetFormat(const std::unique_ptr<CNodeT> &node) {
|
|
|
|
|
switch (node->primitive->value.type) {
|
|
|
|
|
int FormatTransPass::GetFormat(const schema::CNodeT &node) {
|
|
|
|
|
switch (node.primitive->value.type) {
|
|
|
|
|
case schema::PrimitiveType_Conv2D:
|
|
|
|
|
return node->primitive->value.AsConv2D()->format;
|
|
|
|
|
return node.primitive->value.AsConv2D()->format;
|
|
|
|
|
case schema::PrimitiveType_DeConv2D:
|
|
|
|
|
return node->primitive->value.AsDeConv2D()->format;
|
|
|
|
|
return node.primitive->value.AsDeConv2D()->format;
|
|
|
|
|
case schema::PrimitiveType_DeDepthwiseConv2D:
|
|
|
|
|
return node->primitive->value.AsDeDepthwiseConv2D()->format;
|
|
|
|
|
return node.primitive->value.AsDeDepthwiseConv2D()->format;
|
|
|
|
|
case schema::PrimitiveType_DepthwiseConv2D:
|
|
|
|
|
return node->primitive->value.AsDepthwiseConv2D()->format;
|
|
|
|
|
return node.primitive->value.AsDepthwiseConv2D()->format;
|
|
|
|
|
case schema::PrimitiveType_Pooling:
|
|
|
|
|
return node->primitive->value.AsPooling()->format;
|
|
|
|
|
return node.primitive->value.AsPooling()->format;
|
|
|
|
|
default:
|
|
|
|
|
return schema::Format_NHWC;
|
|
|
|
|
}
|
|
|
|
|