!8168 fuse 2 quant_dtype_cast when srcAttr->srcT != dstAttr->dstT

Merge pull request !8168 from cjh9368/add_uint8_default_dtype
pull/8168/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 30ae0c9779

@ -62,7 +62,7 @@ STATUS QuantCastFusionPass::DoFusion(MetaGraphT *graph, const std::string &patte
auto dstAttr = dstNode->primitive->value.AsQuantDTypeCast();
MS_ASSERT(srcAttr != nullptr);
MS_ASSERT(dstAttr != nullptr);
if (srcAttr->dstT != dstAttr->srcT || srcAttr->srcT != dstAttr->dstT) {
if (srcAttr->dstT != dstAttr->srcT) {
MS_LOG(ERROR) << "srcNode and dstNode can not been fused";
return RET_ERROR;
}
@ -73,11 +73,15 @@ STATUS QuantCastFusionPass::DoFusion(MetaGraphT *graph, const std::string &patte
return status;
}
if (srcAttr->srcT == dstAttr->dstT) {
status = IsolateOneWayNode(graph, dstPath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name.c_str() << ", error: " << status;
return status;
}
} else {
dstAttr->srcT = srcAttr->dstT;
}
return RET_OK;
}

Loading…
Cancel
Save