|
|
|
@ -26,7 +26,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
|
|
|
|
|
explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true)
|
|
|
|
|
: PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) {
|
|
|
|
|
weight_decay_ = std::make_shared<Var>();
|
|
|
|
|
scale_ = std::make_shared<Var>();
|
|
|
|
|
scale_ = std::make_shared<CondVar>(IsScalar);
|
|
|
|
|
variable_ = std::make_shared<Var>();
|
|
|
|
|
accumulation_ = std::make_shared<Var>();
|
|
|
|
|
learning_rate_ = std::make_shared<Var>();
|
|
|
|
@ -38,9 +38,10 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
|
|
|
|
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
static bool IsScalar(const BaseRef &n);
|
|
|
|
|
|
|
|
|
|
VarPtr weight_decay_;
|
|
|
|
|
VarPtr scale_;
|
|
|
|
|
|
|
|
|
|
VarPtr variable_;
|
|
|
|
|
VarPtr accumulation_;
|
|
|
|
|
VarPtr learning_rate_;
|
|
|
|
|