diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 5faf046d18..ffde5cecec 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -166,6 +166,8 @@ class _BatchNorm(Cell): def extend_repr(self): return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format( self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance) + +@constexpr def _channel_check(channel, num_channel): if channel != num_channel: raise ValueError("the input channel is not equal with num_channels")