|
|
|
@ -44,7 +44,8 @@ class _BatchNorm(Cell):
|
|
|
|
|
moving_mean_init='zeros',
|
|
|
|
|
moving_var_init='ones',
|
|
|
|
|
use_batch_statistics=None,
|
|
|
|
|
device_num_each_group=1):
|
|
|
|
|
device_num_each_group=1,
|
|
|
|
|
input_dims='1d'):
|
|
|
|
|
super(_BatchNorm, self).__init__()
|
|
|
|
|
if num_features < 1:
|
|
|
|
|
raise ValueError("num_features must be at least 1")
|
|
|
|
@ -55,6 +56,7 @@ class _BatchNorm(Cell):
|
|
|
|
|
self.use_batch_statistics = use_batch_statistics
|
|
|
|
|
self.num_features = num_features
|
|
|
|
|
self.eps = eps
|
|
|
|
|
self.input_dims = input_dims
|
|
|
|
|
self.moving_mean = Parameter(initializer(
|
|
|
|
|
moving_mean_init, num_features), name="mean", requires_grad=False)
|
|
|
|
|
self.moving_variance = Parameter(initializer(
|
|
|
|
@ -145,6 +147,8 @@ class _BatchNorm(Cell):
|
|
|
|
|
return y
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
if self.input_dims == '2d':
|
|
|
|
|
_shape_check(self.shape(x))
|
|
|
|
|
if self.use_batch_statistics is None:
|
|
|
|
|
flag = self.training
|
|
|
|
|
else:
|
|
|
|
@ -253,10 +257,10 @@ class BatchNorm1d(_BatchNorm):
|
|
|
|
|
mean and variance. Default: None.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
|
|
|
|
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in})`.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
|
|
|
|
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out})`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> net = nn.BatchNorm1d(num_features=16)
|
|
|
|
@ -282,7 +286,8 @@ class BatchNorm1d(_BatchNorm):
|
|
|
|
|
beta_init,
|
|
|
|
|
moving_mean_init,
|
|
|
|
|
moving_var_init,
|
|
|
|
|
use_batch_statistics)
|
|
|
|
|
use_batch_statistics,
|
|
|
|
|
input_dims='1d')
|
|
|
|
|
|
|
|
|
|
def _check_data_dim(self, x):
|
|
|
|
|
if x.dim() != 2:
|
|
|
|
@ -357,7 +362,8 @@ class BatchNorm2d(_BatchNorm):
|
|
|
|
|
beta_init,
|
|
|
|
|
moving_mean_init,
|
|
|
|
|
moving_var_init,
|
|
|
|
|
use_batch_statistics)
|
|
|
|
|
use_batch_statistics,
|
|
|
|
|
input_dims='2d')
|
|
|
|
|
|
|
|
|
|
def _check_data_dim(self, x):
|
|
|
|
|
if x.dim() != 4:
|
|
|
|
|