From 146ac1263e068c7be39bbe8057acab2893882d5c Mon Sep 17 00:00:00 2001 From: YuJianfeng Date: Wed, 22 Apr 2020 10:40:59 +0800 Subject: [PATCH] Overlength functions rectification --- .../ascend/ascend_backend_optimization.cc | 53 +++++++++++-------- .../ir_fusion/parameter_and_transop_fusion.cc | 31 +++++++---- 2 files changed, 50 insertions(+), 34 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 496a9b276f..a2d82525e9 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -70,6 +70,35 @@ namespace mindspore { namespace opt { +namespace { +void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { + MS_EXCEPTION_IF_NULL(ir_fusion_pm); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); +} +} // namespace + void RunOpAscendDataLayout(const std::shared_ptr &kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); auto optimizer = std::make_shared(); @@ -164,29 +193,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); if (context_ptr->ir_fusion_flag()) { - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); + AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); } if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc index faa1308f8b..fe9b35a5e9 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc @@ -26,6 +26,7 @@ namespace mindspore { namespace opt { +namespace { const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool first_flag, std::vector *trans_road) { if (node == nullptr) { @@ -59,6 +60,24 @@ const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr return nullptr; } +kernel::KernelBuildInfoPtr GetKernelBuildInfo(const CNodePtr &cast, const string &format, TypeId input_type, + TypeId output_type) { + MS_EXCEPTION_IF_NULL(cast); + auto kernel_info = cast->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + auto cast_build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(cast_build_info); + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetOutputsFormat({format}); + builder.SetInputsFormat({format}); + builder.SetInputsDeviceType({input_type}); + builder.SetOutputsDeviceType({output_type}); + builder.SetKernelType(cast_build_info->kernel_type()); + builder.SetFusionType(cast_build_info->fusion_type()); + builder.SetProcessor(cast_build_info->processor()); + return builder.Build(); +} +} // namespace bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { MS_LOG(ERROR) << "Func graph is nullptr"; @@ -95,17 +114,7 @@ bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0); auto cast = trans_road[1]; - auto cast_format = AnfAlgo::GetOutputFormat(cast, 0); - auto cast_build_info = cast->kernel_info()->select_kernel_build_info(); - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - builder.SetOutputsFormat({format}); - builder.SetInputsFormat({format}); - builder.SetInputsDeviceType({param_dtype}); - builder.SetOutputsDeviceType({dtype}); - builder.SetKernelType(cast_build_info->kernel_type()); - builder.SetFusionType(cast_build_info->fusion_type()); - builder.SetProcessor(cast_build_info->processor()); - AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); + AnfAlgo::SetSelectKernelBuildInfo(GetKernelBuildInfo(cast, format, param_dtype, dtype), cast.get()); if (param_format == format && param_dtype != dtype) { manager->Replace(trans_road[2], final_node); manager->Replace(cur_transop, cast);