From 094f0b2a07755223c4a43e89187608c37dbe9a90 Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Sat, 21 Nov 2020 15:01:42 +0800 Subject: [PATCH] bugfix:fused batch norm op's input channel nums should be a multiple of 4 --- .../ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc | 4 ++++ .../backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc | 4 ++++ .../ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc | 4 ++++ .../backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc | 4 ++++ mindspore/ccsrc/utils/utils.h | 1 + 5 files changed, 17 insertions(+) diff --git a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc index 466aeb39e6..0e231576cc 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc @@ -52,6 +52,10 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { return nullptr; } + auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0); + if (shape.back() % kBNChannelMultipleFactor != 0) { + return nullptr; + } auto x = AnfAlgo::GetInputNode(utils::cast(batch_norm_ex), 0); auto scale = AnfAlgo::GetInputNode(utils::cast(batch_norm_ex), 1); diff --git a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc index d80f865260..3745b28006 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc @@ -120,6 +120,10 @@ bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) { if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") { return false; } + auto shape = AnfAlgo::GetInputDeviceShape(node, 0); + if (shape.back() % kBNChannelMultipleFactor != 0) { + return false; + } auto relu_grad = AnfAlgo::GetInputNode(utils::cast(node), 0); MS_EXCEPTION_IF_NULL(relu_grad); diff --git a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc index 629dd17714..92faf0f325 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc @@ -49,6 +49,10 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { return nullptr; } + auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0); + if (shape.back() % kBNChannelMultipleFactor != 0) { + return nullptr; + } auto x = AnfAlgo::GetInputNode(utils::cast(batch_norm_ex), 0); auto scale = AnfAlgo::GetInputNode(utils::cast(batch_norm_ex), 1); diff --git a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc index e8dc539591..eeff8830da 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc @@ -44,6 +44,10 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") { return nullptr; } + auto shape = AnfAlgo::GetInputDeviceShape(node, 0); + if (shape.back() % kBNChannelMultipleFactor != 0) { + return nullptr; + } auto relu_grad = AnfAlgo::GetInputNode(utils::cast(node), 0); MS_EXCEPTION_IF_NULL(relu_grad); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 8cb1170ae6..808ea22373 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -344,6 +344,7 @@ const size_t kShape5dDims = 5; const size_t kShape1dDims = 1; const size_t kCubeSize = 16; const size_t kMemAlignSize = 512; +const size_t kBNChannelMultipleFactor = 4; const int kParameterDataTensorMask = 0; const int kParameterWeightTensorMask = 1; const int kValueNodeTensorMask = 2;