|
|
|
@ -971,11 +971,17 @@ def batch_norm(input,
|
|
|
|
|
attr=helper.param_attr, shape=param_shape, dtype=dtype, is_bias=True)
|
|
|
|
|
|
|
|
|
|
mean = helper.create_global_variable(
|
|
|
|
|
dtype=input.dtype, shape=param_shape, persistable=True)
|
|
|
|
|
dtype=input.dtype,
|
|
|
|
|
shape=param_shape,
|
|
|
|
|
persistable=True,
|
|
|
|
|
stop_gradient=True)
|
|
|
|
|
helper.set_variable_initializer(var=mean, initializer=Constant(0.0))
|
|
|
|
|
|
|
|
|
|
variance = helper.create_global_variable(
|
|
|
|
|
dtype=input.dtype, shape=param_shape, persistable=True)
|
|
|
|
|
dtype=input.dtype,
|
|
|
|
|
shape=param_shape,
|
|
|
|
|
persistable=True,
|
|
|
|
|
stop_gradient=True)
|
|
|
|
|
helper.set_variable_initializer(var=variance, initializer=Constant(1.0))
|
|
|
|
|
|
|
|
|
|
# create output
|
|
|
|
@ -983,8 +989,8 @@ def batch_norm(input,
|
|
|
|
|
mean_out = mean
|
|
|
|
|
# variance and variance out share the same memory
|
|
|
|
|
variance_out = variance
|
|
|
|
|
saved_mean = helper.create_tmp_variable(dtype)
|
|
|
|
|
saved_variance = helper.create_tmp_variable(dtype)
|
|
|
|
|
saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
|
|
|
|
|
saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
|
|
|
|
|
|
|
|
|
|
batch_norm_out = helper.create_tmp_variable(dtype)
|
|
|
|
|
|
|
|
|
|