!12539 [ME]fix bug of eliminate cast pass

From: @chenfei52
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
pull/12539/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 7582ed0390

@ -25,10 +25,20 @@
namespace mindspore {
namespace opt {
namespace irpass {
AnfNodePtr TransThroughDepend(const AnfNodePtr &node) {
auto cur_node = node;
while (IsPrimitiveCNode(cur_node, prim::kPrimDepend)) {
cur_node = cur_node->cast<CNodePtr>()->input(1);
}
return cur_node;
}
bool IsValueNode(const AnfNodePtr &node) { return IsVNode(TransThroughDepend(node)); }
// {prim::kPrimCast, X, T}
AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
Reset();
AnfVisitor::Match(prim::kPrimCast, {IsNode, IsVNode})(node);
AnfVisitor::Match(prim::kPrimCast, {IsNode, IsValueNode})(node);
// check pattern match
if (tgt_ == nullptr) {
@ -50,6 +60,11 @@ AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNod
}
if (src_type->type_id() == tgt_type->type_id()) {
if (IsPrimitiveCNode(node->cast<CNodePtr>()->input(2), prim::kPrimDepend)) {
auto new_depend =
node->func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), src_, node->cast<CNodePtr>()->input(2)});
return new_depend;
}
return src_;
}
@ -57,10 +72,11 @@ AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNod
}
void CastSameTypeEliminater::Visit(const AnfNodePtr &node) {
auto cur_node = TransThroughDepend(node);
if (src_ == nullptr) {
src_ = node;
src_ = cur_node;
} else {
tgt_ = node;
tgt_ = cur_node;
}
}

Loading…
Cancel
Save