|
|
|
@ -120,7 +120,9 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|
|
|
|
pm->AddPass(std::make_shared<opt::AdamFusion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
|
|
|
|
|
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
|
|
|
|
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
|
|
|
|
|
}
|
|
|
|
|
pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum"));
|
|
|
|
|
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
|
|
|
|
@ -165,15 +167,17 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
|
|
|
|
|
}
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
|
auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm");
|
|
|
|
|
std::vector<PrimitivePtr> black_list = {prim::kPrimReshape, prim::kPrimCast};
|
|
|
|
|
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::CompositeOpsFusion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list));
|
|
|
|
|
pm->AddPass(std::make_shared<opt::ArithmeticSimplify>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list));
|
|
|
|
|
pm->AddPass(std::make_shared<opt::TensorPromotion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
|
|
|
|
|
// After Simplify and Splitter, a lot of redundant getitem/maketuple
|
|
|
|
|
// will be exposed, use GetitemTuple Pass to delete them.
|
|
|
|
|
pm->AddPass(std::make_shared<opt::GetitemTuple>());
|
|
|
|
|