From b83c3f7db56b00427ab499da889c02d7b9b5c1ce Mon Sep 17 00:00:00 2001 From: liubuyu Date: Thu, 18 Feb 2021 19:04:57 +0800 Subject: [PATCH] update inout format for kernel json --- .../backend/kernel_compiler/tbe/tbe_kernel_build.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc index 8fe80ce78f..b502cb6e38 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc @@ -595,6 +595,9 @@ std::string TbeKernelJsonCreator::GetDeviceInputType(const AnfNodePtr &anf_node, std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const { MS_EXCEPTION_IF_NULL(anf_node); std::string format = kOpFormat_NCHW; + if (anf_node->isa() && IsNeedChangeDefaultFormat(anf_node->cast())) { + format = kOpFormat_NCDHW; + } if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { format = AnfAlgo::GetInputFormat(anf_node, real_index); if (format == kOpFormat_FRAC_Z) { @@ -603,9 +606,6 @@ std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_nod format = kOpFormat_NCHW; } } - if (anf_node->isa() && IsNeedChangeDefaultFormat(anf_node->cast())) { - format = kOpFormat_NCDHW; - } return format; } @@ -637,6 +637,9 @@ std::string TbeKernelJsonCreator::GetDeviceOutputType(const AnfNodePtr &anf_node std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const { MS_EXCEPTION_IF_NULL(anf_node); std::string format = kOpFormat_NCHW; + if (anf_node->isa() && IsNeedChangeDefaultFormat(anf_node->cast())) { + format = kOpFormat_NCDHW; + } if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { format = AnfAlgo::GetOutputFormat(anf_node, real_index); if (format == kOpFormat_FRAC_Z) { @@ -645,9 +648,6 @@ std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_no format = kOpFormat_NCHW; } } - if (anf_node->isa() && IsNeedChangeDefaultFormat(anf_node->cast())) { - format = kOpFormat_NCDHW; - } return format; }