|
|
|
@ -25,10 +25,20 @@
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
|
namespace irpass {
|
|
|
|
|
AnfNodePtr TransThroughDepend(const AnfNodePtr &node) {
|
|
|
|
|
auto cur_node = node;
|
|
|
|
|
while (IsPrimitiveCNode(cur_node, prim::kPrimDepend)) {
|
|
|
|
|
cur_node = cur_node->cast<CNodePtr>()->input(1);
|
|
|
|
|
}
|
|
|
|
|
return cur_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsValueNode(const AnfNodePtr &node) { return IsVNode(TransThroughDepend(node)); }
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimCast, X, T}
|
|
|
|
|
AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
|
|
|
|
Reset();
|
|
|
|
|
AnfVisitor::Match(prim::kPrimCast, {IsNode, IsVNode})(node);
|
|
|
|
|
AnfVisitor::Match(prim::kPrimCast, {IsNode, IsValueNode})(node);
|
|
|
|
|
|
|
|
|
|
// check pattern match
|
|
|
|
|
if (tgt_ == nullptr) {
|
|
|
|
@ -50,6 +60,11 @@ AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNod
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (src_type->type_id() == tgt_type->type_id()) {
|
|
|
|
|
if (IsPrimitiveCNode(node->cast<CNodePtr>()->input(2), prim::kPrimDepend)) {
|
|
|
|
|
auto new_depend =
|
|
|
|
|
node->func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), src_, node->cast<CNodePtr>()->input(2)});
|
|
|
|
|
return new_depend;
|
|
|
|
|
}
|
|
|
|
|
return src_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -57,10 +72,11 @@ AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNod
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CastSameTypeEliminater::Visit(const AnfNodePtr &node) {
|
|
|
|
|
auto cur_node = TransThroughDepend(node);
|
|
|
|
|
if (src_ == nullptr) {
|
|
|
|
|
src_ = node;
|
|
|
|
|
src_ = cur_node;
|
|
|
|
|
} else {
|
|
|
|
|
tgt_ = node;
|
|
|
|
|
tgt_ = cur_node;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|