Merge pull request #8069 from qingqing01/bn_name

Allow uers to specify the name of moving mean and variance in batch_norm interface.
emailweixu-patch-1
qingqing01 7 years ago committed by GitHub
commit 4e7e39b4bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1478,7 +1478,9 @@ def batch_norm(input,
param_attr=None,
bias_attr=None,
data_layout='NCHW',
name=None):
name=None,
moving_mean_name=None,
moving_variance_name=None):
"""
This function helps create an operator to implement
the BatchNorm layer using the configurations from the input parameters.
@ -1508,6 +1510,7 @@ def batch_norm(input,
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)
mean = helper.create_global_variable(
name=moving_mean_name,
dtype=input.dtype,
shape=param_shape,
persistable=True,
@ -1515,6 +1518,7 @@ def batch_norm(input,
helper.set_variable_initializer(var=mean, initializer=Constant(0.0))
variance = helper.create_global_variable(
name=moving_variance_name,
dtype=input.dtype,
shape=param_shape,
persistable=True,

Loading…
Cancel
Save