fix bn train&eval loss problem

pull/1136/head
zhaojichen 5 years ago
parent fcb33834d2
commit 59993c4843

@ -43,7 +43,7 @@ class _BatchNorm(Cell):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True,
use_batch_statistics=None,
device_num_each_group=1):
super(_BatchNorm, self).__init__()
if num_features < 1:
@ -147,7 +147,11 @@ class _BatchNorm(Cell):
return y
def construct(self, x):
if self.training and self.use_batch_statistics:
if self.use_batch_statistics is None:
flag = self.training
else:
flag = self.use_batch_statistics
if flag:
if self.is_ge_backend and self.is_global:
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)
@ -236,8 +240,10 @@ class BatchNorm1d(_BatchNorm):
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
the mean value and variance value of specified value. Default: True.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, training process will use the mean and
variance of current batch data and track the running mean and variance, eval process will use the running
mean and variance. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -259,7 +265,7 @@ class BatchNorm1d(_BatchNorm):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True):
use_batch_statistics=None):
super(BatchNorm1d, self).__init__(num_features,
eps,
momentum,
@ -307,8 +313,10 @@ class BatchNorm2d(_BatchNorm):
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
the mean value and variance value of specified value. Default: True.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, training process will use the mean and
variance of current batch data and track the running mean and variance, eval process will use the running
mean and variance. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -330,7 +338,7 @@ class BatchNorm2d(_BatchNorm):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True):
use_batch_statistics=None):
super(BatchNorm2d, self).__init__(num_features,
eps,
momentum,
@ -379,8 +387,10 @@ class GlobalBatchNorm(_BatchNorm):
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
the mean value and variance value of specified value. Default: True.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, training process will use the mean and
variance of current batch data and track the running mean and variance, eval process will use the running
mean and variance. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -402,7 +412,7 @@ class GlobalBatchNorm(_BatchNorm):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True,
use_batch_statistics=None,
device_num_each_group=1):
super(GlobalBatchNorm, self).__init__(num_features,
eps,

Loading…
Cancel
Save