|
|
@ -59,10 +59,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond1::DefinePattern() const {
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
|
|
|
|
VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0});
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
|
|
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
|
|
|
VectorRef mul5({prim::kPrimMul, input4_, add3});
|
|
|
|
VectorRef mul5({prim::kPrimMul, input4_, add3});
|
|
|
|
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
|
|
|
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
|
|
|
return sub0;
|
|
|
|
return sub0;
|
|
|
@ -79,10 +79,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond2::DefinePattern() const {
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
|
|
|
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
|
|
|
|
VectorRef mul4({prim::kPrimMul, input3_, mul4_x_});
|
|
|
|
VectorRef mul4({prim::kPrimMul, input3_, mul4_x_});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
|
|
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
|
|
|
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
|
|
|
return sub0;
|
|
|
|
return sub0;
|
|
|
@ -99,10 +99,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond3::DefinePattern() const {
|
|
|
|
VectorRef mul3({prim::kPrimMul, square0, mul3_x_});
|
|
|
|
VectorRef mul3({prim::kPrimMul, square0, mul3_x_});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
|
|
|
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
|
|
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
|
|
|
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
|
|
|
return sub0;
|
|
|
|
return sub0;
|
|
|
@ -119,10 +119,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond4::DefinePattern() const {
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
|
|
|
|
VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0});
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
|
|
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
|
|
|
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
|
|
|
return sub0;
|
|
|
|
return sub0;
|
|
|
@ -139,10 +139,10 @@ const BaseRef AdamApplyOneWithDecayRuleCond5::DefinePattern() const {
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
|
|
|
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
|
|
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
|
|
|
VectorRef sub0({prim::kPrimSub, input3_, mul5});
|
|
|
|
return sub0;
|
|
|
|
return sub0;
|
|
|
@ -159,10 +159,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond1::DefinePattern() const {
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
|
|
|
|
VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0});
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
|
|
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
|
|
|
VectorRef mul5({prim::kPrimMul, input4_, add3});
|
|
|
|
VectorRef mul5({prim::kPrimMul, input4_, add3});
|
|
|
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
|
|
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
|
|
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
|
|
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
|
|
@ -184,10 +184,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond2::DefinePattern() const {
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
|
|
|
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
|
|
|
|
VectorRef mul4({prim::kPrimMul, input3_, mul4_x_});
|
|
|
|
VectorRef mul4({prim::kPrimMul, input3_, mul4_x_});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
|
|
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
|
|
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
|
|
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
|
|
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
|
|
@ -209,10 +209,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond3::DefinePattern() const {
|
|
|
|
VectorRef mul3({prim::kPrimMul, square0, mul3_x_});
|
|
|
|
VectorRef mul3({prim::kPrimMul, square0, mul3_x_});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
|
|
|
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
|
|
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
|
|
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
|
|
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
|
|
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
|
|
@ -234,10 +234,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond4::DefinePattern() const {
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
|
|
|
|
VectorRef add2({prim::kPrimAdd, add2_y_, sqrt0});
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
|
|
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
|
|
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
|
|
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
|
|
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
|
|
@ -259,10 +259,10 @@ const BaseRef AdamApplyOneWithDecayAssignRuleCond5::DefinePattern() const {
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef add1({add1_var_, mul2, mul3});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef sqrt0({sqrt, add1});
|
|
|
|
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
|
|
|
|
VectorRef add2({prim::kPrimAdd, sqrt0, add2_y_});
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef real_div0({real_div, add0, add2});
|
|
|
|
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
|
|
|
|
VectorRef add3({prim::kPrimAdd, mul4, real_div0});
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
VectorRef mul5({prim::kPrimMul, add3, input4_});
|
|
|
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
|
|
|
VectorRef sub0({sub0_var_, input3_, mul5});
|
|
|
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
|
|
|
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
|
|
|
|