!10489 Add grad epilogue passes for PyNative mode.

From: @zh_qh
Reviewed-by: @chujinjin,@ginfung
Signed-off-by: @ginfung
pull/10489/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit ae46e0216f

@ -151,6 +151,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig({resolve_irpass.resolver_resolve_, resolve_irpass.resolver_getattr_,
irpass.get_make_ref_eliminate_, irpass.replace_old_param_});
// Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
OptPassGroupMap map_a({{"a_1", a_1},
{"a_2", a_2},
{"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)},
@ -272,6 +273,17 @@ OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) {
return map;
}
OptPassGroupMap GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
auto opt_a = GetOptPassesA(irpass);
auto a3 = opt_a[opt_a.size() - 1];
OptPassGroupMap map({
{"renormalize", opt::OptPassConfig::Renormalize()},
{"cse", opt::OptPassConfig(opt::CSEPass(false))},
{a3},
});
return map;
}
OptPassGroupMap GetInferenceOptPreparePhases() {
opt::irpass::InferenceOptPrepareLib irpass;
auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_});
@ -303,6 +315,8 @@ void InitOpt(const ResourcePtr &res) {
Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false);
g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass));
g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true);
g_pass_opts["opt_grad_epilogue"] =
Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false);
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
@ -351,6 +365,8 @@ bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepar
bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); }
bool OptPassGradEpilogueGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_grad_epilogue"); }
bool AddControlDependPass(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
@ -469,7 +485,8 @@ std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStru
{"opt_prepare", PrepareGroup},
{"cconv", CconvPass}};
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup},
std::vector<PassItem> kPynativePasses = {{"opt_grad_epilogue", OptPassGradEpilogueGroup},
{"opt_a", OptPassAGroup},
{"opt_b", OptPassBGroup},
{"cconv", CconvPass},
{"transform_top", TransformTopGraphPass},

Loading…
Cancel
Save