diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc index b957e3974d..580f15e6e2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc @@ -241,8 +241,8 @@ bool TbeKernelSelect::IsShapeMatchFormat(const std::vector &shape, const return true; } // not support format: - // 1 NCDHW with shape size != 5 - if (format == kOpFormat_NCDHW && shape.size() != kShape5dDims) { + // 1 3d formats with shape size > 5 + if (k3DFormatSet.find(format) != k3DFormatSet.end() && shape.size() > kShape5dDims) { return false; } return true; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 37804d3848..95d3fc961e 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -516,7 +516,8 @@ const std::set kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat const std::set kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, kPadAndShiftOpName, kCTCGreedyDecoderOpName}; -const std::set k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D}; +const std::set k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D, + kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC}; const std::set DynamicShapeConstInputToAttr = { kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName,