|
|
|
@ -31,6 +31,37 @@ def get_backward_op(scope, op, no_grad_set):
|
|
|
|
|
return backward_op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _reference_testing(x, scale, offset, mean, var, epsilon, data_format):
|
|
|
|
|
x_shape = x.shape
|
|
|
|
|
if len(x_shape) == 2:
|
|
|
|
|
if data_format == "NCHW":
|
|
|
|
|
x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
|
|
|
|
|
else:
|
|
|
|
|
x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
|
|
|
|
|
|
|
|
|
|
if data_format == "NCHW":
|
|
|
|
|
n, c, h, w = x.shape
|
|
|
|
|
mean_tile = np.reshape(mean, (1, c, 1, 1))
|
|
|
|
|
mean_tile = np.tile(mean_tile, (n, 1, h, w))
|
|
|
|
|
var_tile = np.reshape(var, (1, c, 1, 1))
|
|
|
|
|
var_tile = np.tile(var_tile, (n, 1, h, w))
|
|
|
|
|
normalized = (x - mean_tile) / np.sqrt(var_tile + epsilon)
|
|
|
|
|
scale_tile = np.reshape(scale, (1, c, 1, 1))
|
|
|
|
|
scale_tile = np.tile(scale_tile, (n, 1, h, w))
|
|
|
|
|
offset_tile = np.reshape(offset, (1, c, 1, 1))
|
|
|
|
|
offset_tile = np.reshape(offset_tile, (1, c, 1, 1))
|
|
|
|
|
y = normalized * scale_tile + offset_tile
|
|
|
|
|
elif data_format == "NHWC":
|
|
|
|
|
normalized = (x - mean) / np.sqrt(var + epsilon)
|
|
|
|
|
y = normalized * scale + offset
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Unknown data order.")
|
|
|
|
|
|
|
|
|
|
if len(x_shape) == 2:
|
|
|
|
|
y = np.reshape(y, x_shape)
|
|
|
|
|
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _reference_training(x, scale, offset, epsilon, data_format):
|
|
|
|
|
x_shape = x.shape
|
|
|
|
|
if len(x_shape) == 2:
|
|
|
|
@ -155,11 +186,159 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
|
|
|
|
|
__set_tensor__(output, data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestBatchNormOp(OpTest):
|
|
|
|
|
class TestBatchNormOpInference(OpTest):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.dtype = np.float32
|
|
|
|
|
|
|
|
|
|
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
|
|
|
|
|
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
|
|
|
|
|
|
|
|
|
|
def test_python(self):
|
|
|
|
|
def check_with_place(self, place, data_layout, dtype, shape):
|
|
|
|
|
epsilon = 0.00001
|
|
|
|
|
if len(shape) == 2:
|
|
|
|
|
x_shape = shape
|
|
|
|
|
c = x_shape[1]
|
|
|
|
|
else:
|
|
|
|
|
n, h, w, c = shape[0], shape[1], shape[2], shape[3]
|
|
|
|
|
if data_layout == "NHWC":
|
|
|
|
|
x_shape = [n, h, w, c]
|
|
|
|
|
elif data_layout == "NCHW":
|
|
|
|
|
x_shape = [n, c, h, w]
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Unknown data layout.")
|
|
|
|
|
scale_shape = [c]
|
|
|
|
|
|
|
|
|
|
x_val = np.random.random_sample(x_shape).astype(dtype)
|
|
|
|
|
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
|
|
|
|
|
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
|
|
|
|
|
|
|
|
|
|
mean = np.zeros(scale_shape).astype(np.float32)
|
|
|
|
|
variance = np.ones(scale_shape).astype(np.float32)
|
|
|
|
|
|
|
|
|
|
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
|
|
|
|
|
epsilon, data_layout).astype(dtype)
|
|
|
|
|
|
|
|
|
|
scope = core.Scope()
|
|
|
|
|
|
|
|
|
|
# create input
|
|
|
|
|
x_tensor = create_or_get_tensor(scope, "x_val",
|
|
|
|
|
OpTest.np_dtype_to_fluid_dtype(x_val),
|
|
|
|
|
place)
|
|
|
|
|
scale_tensor = create_or_get_tensor(
|
|
|
|
|
scope, "scale_val",
|
|
|
|
|
OpTest.np_dtype_to_fluid_dtype(scale_val), place)
|
|
|
|
|
bias_tensor = create_or_get_tensor(
|
|
|
|
|
scope, "bias_val", OpTest.np_dtype_to_fluid_dtype(bias_val), place)
|
|
|
|
|
mean_tensor = create_or_get_tensor(scope, "mean",
|
|
|
|
|
OpTest.np_dtype_to_fluid_dtype(mean),
|
|
|
|
|
place)
|
|
|
|
|
variance_tensor = create_or_get_tensor(
|
|
|
|
|
scope, "variance", OpTest.np_dtype_to_fluid_dtype(variance), place)
|
|
|
|
|
|
|
|
|
|
# create output
|
|
|
|
|
y_tensor = create_or_get_tensor(scope, "y_out", None, place)
|
|
|
|
|
saved_mean_tensor = create_or_get_tensor(scope, "saved_mean", None,
|
|
|
|
|
place)
|
|
|
|
|
saved_variance_tensor = create_or_get_tensor(scope, "saved_variance",
|
|
|
|
|
None, place)
|
|
|
|
|
mean_out_tensor = mean_tensor
|
|
|
|
|
variance_out_tensor = variance_tensor
|
|
|
|
|
|
|
|
|
|
batch_norm_op = Operator(
|
|
|
|
|
"batch_norm",
|
|
|
|
|
# inputs
|
|
|
|
|
X="x_val",
|
|
|
|
|
Scale="scale_val",
|
|
|
|
|
Bias="bias_val",
|
|
|
|
|
Mean="mean",
|
|
|
|
|
Variance="variance",
|
|
|
|
|
# outputs
|
|
|
|
|
Y="y_out",
|
|
|
|
|
MeanOut="mean",
|
|
|
|
|
VarianceOut="variance",
|
|
|
|
|
SavedMean="saved_mean",
|
|
|
|
|
SavedVariance="saved_variance",
|
|
|
|
|
# attrs
|
|
|
|
|
is_test=True,
|
|
|
|
|
data_layout=data_layout,
|
|
|
|
|
epsilon=epsilon)
|
|
|
|
|
|
|
|
|
|
batch_norm_op.run(scope, place)
|
|
|
|
|
|
|
|
|
|
# check inference result
|
|
|
|
|
self.__assert_close(
|
|
|
|
|
y_tensor,
|
|
|
|
|
y_out,
|
|
|
|
|
"inference output are different at " + str(place) + ", " +
|
|
|
|
|
data_layout + ", " + str(np.dtype(dtype)) +
|
|
|
|
|
str(np.array(y_tensor)) + str(y_out),
|
|
|
|
|
atol=1e-3)
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
|
places = [core.CPUPlace()]
|
|
|
|
|
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
|
|
|
|
|
places.append(core.CUDAPlace(0))
|
|
|
|
|
|
|
|
|
|
for place in places:
|
|
|
|
|
for data_format in ["NCHW", "NHWC"]:
|
|
|
|
|
self.check_with_place(place, data_format, self.dtype,
|
|
|
|
|
[2, 3, 4, 5])
|
|
|
|
|
self.check_with_place(place, data_format, self.dtype, [2, 3])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestFP16BatchNormOpInference(TestBatchNormOpInference):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.dtype = np.float16
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
|
places = []
|
|
|
|
|
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
|
|
|
|
|
place = core.CUDAPlace(0)
|
|
|
|
|
if core.is_float16_supported(place):
|
|
|
|
|
places.append(place)
|
|
|
|
|
|
|
|
|
|
for place in places:
|
|
|
|
|
for data_format in ["NCHW", "NHWC"]:
|
|
|
|
|
self.check_with_place(place, data_format, self.dtype,
|
|
|
|
|
[2, 3, 4, 5])
|
|
|
|
|
self.check_with_place(place, data_format, self.dtype, [2, 3])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestBatchNormOpTraining(OpTest):
|
|
|
|
|
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
|
|
|
|
|
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
|
|
|
|
|
|
|
|
|
|
def test_python_testing(self):
|
|
|
|
|
data_format = "NHWC"
|
|
|
|
|
epsilon = 0.00001
|
|
|
|
|
|
|
|
|
|
n, h, w, c = 2, 3, 4, 5
|
|
|
|
|
x_shape = [n, h, w, c]
|
|
|
|
|
scale_shape = [c]
|
|
|
|
|
|
|
|
|
|
x_val = np.random.random_sample(x_shape).astype(np.float32)
|
|
|
|
|
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
|
|
|
|
|
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
|
|
|
|
|
|
|
|
|
|
mean = np.zeros(scale_shape).astype(np.float32)
|
|
|
|
|
variance = np.ones(scale_shape).astype(np.float32)
|
|
|
|
|
|
|
|
|
|
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
|
|
|
|
|
epsilon, "NHWC")
|
|
|
|
|
|
|
|
|
|
# running N, C, H, W case
|
|
|
|
|
# should produce the same results
|
|
|
|
|
x_shape2 = [n, c, h, w]
|
|
|
|
|
x_val2 = np.transpose(x_val, (0, 3, 1, 2))
|
|
|
|
|
y_out2 = _reference_testing(x_val2, scale_val, bias_val, mean, variance,
|
|
|
|
|
epsilon, "NCHW")
|
|
|
|
|
|
|
|
|
|
# transfer (N, C, H, W) back to (N, H, W, C)
|
|
|
|
|
y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
|
|
|
|
|
self.__assert_close(y_out, y_out2_trans, "inference output")
|
|
|
|
|
print 'python: NHWC, NCHW, inference checking passed'
|
|
|
|
|
|
|
|
|
|
def test_python_training(self):
|
|
|
|
|
data_format = "NHWC"
|
|
|
|
|
epsilon = 0.00001
|
|
|
|
|
momentum = 0.9
|
|
|
|
@ -197,7 +376,7 @@ class TestBatchNormOp(OpTest):
|
|
|
|
|
|
|
|
|
|
# transfer (N, C, H, W) back to (N, H, W, C)
|
|
|
|
|
y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
|
|
|
|
|
self.__assert_close(y_out, y_out2_trans, "batch variance")
|
|
|
|
|
self.__assert_close(y_out, y_out2_trans, "batch output")
|
|
|
|
|
print 'python: NHWC, NCHW, forward checking passed'
|
|
|
|
|
|
|
|
|
|
# test backward now
|
|
|
|
|