bugfix:fused batch norm op's input channel nums should be a multiple of 4

pull/8867/head
lizhenyu 4 years ago
parent ac0b1aa960
commit 094f0b2a07

@ -52,6 +52,10 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr;
}
auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0);
if (shape.back() % kBNChannelMultipleFactor != 0) {
return nullptr;
}
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0);
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1);

@ -120,6 +120,10 @@ bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) {
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
return false;
}
auto shape = AnfAlgo::GetInputDeviceShape(node, 0);
if (shape.back() % kBNChannelMultipleFactor != 0) {
return false;
}
auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(relu_grad);

@ -49,6 +49,10 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr;
}
auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0);
if (shape.back() % kBNChannelMultipleFactor != 0) {
return nullptr;
}
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0);
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1);

@ -44,6 +44,10 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr;
}
auto shape = AnfAlgo::GetInputDeviceShape(node, 0);
if (shape.back() % kBNChannelMultipleFactor != 0) {
return nullptr;
}
auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(relu_grad);

@ -344,6 +344,7 @@ const size_t kShape5dDims = 5;
const size_t kShape1dDims = 1;
const size_t kCubeSize = 16;
const size_t kMemAlignSize = 512;
const size_t kBNChannelMultipleFactor = 4;
const int kParameterDataTensorMask = 0;
const int kParameterWeightTensorMask = 1;
const int kValueNodeTensorMask = 2;

Loading…
Cancel
Save