|
|
|
@ -62,7 +62,7 @@ STATUS QuantCastFusionPass::DoFusion(MetaGraphT *graph, const std::string &patte
|
|
|
|
|
auto dstAttr = dstNode->primitive->value.AsQuantDTypeCast();
|
|
|
|
|
MS_ASSERT(srcAttr != nullptr);
|
|
|
|
|
MS_ASSERT(dstAttr != nullptr);
|
|
|
|
|
if (srcAttr->dstT != dstAttr->srcT || srcAttr->srcT != dstAttr->dstT) {
|
|
|
|
|
if (srcAttr->dstT != dstAttr->srcT) {
|
|
|
|
|
MS_LOG(ERROR) << "srcNode and dstNode can not been fused";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
@ -73,11 +73,15 @@ STATUS QuantCastFusionPass::DoFusion(MetaGraphT *graph, const std::string &patte
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (srcAttr->srcT == dstAttr->dstT) {
|
|
|
|
|
status = IsolateOneWayNode(graph, dstPath->nodeIdx);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name.c_str() << ", error: " << status;
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
dstAttr->srcT = srcAttr->dstT;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|