|
|
|
@ -302,6 +302,11 @@ OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|
|
|
|
return map;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OptPassGroupMap GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|
|
|
|
OptPassGroupMap map({{"cse", opt::OptPassConfig(opt::CSEPass(false))}});
|
|
|
|
|
return map;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, std::shared_ptr<Optimizer>> g_pass_opts = {};
|
|
|
|
|
|
|
|
|
|
void InitOpt(const ResourcePtr &res) {
|
|
|
|
@ -323,6 +328,8 @@ void InitOpt(const ResourcePtr &res) {
|
|
|
|
|
g_pass_opts["opt_grad_epilogue"] =
|
|
|
|
|
Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false);
|
|
|
|
|
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
|
|
|
|
|
g_pass_opts["opt_after_recompute"] =
|
|
|
|
|
Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass));
|
|
|
|
|
auto context_ptr = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
|
|
|
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
|
|
|
@ -367,6 +374,7 @@ bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res,
|
|
|
|
|
bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); }
|
|
|
|
|
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
|
|
|
|
|
bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
|
|
|
|
|
bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); }
|
|
|
|
|
|
|
|
|
|
bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); }
|
|
|
|
|
|
|
|
|
@ -525,7 +533,8 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
|
|
|
|
|
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
|
|
|
|
|
{"add_cache_embedding", AddCacheEmbeddingPass},
|
|
|
|
|
{"add_control_depend", AddControlDependPass},
|
|
|
|
|
{"add_recomputation", AddRecomputationPass}};
|
|
|
|
|
{"add_recomputation", AddRecomputationPass},
|
|
|
|
|
{"cse_after_recomputation", OptAfterRecomputeGroup}};
|
|
|
|
|
|
|
|
|
|
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
|
|
|
|
{"opt_a", OptPassAGroup},
|
|
|
|
|