|
|
|
@ -228,6 +228,82 @@ class ConstantDuplicateMul : public AnfVisitor {
|
|
|
|
|
CNodePtr cnode_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// grad = AllReduce(grad) / worker_number
|
|
|
|
|
// grad = grad + weight * decy
|
|
|
|
|
// ->
|
|
|
|
|
// grad = grad + weight * decy
|
|
|
|
|
// grad = AllReduce(grad) / worker_number
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
|
|
|
|
|
// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
|
|
|
|
|
class AdjustAllReduceMulAdd : public AnfVisitor {
|
|
|
|
|
public:
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
|
Reset();
|
|
|
|
|
// {prim::kPrimAddN, Zs}
|
|
|
|
|
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto addn = node->cast<CNodePtr>();
|
|
|
|
|
if (addn->size() != 2) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
|
|
|
|
|
if (x_ == nullptr || y_ == nullptr || z_ == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto fg = node->func_graph();
|
|
|
|
|
AnfNodePtr tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), z_, x_}, fg);
|
|
|
|
|
AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimAddN), tuple}, fg);
|
|
|
|
|
AnfNodePtr all_reduce = NewCNode({NewValueNode(prim::kPrimAllReduce), add}, fg);
|
|
|
|
|
return NewCNode({NewValueNode(prim::kPrimMul), all_reduce, y_}, fg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override {
|
|
|
|
|
if (level_ == 0) {
|
|
|
|
|
level_ = 1;
|
|
|
|
|
is_reduce_match_ = false;
|
|
|
|
|
// {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}
|
|
|
|
|
AnfVisitor::Match(prim::kPrimMul)(node);
|
|
|
|
|
level_ = 0;
|
|
|
|
|
if (is_reduce_match_) {
|
|
|
|
|
y_ = tmp_;
|
|
|
|
|
} else {
|
|
|
|
|
z_ = node;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (level_ == 1) {
|
|
|
|
|
// {prim::kPrimAllReduce, X}
|
|
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
if (cnode->size() > 1) {
|
|
|
|
|
x_ = cnode->input(1);
|
|
|
|
|
is_reduce_match_ = true;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
tmp_ = node;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Reset() {
|
|
|
|
|
level_ = 0;
|
|
|
|
|
is_reduce_match_ = false;
|
|
|
|
|
x_ = nullptr;
|
|
|
|
|
y_ = nullptr;
|
|
|
|
|
z_ = nullptr;
|
|
|
|
|
tmp_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
int level_{0};
|
|
|
|
|
bool is_reduce_match_{false};
|
|
|
|
|
AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ArithmeticSimplify {
|
|
|
|
|
public:
|
|
|
|
|
ArithmeticSimplify()
|
|
|
|
@ -243,6 +319,7 @@ class ArithmeticSimplify {
|
|
|
|
|
eliminaters_.emplace_back(identity_);
|
|
|
|
|
eliminaters_.emplace_back(opt_update_zero_tensor_);
|
|
|
|
|
eliminaters_.emplace_back(constant_duplicate_mul_);
|
|
|
|
|
eliminaters_.emplace_back(adjust_allreduce_mul_add_);
|
|
|
|
|
}
|
|
|
|
|
~ArithmeticSimplify() = default;
|
|
|
|
|
|
|
|
|
@ -264,6 +341,7 @@ class ArithmeticSimplify {
|
|
|
|
|
PrimEliminater identity_;
|
|
|
|
|
OptUpdateZeroTensor opt_update_zero_tensor_;
|
|
|
|
|
ConstantDuplicateMul constant_duplicate_mul_;
|
|
|
|
|
AdjustAllReduceMulAdd adjust_allreduce_mul_add_;
|
|
|
|
|
std::vector<TransformFuncType> eliminaters_{};
|
|
|
|
|
};
|
|
|
|
|
} // namespace irpass
|
|
|
|
|