|
|
@ -30,7 +30,7 @@ class TestHWLambNextMVWithDecayRule : public BackendCommon {
|
|
|
|
UT::PyFuncGraphFetcher get_py_fun_;
|
|
|
|
UT::PyFuncGraphFetcher get_py_fun_;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_matched) {
|
|
|
|
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_matched) {
|
|
|
|
/*
|
|
|
|
/*
|
|
|
|
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
|
|
|
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
|
|
|
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
|
|
|
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
|
|
@ -55,7 +55,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_matched) {
|
|
|
|
* output = tuple_getitem(outputs, 0)
|
|
|
|
* output = tuple_getitem(outputs, 0)
|
|
|
|
* return output
|
|
|
|
* return output
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before");
|
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before");
|
|
|
|
std::vector<int> shp{2, 32, 224, 224};
|
|
|
|
std::vector<int> shp{2, 32, 224, 224};
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
|
|
|
AbstractBasePtrList args_spec_list;
|
|
|
|
AbstractBasePtrList args_spec_list;
|
|
|
@ -66,15 +66,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_matched) {
|
|
|
|
|
|
|
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
auto pm = std::make_shared<opt::PassManager>();
|
|
|
|
auto pm = std::make_shared<opt::PassManager>();
|
|
|
|
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRule>());
|
|
|
|
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond4>());
|
|
|
|
optimizer->AddPassManager(pm);
|
|
|
|
optimizer->AddPassManager(pm);
|
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
|
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after");
|
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after");
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_add3) {
|
|
|
|
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_unmatched_add3) {
|
|
|
|
/*
|
|
|
|
/*
|
|
|
|
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
|
|
|
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
|
|
|
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
|
|
|
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
|
|
@ -99,7 +99,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_add
|
|
|
|
* output = tuple_getitem(outputs, 0)
|
|
|
|
* output = tuple_getitem(outputs, 0)
|
|
|
|
* return output
|
|
|
|
* return output
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before_unmatched_add3");
|
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before_unmatched_add3");
|
|
|
|
std::vector<int> shp{2, 32, 224, 224};
|
|
|
|
std::vector<int> shp{2, 32, 224, 224};
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
|
|
|
AbstractBasePtrList args_spec_list;
|
|
|
|
AbstractBasePtrList args_spec_list;
|
|
|
@ -111,15 +111,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_add
|
|
|
|
|
|
|
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
auto pm = std::make_shared<opt::PassManager>();
|
|
|
|
auto pm = std::make_shared<opt::PassManager>();
|
|
|
|
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRule>());
|
|
|
|
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond4>());
|
|
|
|
optimizer->AddPassManager(pm);
|
|
|
|
optimizer->AddPassManager(pm);
|
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after");
|
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after");
|
|
|
|
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
|
|
|
|
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_mul4) {
|
|
|
|
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_unmatched_mul4) {
|
|
|
|
/*
|
|
|
|
/*
|
|
|
|
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
|
|
|
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
|
|
|
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
|
|
|
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
|
|
@ -144,7 +144,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_mul
|
|
|
|
* output = tuple_getitem(outputs, 0)
|
|
|
|
* output = tuple_getitem(outputs, 0)
|
|
|
|
* return output
|
|
|
|
* return output
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before_unmatched_mul4");
|
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before_unmatched_mul4");
|
|
|
|
std::vector<int> shp{2, 32, 224, 224};
|
|
|
|
std::vector<int> shp{2, 32, 224, 224};
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
|
|
|
AbstractBasePtrList args_spec_list;
|
|
|
|
AbstractBasePtrList args_spec_list;
|
|
|
@ -156,15 +156,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_mul
|
|
|
|
|
|
|
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
auto pm = std::make_shared<opt::PassManager>();
|
|
|
|
auto pm = std::make_shared<opt::PassManager>();
|
|
|
|
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRule>());
|
|
|
|
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond4>());
|
|
|
|
optimizer->AddPassManager(pm);
|
|
|
|
optimizer->AddPassManager(pm);
|
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after");
|
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after");
|
|
|
|
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
|
|
|
|
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_real_div0) {
|
|
|
|
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_unmatched_real_div0) {
|
|
|
|
/*
|
|
|
|
/*
|
|
|
|
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
|
|
|
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
|
|
|
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
|
|
|
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
|
|
@ -189,7 +189,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
|
|
|
|
* output = tuple_getitem(outputs, 0)
|
|
|
|
* output = tuple_getitem(outputs, 0)
|
|
|
|
* return output
|
|
|
|
* return output
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before_unmatched_real_div0");
|
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before_unmatched_real_div0");
|
|
|
|
std::vector<int> shp{2, 32, 224, 224};
|
|
|
|
std::vector<int> shp{2, 32, 224, 224};
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
|
|
|
AbstractBasePtrList args_spec_list;
|
|
|
|
AbstractBasePtrList args_spec_list;
|
|
|
@ -201,15 +201,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
|
|
|
|
|
|
|
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
auto pm = std::make_shared<opt::PassManager>();
|
|
|
|
auto pm = std::make_shared<opt::PassManager>();
|
|
|
|
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRule>());
|
|
|
|
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond4>());
|
|
|
|
optimizer->AddPassManager(pm);
|
|
|
|
optimizer->AddPassManager(pm);
|
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after");
|
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after");
|
|
|
|
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
|
|
|
|
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_real_div1) {
|
|
|
|
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_unmatched_real_div1) {
|
|
|
|
/*
|
|
|
|
/*
|
|
|
|
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
|
|
|
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
|
|
|
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
|
|
|
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
|
|
@ -234,7 +234,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
|
|
|
|
* output = tuple_getitem(outputs, 0)
|
|
|
|
* output = tuple_getitem(outputs, 0)
|
|
|
|
* return output
|
|
|
|
* return output
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before_unmatched_real_div1");
|
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before_unmatched_real_div1");
|
|
|
|
std::vector<int> shp{2, 32, 224, 224};
|
|
|
|
std::vector<int> shp{2, 32, 224, 224};
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
|
|
|
AbstractBasePtrList args_spec_list;
|
|
|
|
AbstractBasePtrList args_spec_list;
|
|
|
@ -246,11 +246,11 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
|
|
|
|
|
|
|
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
auto pm = std::make_shared<opt::PassManager>();
|
|
|
|
auto pm = std::make_shared<opt::PassManager>();
|
|
|
|
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRule>());
|
|
|
|
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond4>());
|
|
|
|
optimizer->AddPassManager(pm);
|
|
|
|
optimizer->AddPassManager(pm);
|
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after");
|
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after");
|
|
|
|
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
|
|
|
|
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|