|
|
|
@ -1478,7 +1478,9 @@ def batch_norm(input,
|
|
|
|
|
param_attr=None,
|
|
|
|
|
bias_attr=None,
|
|
|
|
|
data_layout='NCHW',
|
|
|
|
|
name=None):
|
|
|
|
|
name=None,
|
|
|
|
|
moving_mean_name=None,
|
|
|
|
|
moving_variance_name=None):
|
|
|
|
|
"""
|
|
|
|
|
This function helps create an operator to implement
|
|
|
|
|
the BatchNorm layer using the configurations from the input parameters.
|
|
|
|
@ -1508,6 +1510,7 @@ def batch_norm(input,
|
|
|
|
|
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)
|
|
|
|
|
|
|
|
|
|
mean = helper.create_global_variable(
|
|
|
|
|
name=moving_mean_name,
|
|
|
|
|
dtype=input.dtype,
|
|
|
|
|
shape=param_shape,
|
|
|
|
|
persistable=True,
|
|
|
|
@ -1515,6 +1518,7 @@ def batch_norm(input,
|
|
|
|
|
helper.set_variable_initializer(var=mean, initializer=Constant(0.0))
|
|
|
|
|
|
|
|
|
|
variance = helper.create_global_variable(
|
|
|
|
|
name=moving_variance_name,
|
|
|
|
|
dtype=input.dtype,
|
|
|
|
|
shape=param_shape,
|
|
|
|
|
persistable=True,
|
|
|
|
|