change type of BN's 'mean' and 'variance' from persistable variable to Parameter

tonyyang-svail-patch-1
fengjiayi 7 years ago
parent d4dabe3e0b
commit cf7c745c48

@ -1519,21 +1519,21 @@ def batch_norm(input,
bias = helper.create_parameter(
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)
mean = helper.create_global_variable(
name=moving_mean_name,
dtype=input.dtype,
mean = helper.create_parameter(
attr=ParamAttr(
name=moving_mean_name, initializer=Constant(0.0), trainable=False),
shape=param_shape,
persistable=True,
stop_gradient=True)
helper.set_variable_initializer(var=mean, initializer=Constant(0.0))
dtype=input.dtype)
mean.stop_gradient = True
variance = helper.create_global_variable(
name=moving_variance_name,
dtype=input.dtype,
variance = helper.create_parameter(
attr=ParamAttr(
name=moving_variance_name,
initializer=Constant(1.0),
trainable=False),
shape=param_shape,
persistable=True,
stop_gradient=True)
helper.set_variable_initializer(var=variance, initializer=Constant(1.0))
dtype=input.dtype)
variance.stop_gradient = True
# create output
# mean and mean_out share the same memory

Loading…
Cancel
Save