From 2b1837afa5de0e37d791dde4b634d9415ea1e2e5 Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Tue, 3 Nov 2020 17:25:15 +0800 Subject: [PATCH] fuse 2 quant_dtype_cast when srcAttr->srcT != dstAttr->dstT --- .../fusion/quant_cast_fusion_pass.cc | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.cc index d65b23224c..6894f5b4fa 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.cc @@ -62,7 +62,7 @@ STATUS QuantCastFusionPass::DoFusion(MetaGraphT *graph, const std::string &patte auto dstAttr = dstNode->primitive->value.AsQuantDTypeCast(); MS_ASSERT(srcAttr != nullptr); MS_ASSERT(dstAttr != nullptr); - if (srcAttr->dstT != dstAttr->srcT || srcAttr->srcT != dstAttr->dstT) { + if (srcAttr->dstT != dstAttr->srcT) { MS_LOG(ERROR) << "srcNode and dstNode can not been fused"; return RET_ERROR; } @@ -73,10 +73,14 @@ STATUS QuantCastFusionPass::DoFusion(MetaGraphT *graph, const std::string &patte return status; } - status = IsolateOneWayNode(graph, dstPath->nodeIdx); - if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name.c_str() << ", error: " << status; - return status; + if (srcAttr->srcT == dstAttr->dstT) { + status = IsolateOneWayNode(graph, dstPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name.c_str() << ", error: " << status; + return status; + } + } else { + dstAttr->srcT = srcAttr->dstT; } return RET_OK;