From 097f53bed9208b5b9a27b2fe0eab12e2bb972bcb Mon Sep 17 00:00:00 2001 From: WilliamLian Date: Tue, 25 Aug 2020 21:52:47 +0800 Subject: [PATCH] add attr for transdata node --- .../kernel_compiler/kernel_build_info.cc | 12 ++++++++---- .../backend/kernel_compiler/kernel_build_info.h | 2 ++ .../backend/kernel_compiler/kernel_query.cc | 3 ++- .../backend/kernel_compiler/tbe/tbe_adapter.cc | 16 ---------------- .../backend/kernel_compiler/tbe/tbe_adapter.h | 2 +- .../tbe/tbe_kernel_parallel_build.cc | 1 - .../backend/optimizer/ascend/ascend_helper.cc | 17 +++++++++++++++++ mindspore/ccsrc/utils/utils.h | 1 + 8 files changed, 31 insertions(+), 23 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc index 6daade89ea..b534fd112b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc @@ -108,10 +108,7 @@ std::string KernelBuildInfo::ToString() const { return output_buffer.str(); } -bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const { - if (kernel_type_ != other.kernel_type_ || fusion_type_ != other.fusion_type_ || processor_ != other.processor_) { - return false; - } +bool KernelBuildInfo::IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const { if (inputs_format_ != other.inputs_format_ || outputs_format_ != other.outputs_format_) { if (op_pattern_ != kFormatAgnosticPattern) { return false; @@ -123,6 +120,13 @@ bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const { return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_); } +bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const { + if (kernel_type_ != other.kernel_type_ || fusion_type_ != other.fusion_type_ || processor_ != other.processor_) { + return false; + } + return IsSimilarityKernelBuildInfo(other); +} + bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_.empty(); } bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h index 845930f815..6c4ee24323 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h @@ -91,6 +91,8 @@ class KernelBuildInfo { std::string ToString() const; + bool IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const; + bool operator==(const KernelBuildInfo &other) const; bool operator!=(const KernelBuildInfo &other) const; diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc index 2dfda849e9..d3555bad07 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc @@ -130,6 +130,7 @@ void AICPUQuery(const CNodePtr &kernel_node, std::vectorIsSimilarityKernelBuildInfo(*select_kernel_build_info); }); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc index f3ef4e24f4..07e2893294 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc @@ -178,22 +178,6 @@ void TbeAdapter::NormalizeFuncName(std::string *func_name) { } } -void TbeAdapter::SetTbeAttrsForTransDataOp(const mindspore::AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - if (AnfAlgo::GetCNodeName(anf_node) == kTransDataOpName) { - std::string input_format = AnfAlgo::GetInputFormat(anf_node, 0); - std::string output_format = AnfAlgo::GetOutputFormat(anf_node, 0); - if (input_format == kOpFormat_DEFAULT) { - input_format = kOpFormat_NCHW; - } - if (output_format == kOpFormat_DEFAULT) { - output_format = kOpFormat_NCHW; - } - AnfAlgo::SetNodeAttr("src_format", MakeValue(input_format), anf_node); - AnfAlgo::SetNodeAttr("dst_format", MakeValue(output_format), anf_node); - } -} - std::unordered_set input_order_adjusted_ops = { "Conv2DBackpropInput", "Conv2DBackpropFilter", "LogSoftmaxGrad", "LayerNormGrad", "LayerNormXBackprop", "LayerNormBetaGammaBackprop", "MinimumGrad", "MaximumGrad", "ApplyCenteredRMSProp"}; diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h index f72de02e8f..b37cf68da6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h @@ -36,7 +36,7 @@ class TbeAdapter { TbeAdapter() = default; ~TbeAdapter() = default; static void NormalizeFuncName(std::string *func_name); - static void SetTbeAttrsForTransDataOp(const AnfNodePtr &anf_node); + static void InputOrderPass(const std::string &op_name, std::vector> const &inputs_list, nlohmann::json *inputs_json); static bool RunAttrPass(const AnfNodePtr &anf_node, const std::vector> &op_info_attrs, diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc index 4114e6729b..79a538acd3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc @@ -75,7 +75,6 @@ bool TbeOpParallelBuild(const std::vector &anf_nodes) { set processed_kernel; for (const auto &anf_node : anf_nodes) { // gen kernel json - tbe::TbeAdapter::SetTbeAttrsForTransDataOp(anf_node); if (AnfAlgo::GetKernelMod(anf_node) != nullptr) { continue; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index 39ebc39612..1a81c82843 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -48,6 +48,22 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i return reshape; } +void SetTransNodeAttr(const CNodePtr &trans_node) { + MS_EXCEPTION_IF_NULL(trans_node); + if (AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName) { + std::string input_format = AnfAlgo::GetInputFormat(trans_node, 0); + std::string output_format = AnfAlgo::GetOutputFormat(trans_node, 0); + if (input_format == kOpFormat_DEFAULT) { + input_format = kOpFormat_NCHW; + } + if (output_format == kOpFormat_DEFAULT) { + output_format = kOpFormat_NCHW; + } + AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(input_format), trans_node); + AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(output_format), trans_node); + } +} + AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { AnfNodePtr trans_node = nullptr; @@ -173,6 +189,7 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string & builder->SetInputsDeviceType({type_id}); } AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get()); + SetTransNodeAttr(trans_data->cast()); } CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index af1243e9fc..da7fa32ae9 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -224,6 +224,7 @@ constexpr auto kAttrEventId = "event_id"; constexpr auto kAttrDynInput = "dynamic"; constexpr auto kAttrDynInputSizes = "dyn_input_sizes"; constexpr auto kAttrSrcFormat = "src_format"; +constexpr auto kAttrDstFormat = "dst_format"; constexpr auto kAttrMultiples = "multiples"; constexpr auto kAttrFixPrecision = "fix_precision"; constexpr auto kAttrOutputPrecision = "output_precision";