batch-norm forward backward nchw, nhwc passed

fix-typo
zchen0211 7 years ago
parent 03789a7df4
commit f456a4e938

@ -184,8 +184,8 @@ class TestBatchNormOp(OpTest):
print 'python: NHWC, NCHW, backward checking passed'
def test_forward_backward(self):
def test_with_place(place, tensor_format):
# attr
data_format = "NCHW"
epsilon = 0.00001
momentum = 0.9
@ -222,9 +222,9 @@ class TestBatchNormOp(OpTest):
y_grad[0, 0, 0, 0] = 1.
# y_grad = np.random.random_sample(x_shape).astype(np.float32)
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad(
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, data_format)
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon,
data_format)
def test_with_place(place, tensor_format=data_format):
scope = core.Scope()
# create input
@ -275,14 +275,13 @@ class TestBatchNormOp(OpTest):
self.__assert_close(saved_variance_tensor, saved_variance,
"saved_variance")
self.__assert_close(mean_out_tensor, mean_out, "mean_out")
# FIXME(qiao) figure out why with cuDNN variance_out have a higher error rate
if isinstance(place, core.GPUPlace):
atol = 5e-2
else:
atol = 1e-4
self.__assert_close(variance_out_tensor, variance_out,
"variance_out", atol)
print "op test forward passed: ", tensor_format
print "op test forward passed: ", str(place), tensor_format
# run backward
batch_norm_op_grad = get_backward_op(scope, batch_norm_op, set())
@ -307,14 +306,14 @@ class TestBatchNormOp(OpTest):
self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad")
self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad")
self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad")
print "op test backward passed: ", tensor_format
print "op test backward passed: ", str(place), tensor_format
places = [core.CPUPlace()]
if core.is_compile_gpu() and core.op_support_gpu("batch_norm"):
places.append(core.GPUPlace(0))
for place in places:
test_with_place(place)
print "test forward passed"
for data_format in ["NCHW", "NHWC"]:
test_with_place(place, data_format)
if __name__ == '__main__':

Loading…
Cancel
Save