|
|
|
@ -65,6 +65,8 @@ void GPUSession::StartKernelRT() const {
|
|
|
|
|
|
|
|
|
|
void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
auto context_ptr = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
|
auto pm = std::make_shared<opt::PassManager>();
|
|
|
|
|
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
|
|
|
|
@ -73,9 +75,11 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|
|
|
|
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
|
|
|
|
|
if (context_ptr->execution_mode() != kPynativeMode) {
|
|
|
|
|
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
|
|
|
|
|
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
|
|
|
|
|
}
|
|
|
|
|
optimizer->AddPassManager(pm);
|
|
|
|
|
(void)optimizer->Optimize(kernel_graph);
|
|
|
|
|
kernel_graph->SetExecOrderByDefault();
|
|
|
|
|