|
|
|
@ -60,10 +60,11 @@ const BaseRef AdamFusion::DefinePattern() const {
|
|
|
|
|
{prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
|
|
|
|
|
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update});
|
|
|
|
|
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
|
|
|
|
|
VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})});
|
|
|
|
|
VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})});
|
|
|
|
|
VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})});
|
|
|
|
|
return depend3;
|
|
|
|
|
|
|
|
|
|
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, param_, next_param})});
|
|
|
|
|
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, m_, next_m})});
|
|
|
|
|
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, v_, next_v})});
|
|
|
|
|
return next_param;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const {
|
|
|
|
|