|
|
|
@ -95,37 +95,37 @@ AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePt
|
|
|
|
|
// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
|
|
|
|
|
// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
|
|
|
|
|
AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
|
|
|
|
Reset();
|
|
|
|
|
// {prim::kPrimAddN, Zs}
|
|
|
|
|
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto addn = node->cast<CNodePtr>();
|
|
|
|
|
if (addn->size() != 2) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
|
|
|
|
|
if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto addn_maketuple = addn->input(1);
|
|
|
|
|
|
|
|
|
|
auto fg = all_reduce_fg_;
|
|
|
|
|
// addn inputs cross the graph, make the inputs same as allreduce node.
|
|
|
|
|
if (z_->isa<CNode>() && fg != z_->func_graph()) {
|
|
|
|
|
auto cnode_z = z_->cast<CNodePtr>();
|
|
|
|
|
z_ = NewCNode(cnode_z->inputs(), fg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto addn_op_node = addn->input(0);
|
|
|
|
|
auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0);
|
|
|
|
|
PatternNode x, y, z;
|
|
|
|
|
auto all_reduce_pat = PPrimitive(prim::kPrimAllReduce, x);
|
|
|
|
|
auto mul_pat = PBinOperation(prim::kPrimMul, all_reduce_pat, y, true);
|
|
|
|
|
auto admktup_pat = PBinOperation(prim::kPrimMakeTuple, mul_pat, z, true);
|
|
|
|
|
auto addn_pat = PPrimitive(prim::kPrimAddN, admktup_pat);
|
|
|
|
|
auto adjust_lambda = [&node, &x, &y, &z, &addn_pat, &all_reduce_pat, &admktup_pat, &mul_pat, this]() -> AnfNodePtr {
|
|
|
|
|
auto fg = all_reduce_pat.GetFuncGraph();
|
|
|
|
|
auto z_ = z.GetNode(node);
|
|
|
|
|
// If addn inputs cross the graph, make the inputs same as allreduce node.
|
|
|
|
|
if (z_->isa<CNode>() && fg != z_->func_graph()) {
|
|
|
|
|
auto cnode_z = z_->cast<CNodePtr>();
|
|
|
|
|
z_ = NewCNode(cnode_z->inputs(), fg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
|
|
|
|
|
AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
|
|
|
|
|
AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg);
|
|
|
|
|
AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg);
|
|
|
|
|
ProcessDependEdge(fg, addn_maketuple, all_reduce);
|
|
|
|
|
return mul;
|
|
|
|
|
auto addn_cnode = addn_pat.GetOriginalNode()->cast<CNodePtr>();
|
|
|
|
|
auto addn_op_node = addn_cnode->input(0);
|
|
|
|
|
auto make_tuple_op_node = addn_cnode->input(1)->cast<CNodePtr>()->input(0);
|
|
|
|
|
auto all_reduce_prim = all_reduce_pat.GetOriginalNode()->cast<CNodePtr>()->input(0);
|
|
|
|
|
mul_cnode_ = mul_pat.GetOriginalNode();
|
|
|
|
|
auto mul_prim = mul_cnode_->cast<CNodePtr>()->input(0);
|
|
|
|
|
auto addn_maketuple = admktup_pat.GetOriginalNode();
|
|
|
|
|
|
|
|
|
|
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x.GetNode(node)}, fg);
|
|
|
|
|
AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
|
|
|
|
|
AnfNodePtr all_reduce = NewCNode({all_reduce_prim, add}, fg);
|
|
|
|
|
AnfNodePtr mul = NewCNode({mul_prim, all_reduce, y.GetNode(node)}, fg);
|
|
|
|
|
ProcessDependEdge(fg, addn_maketuple, all_reduce);
|
|
|
|
|
return mul;
|
|
|
|
|
};
|
|
|
|
|
MATCH_REPLACE_LAMBDA(node, addn_pat, adjust_lambda);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple,
|
|
|
|
@ -146,48 +146,6 @@ void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfN
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AdjustAllReduceMulAdd::Visit(const AnfNodePtr &node) {
|
|
|
|
|
if (level_ == 0) {
|
|
|
|
|
level_ = 1;
|
|
|
|
|
is_reduce_match_ = false;
|
|
|
|
|
// {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}
|
|
|
|
|
AnfVisitor::Match(prim::kPrimMul)(node);
|
|
|
|
|
level_ = 0;
|
|
|
|
|
if (is_reduce_match_) {
|
|
|
|
|
mul_ = node->cast<CNodePtr>()->input(0);
|
|
|
|
|
mul_cnode_ = node->cast<CNodePtr>();
|
|
|
|
|
y_ = tmp_;
|
|
|
|
|
} else {
|
|
|
|
|
z_ = node;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (level_ == 1) {
|
|
|
|
|
// {prim::kPrimAllReduce, X}
|
|
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
if (cnode->size() > 1) {
|
|
|
|
|
all_reduce_ = cnode->input(0);
|
|
|
|
|
x_ = cnode->input(1);
|
|
|
|
|
is_reduce_match_ = true;
|
|
|
|
|
all_reduce_fg_ = cnode->func_graph();
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
tmp_ = node;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AdjustAllReduceMulAdd::Reset() {
|
|
|
|
|
level_ = 0;
|
|
|
|
|
is_reduce_match_ = false;
|
|
|
|
|
x_ = nullptr;
|
|
|
|
|
y_ = nullptr;
|
|
|
|
|
z_ = nullptr;
|
|
|
|
|
tmp_ = nullptr;
|
|
|
|
|
all_reduce_fg_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace irpass
|
|
|
|
|
} // namespace opt
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|