From b1585f862d9730b95e06c73c21f9b4e441cb05ae Mon Sep 17 00:00:00 2001 From: liubuyu Date: Thu, 16 Apr 2020 10:03:58 +0800 Subject: [PATCH] auto mix precision --- .../device/ascend/kernel_select_ascend.cc | 89 ++++--------- .../ccsrc/kernel/tbe/tbe_kernel_select.cc | 61 ++++++++- .../ascend/ascend_backend_optimization.cc | 2 + .../ir_fusion/parameter_and_transop_fusion.cc | 120 ++++++++++++++++++ .../ir_fusion/parameter_and_transop_fusion.h | 41 ++++++ mindspore/ops/_op_impl/tbe/cast.py | 6 + 6 files changed, 253 insertions(+), 66 deletions(-) create mode 100644 mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc create mode 100644 mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index f02e677163..36c622cbc5 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -45,64 +45,6 @@ enum MatchCountPriority : int { const size_t kMaxCount = 0xffffffff; const int kUnSupportMixedDataTypeIndex = -1; -const std::set kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, - kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, - kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, - kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04}; - -bool IsShapeMatchFormat(const std::vector &shape, const std::string &format) { - // if format is default, it remarkes support all format - if (kOpFormatList.find(format) == kOpFormatList.end()) { - MS_LOG(EXCEPTION) << "got the unknown format " << format; - } - if (format == kOpFormat_DEFAULT) { - return true; - } - // if shape size is 0, the shape will be a scalar - if (shape.empty()) { - return true; - } - if (shape.size() > kShapeSupportFormatMap.size()) { - return false; - } - if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) { - return true; - } - return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end()); -} - -bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { - MS_EXCEPTION_IF_NULL(kernel_node); - auto check_function = [](const std::vector &shape, const std::string &format) -> bool { - if (!IsShapeMatchFormat(shape, format)) { - return false; - } - for (auto shape_value : shape) { - if (shape_value == 0) { - MS_LOG(EXCEPTION) << "dimension size of the tensor shape should be a positive integer, but got " << shape_value; - } - } - return true; - }; - if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { - return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && - AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); - } - for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); - if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) { - return false; - } - } - for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); - if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) { - return false; - } - } - return true; -} - bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { MS_EXCEPTION_IF_NULL(cnode); // Check input data type @@ -459,6 +401,29 @@ int PrecisionReduce(const std::vector &node_mix_precision_datatype_index, // raise precision int selected_index = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatype, kernel_match_datatype_idx); + if (selected_index != -1) { + int max_match = 0; + auto iter = kernel_match_datatype_idx->begin(); + int match_count = 0; + while (iter != kernel_match_datatype_idx->end()) { + auto kernel_datatypes = kernel_support_datatype.find(iter->first); + if (kernel_datatypes == kernel_support_datatype.end()) { + MS_LOG(EXCEPTION) << "Can not find kernel index" << iter->first << "'s datatype."; + } + if (kernel_datatypes->second.size() < node_mix_precision_datatype.size()) { + MS_LOG(EXCEPTION) << "Kernel datatype size is not equal to node datatype size!"; + } + for (size_t i = 0; i < node_mix_precision_datatype.size(); ++i) { + if (node_mix_precision_datatype[i] == kernel_datatypes->second[i]) { + ++match_count; + } + } + if (match_count > max_match) { + selected_index = SizeToInt(iter->first); + } + ++iter; + } + } if (selected_index == -1 && context_ptr->enable_reduce_precision()) { selected_index = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, @@ -507,9 +472,6 @@ void SelectKernelInfo(const CNodePtr &kernel_node) { kernel::KernelQuery(kernel_node, &kernel_info_list); std::vector most_match_counts = {-1, -1, -1, -1}; int selected_index = -1; - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool auto_mixed_precision = context_ptr->auto_mixed_precision_flag(); std::unordered_map> kernel_match_datatype_idx; std::unordered_map> kernel_support_datatype; std::vector node_mix_precision_datatype_index; @@ -517,16 +479,13 @@ void SelectKernelInfo(const CNodePtr &kernel_node) { for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { std::vector cur_kernel_info_match_counts = {0, 0, 0, 0}; auto kernel_build_info = *(kernel_info_list[info_index]); - if (!IsValidKernelInfo(kernel_node, kernel_build_info)) { - continue; - } std::vector support_indexes; std::vector support_datatypes; AddNodeAndKernelDataType(kernel_node, kernel_build_info, &support_indexes, &node_mix_precision_datatype, &support_datatypes, &node_mix_precision_datatype_index); kernel_match_datatype_idx[info_index] = support_indexes; kernel_support_datatype[info_index] = support_datatypes; - if (!auto_mixed_precision && !MatchInferOutputDataType(kernel_node, kernel_build_info)) { + if (!MatchInferOutputDataType(kernel_node, kernel_build_info)) { continue; } std::shared_ptr kernel_info_ptr = kernel_info_list[info_index]; diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc index e818f503c0..127451851e 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include "session/anf_runtime_algorithm.h" #include "kernel/oplib/oplib.h" @@ -510,6 +511,64 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr &shape, const std::string &format) { + const std::set kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, + kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, + kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, + kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04}; + + // if format is default, it remarkes support all format + if (kOpFormatList.find(format) == kOpFormatList.end()) { + MS_LOG(EXCEPTION) << "Got the unknown format " << format; + } + if (format == kOpFormat_DEFAULT) { + return true; + } + // if shape size is 0, the shape will be a scalar + if (shape.empty()) { + return true; + } + if (shape.size() > kShapeSupportFormatMap.size()) { + return false; + } + if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) { + return true; + } + return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end()); +} + +bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { + MS_EXCEPTION_IF_NULL(kernel_node); + auto check_function = [](const std::vector &shape, const std::string &format) -> bool { + if (!IsShapeMatchFormat(shape, format)) { + return false; + } + for (auto shape_value : shape) { + if (shape_value == 0) { + MS_LOG(EXCEPTION) << "Dimension size of the tensor shape should be a positive integer, but got " << shape_value; + } + } + return true; + }; + for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); + if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) { + return false; + } + } + for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); + if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) { + return false; + } + } + if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { + return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && + AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0); + } + return true; +} + void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_info_list); @@ -534,7 +593,7 @@ void TbeMetadataInfo(const CNodePtr &kernel_node, std::vectorexecution_mode() == kPynativeMode) { kernel_info_list->push_back(parse_info); } else { - if (CheckSupported(kernel_node, parse_info)) { + if (IsValidKernelInfo(kernel_node, *(parse_info)) && CheckSupported(kernel_node, parse_info)) { kernel_info_list->push_back(parse_info); } else { MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info."; diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index a72fb9dc9a..7a35627e25 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -37,6 +37,7 @@ #include "pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h" #include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" #include "pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" +#include "pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h" #include "pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h" #include "pre_activate/ascend/ir_fusion/transdata_split.h" #include "pre_activate/ascend/ir_fission/topk_split.h" @@ -243,6 +244,7 @@ void AscendBackendOptimization(const std::shared_ptr &kern auto optimizer = std::make_shared(); auto other_pm = std::make_shared("other_pm"); other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc new file mode 100644 index 0000000000..faa1308f8b --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc @@ -0,0 +1,120 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h" +#include +#include "session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "operator/ops.h" +#include "device/kernel_info.h" +#include "pre_activate/common/helper.h" +#include "pre_activate/common/optimizer.h" +#include "pre_activate/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool first_flag, + std::vector *trans_road) { + if (node == nullptr) { + MS_LOG(ERROR) << "nullptr"; + return nullptr; + } + if (node->isa()) { + auto cnode = node->cast(); + auto op_name = AnfAlgo::GetCNodeName(cnode); + auto manager = func_graph->manager(); + if (manager == nullptr) { + return nullptr; + } + if (op_name == prim::kPrimCast->name() || op_name == prim::kPrimTranspose->name() || + op_name == prim::kPrimReshape->name() || op_name == kTransDataOpName) { + auto users = manager->node_users()[node]; + if (users.size() > 1 && !first_flag) { + return nullptr; + } + trans_road->push_back(cnode); + first_flag = false; + auto next_node = AnfAlgo::GetInputNode(cnode, 0); + if (next_node->isa() || next_node->isa()) { + return next_node; + } + return ParamTransRoad(func_graph, next_node, first_flag, trans_road); + } + } else if (node->isa() || node->isa()) { + return node; + } + return nullptr; +} + +bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Func graph is nullptr"; + return false; + } + auto manager = func_graph->manager(); + if (manager == nullptr) { + return false; + } + std::vector node_list = TopoSort(func_graph->get_return()); + bool changed = false; + for (auto node : node_list) { + if (node == nullptr || !node->isa()) { + continue; + } + auto cnode = node->cast(); + auto node_name = AnfAlgo::GetCNodeName(cnode); + if (node_name == prim::kPrimCast->name() || node_name == prim::kPrimTranspose->name() || + node_name == prim::kPrimReshape->name() || node_name == kTransDataOpName) { + MS_LOG(DEBUG) << "Skip trans op"; + continue; + } + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) { + std::vector trans_road; + bool first_flag = true; + auto final_node = ParamTransRoad(func_graph, AnfAlgo::GetInputNode(cnode, input_index), first_flag, &trans_road); + if (final_node != nullptr && trans_road.size() == 3 && AnfAlgo::GetCNodeName(trans_road[0]) == kTransDataOpName && + AnfAlgo::GetCNodeName(trans_road[1]) == prim::kPrimCast->name() && + AnfAlgo::GetCNodeName(trans_road[2]) == kTransDataOpName) { + auto cur_transop = trans_road[0]; + auto format = AnfAlgo::GetOutputFormat(cur_transop, 0); + auto dtype = AnfAlgo::GetOutputDeviceDataType(cur_transop, 0); + auto param_format = AnfAlgo::GetOutputFormat(final_node, 0); + auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0); + + auto cast = trans_road[1]; + auto cast_format = AnfAlgo::GetOutputFormat(cast, 0); + auto cast_build_info = cast->kernel_info()->select_kernel_build_info(); + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetOutputsFormat({format}); + builder.SetInputsFormat({format}); + builder.SetInputsDeviceType({param_dtype}); + builder.SetOutputsDeviceType({dtype}); + builder.SetKernelType(cast_build_info->kernel_type()); + builder.SetFusionType(cast_build_info->fusion_type()); + builder.SetProcessor(cast_build_info->processor()); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); + if (param_format == format && param_dtype != dtype) { + manager->Replace(trans_road[2], final_node); + manager->Replace(cur_transop, cast); + } + changed = true; + } + } + } + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h new file mode 100644 index 0000000000..823ec083b1 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "pre_activate/common/pass.h" + +namespace mindspore { +namespace opt { +class ParameterTransOpFusion : public Pass { + public: + explicit ParameterTransOpFusion(size_t groups = 1) : Pass("Parameter_and_transop_fusion"), groups_(groups) {} + ~ParameterTransOpFusion() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + size_t groups_ = 1; +}; +} // namespace opt +} // namespace mindspore + +#endif diff --git a/mindspore/ops/_op_impl/tbe/cast.py b/mindspore/ops/_op_impl/tbe/cast.py index a18dcddfbf..07e14139da 100644 --- a/mindspore/ops/_op_impl/tbe/cast.py +++ b/mindspore/ops/_op_impl/tbe/cast.py @@ -44,6 +44,12 @@ cast_op_info = TBERegOp("Cast") \ .dtype_format(DataType.F16_Default, DataType.U8_Default) \ .dtype_format(DataType.F16_Default, DataType.F32_Default) \ .dtype_format(DataType.F16_Default, DataType.I32_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F16_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F16_FracNZ, DataType.F32_FracNZ) \ + .dtype_format(DataType.F32_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F32_FracNZ, DataType.F16_FracNZ) \ .dtype_format(DataType.F32_Default, DataType.F16_Default) \ .dtype_format(DataType.F32_Default, DataType.I32_Default) \ .get_op_info()