|
|
|
@ -366,15 +366,15 @@ class GlobalBatchNorm(_BatchNorm):
|
|
|
|
|
use_batch_statistics=True,
|
|
|
|
|
group=1):
|
|
|
|
|
super(GlobalBatchNorm, self).__init__(num_features,
|
|
|
|
|
eps,
|
|
|
|
|
momentum,
|
|
|
|
|
affine,
|
|
|
|
|
gamma_init,
|
|
|
|
|
beta_init,
|
|
|
|
|
moving_mean_init,
|
|
|
|
|
moving_var_init,
|
|
|
|
|
use_batch_statistics,
|
|
|
|
|
group)
|
|
|
|
|
eps,
|
|
|
|
|
momentum,
|
|
|
|
|
affine,
|
|
|
|
|
gamma_init,
|
|
|
|
|
beta_init,
|
|
|
|
|
moving_mean_init,
|
|
|
|
|
moving_var_init,
|
|
|
|
|
use_batch_statistics,
|
|
|
|
|
group)
|
|
|
|
|
self.group = check_int_positive(group)
|
|
|
|
|
if self.group <= 1:
|
|
|
|
|
raise ValueError("the number of group must be greater than 1.")
|
|
|
|
|