|
|
|
@ -1483,6 +1483,7 @@ def batch_norm(input,
|
|
|
|
|
param_attr=None,
|
|
|
|
|
bias_attr=None,
|
|
|
|
|
data_layout='NCHW',
|
|
|
|
|
in_place=False,
|
|
|
|
|
name=None,
|
|
|
|
|
moving_mean_name=None,
|
|
|
|
|
moving_variance_name=None):
|
|
|
|
@ -1538,7 +1539,7 @@ def batch_norm(input,
|
|
|
|
|
saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
|
|
|
|
|
saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
|
|
|
|
|
|
|
|
|
|
batch_norm_out = helper.create_tmp_variable(dtype)
|
|
|
|
|
batch_norm_out = input if in_place else helper.create_tmp_variable(dtype)
|
|
|
|
|
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type="batch_norm",
|
|
|
|
|