|
|
|
@ -151,6 +151,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|
|
|
|
opt::OptPassConfig({resolve_irpass.resolver_resolve_, resolve_irpass.resolver_getattr_,
|
|
|
|
|
irpass.get_make_ref_eliminate_, irpass.replace_old_param_});
|
|
|
|
|
|
|
|
|
|
// Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
|
|
|
|
|
OptPassGroupMap map_a({{"a_1", a_1},
|
|
|
|
|
{"a_2", a_2},
|
|
|
|
|
{"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)},
|
|
|
|
@ -272,6 +273,17 @@ OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|
|
|
|
return map;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OptPassGroupMap GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|
|
|
|
auto opt_a = GetOptPassesA(irpass);
|
|
|
|
|
auto a3 = opt_a[opt_a.size() - 1];
|
|
|
|
|
OptPassGroupMap map({
|
|
|
|
|
{"renormalize", opt::OptPassConfig::Renormalize()},
|
|
|
|
|
{"cse", opt::OptPassConfig(opt::CSEPass(false))},
|
|
|
|
|
{a3},
|
|
|
|
|
});
|
|
|
|
|
return map;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OptPassGroupMap GetInferenceOptPreparePhases() {
|
|
|
|
|
opt::irpass::InferenceOptPrepareLib irpass;
|
|
|
|
|
auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_});
|
|
|
|
@ -303,6 +315,8 @@ void InitOpt(const ResourcePtr &res) {
|
|
|
|
|
Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false);
|
|
|
|
|
g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass));
|
|
|
|
|
g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true);
|
|
|
|
|
g_pass_opts["opt_grad_epilogue"] =
|
|
|
|
|
Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false);
|
|
|
|
|
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
|
|
|
|
|
auto context_ptr = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
|
|
@ -351,6 +365,8 @@ bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepar
|
|
|
|
|
|
|
|
|
|
bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); }
|
|
|
|
|
|
|
|
|
|
bool OptPassGradEpilogueGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_grad_epilogue"); }
|
|
|
|
|
|
|
|
|
|
bool AddControlDependPass(const ResourcePtr &res) {
|
|
|
|
|
FuncGraphPtr func_graph = res->func_graph();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
@ -469,7 +485,8 @@ std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStru
|
|
|
|
|
{"opt_prepare", PrepareGroup},
|
|
|
|
|
{"cconv", CconvPass}};
|
|
|
|
|
|
|
|
|
|
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup},
|
|
|
|
|
std::vector<PassItem> kPynativePasses = {{"opt_grad_epilogue", OptPassGradEpilogueGroup},
|
|
|
|
|
{"opt_a", OptPassAGroup},
|
|
|
|
|
{"opt_b", OptPassBGroup},
|
|
|
|
|
{"cconv", CconvPass},
|
|
|
|
|
{"transform_top", TransformTopGraphPass},
|
|
|
|
|