|
|
|
@ -52,6 +52,10 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
|
|
|
|
|
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0);
|
|
|
|
|
if (shape.back() % kBNChannelMultipleFactor != 0) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0);
|
|
|
|
|
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1);
|
|
|
|
|