add inplace attr to bn

helinwang-patch-1
Yang Yang 7 years ago
parent 25317bd312
commit 54a8c04fab

@ -1483,6 +1483,7 @@ def batch_norm(input,
param_attr=None,
bias_attr=None,
data_layout='NCHW',
in_place=False,
name=None,
moving_mean_name=None,
moving_variance_name=None):
@ -1538,7 +1539,7 @@ def batch_norm(input,
saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True)
batch_norm_out = helper.create_tmp_variable(dtype)
batch_norm_out = input if in_place else helper.create_tmp_variable(dtype)
helper.append_op(
type="batch_norm",

Loading…
Cancel
Save