|
|
|
@ -310,6 +310,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
|
|
|
|
|
self.fuse_with_relu = False
|
|
|
|
|
self.data_formats = ["NCHW", "NHWC"]
|
|
|
|
|
self.momentum = 0.9
|
|
|
|
|
self.use_momentum_variable = False
|
|
|
|
|
self.epsilon = 0.00001
|
|
|
|
|
self.init_kernel_type()
|
|
|
|
|
self.init_test_case()
|
|
|
|
@ -367,6 +368,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
|
|
|
|
|
bias = np.random.random_sample(scale_shape).astype(np.float32)
|
|
|
|
|
mean, variance = self.set_mean_variance(scale_shape, x, data_layout)
|
|
|
|
|
y_grad = np.random.random_sample(shape).astype(np.float32)
|
|
|
|
|
momentum_var = np.array([momentum]).astype(np.float32)
|
|
|
|
|
|
|
|
|
|
y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad = self.ref_forward_backward(
|
|
|
|
|
x, y_grad, scale, bias, mean, variance, epsilon, momentum,
|
|
|
|
@ -380,7 +382,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
var_names = [
|
|
|
|
|
'x', 'scale', 'bias', 'mean', 'variance', 'y', 'saved_mean',
|
|
|
|
|
'saved_variance'
|
|
|
|
|
'saved_variance', 'momentum_var'
|
|
|
|
|
]
|
|
|
|
|
ground_truth = {name: var_dict[name] for name in var_names}
|
|
|
|
|
|
|
|
|
@ -392,15 +394,28 @@ class TestBatchNormOpTraining(unittest.TestCase):
|
|
|
|
|
name=name,
|
|
|
|
|
dtype='float32',
|
|
|
|
|
shape=ground_truth[name].shape)
|
|
|
|
|
inputs = {
|
|
|
|
|
"X": block.var('x'),
|
|
|
|
|
"Scale": block.var('scale'),
|
|
|
|
|
"Bias": block.var('bias'),
|
|
|
|
|
"Mean": block.var('mean'),
|
|
|
|
|
"Variance": block.var('variance')
|
|
|
|
|
}
|
|
|
|
|
attrs = {
|
|
|
|
|
"epsilon": epsilon,
|
|
|
|
|
"is_test": False,
|
|
|
|
|
"data_layout": data_layout,
|
|
|
|
|
"use_mkldnn": self.use_mkldnn,
|
|
|
|
|
"fuse_with_relu": self.fuse_with_relu,
|
|
|
|
|
"use_global_stats": self.use_global_stats
|
|
|
|
|
}
|
|
|
|
|
if self.use_momentum_variable:
|
|
|
|
|
inputs['MomentumTensor'] = block.var('momentum_var')
|
|
|
|
|
else:
|
|
|
|
|
attrs['momentum'] = momentum
|
|
|
|
|
bn_op = block.append_op(
|
|
|
|
|
type="batch_norm",
|
|
|
|
|
inputs={
|
|
|
|
|
"X": block.var('x'),
|
|
|
|
|
"Scale": block.var('scale'),
|
|
|
|
|
"Bias": block.var('bias'),
|
|
|
|
|
"Mean": block.var('mean'),
|
|
|
|
|
"Variance": block.var('variance')
|
|
|
|
|
},
|
|
|
|
|
inputs=inputs,
|
|
|
|
|
outputs={
|
|
|
|
|
"Y": block.var('y'),
|
|
|
|
|
"MeanOut": block.var('mean'), # share memory
|
|
|
|
@ -408,15 +423,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
|
|
|
|
|
"SavedMean": block.var('saved_mean'),
|
|
|
|
|
"SavedVariance": block.var('saved_variance')
|
|
|
|
|
},
|
|
|
|
|
attrs={
|
|
|
|
|
"momentum": momentum,
|
|
|
|
|
"epsilon": epsilon,
|
|
|
|
|
"is_test": False,
|
|
|
|
|
"data_layout": data_layout,
|
|
|
|
|
"use_mkldnn": self.use_mkldnn,
|
|
|
|
|
"fuse_with_relu": self.fuse_with_relu,
|
|
|
|
|
"use_global_stats": self.use_global_stats
|
|
|
|
|
})
|
|
|
|
|
attrs=attrs)
|
|
|
|
|
block.create_var(name='y@GRAD', dtype='float32', shape=y.shape)
|
|
|
|
|
|
|
|
|
|
# generate backward op_desc
|
|
|
|
@ -434,14 +441,15 @@ class TestBatchNormOpTraining(unittest.TestCase):
|
|
|
|
|
grad_var.set_dtype(core.VarDesc.VarType.FP32)
|
|
|
|
|
|
|
|
|
|
exe = fluid.Executor(place)
|
|
|
|
|
out = exe.run(
|
|
|
|
|
program,
|
|
|
|
|
feed={
|
|
|
|
|
name: var_dict[name]
|
|
|
|
|
for name in
|
|
|
|
|
['x', 'scale', 'bias', 'mean', 'variance', 'y@GRAD']
|
|
|
|
|
},
|
|
|
|
|
fetch_list=self.fetch_list)
|
|
|
|
|
out = exe.run(program,
|
|
|
|
|
feed={
|
|
|
|
|
name: var_dict[name]
|
|
|
|
|
for name in [
|
|
|
|
|
'x', 'scale', 'bias', 'mean', 'variance',
|
|
|
|
|
'y@GRAD', 'momentum_var'
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
fetch_list=self.fetch_list)
|
|
|
|
|
|
|
|
|
|
for id, name in enumerate(self.fetch_list):
|
|
|
|
|
if name == 'variance':
|
|
|
|
@ -471,6 +479,17 @@ class TestBatchNormOpTrainingCase1(TestBatchNormOpTraining):
|
|
|
|
|
self.fetch_list = ['y', 'mean', 'variance', 'x@GRAD']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestBatchNormOpTrainingMomentumVariable(TestBatchNormOpTraining):
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.use_momentum_variable = True
|
|
|
|
|
self.use_global_stats = False
|
|
|
|
|
self.no_grad_set = set()
|
|
|
|
|
self.fetch_list = [
|
|
|
|
|
'y', 'mean', 'variance', 'saved_mean', 'saved_variance', 'x@GRAD',
|
|
|
|
|
'scale@GRAD', 'bias@GRAD'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining):
|
|
|
|
|
def init_test_case(self):
|
|
|
|
|
self.use_global_stats = True
|
|
|
|
|