From 2e417b6011b05662602e70f9564681c7e4a7cfd1 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Wed, 25 Oct 2017 16:23:46 -0700 Subject: [PATCH 1/4] batch norm --- .../v2/framework/tests/test_batch_norm_op.py | 143 +++++++++++++++--- 1 file changed, 121 insertions(+), 22 deletions(-) diff --git a/python/paddle/v2/framework/tests/test_batch_norm_op.py b/python/paddle/v2/framework/tests/test_batch_norm_op.py index b7b071c24d..76c1ff018a 100644 --- a/python/paddle/v2/framework/tests/test_batch_norm_op.py +++ b/python/paddle/v2/framework/tests/test_batch_norm_op.py @@ -6,16 +6,36 @@ from paddle.v2.framework.op import Operator def _reference_training(x, scale, offset, epsilon, data_format): - if data_format != "NHWC": - raise ValueError("data_format must be NHWC, got %s." % data_format) - x_square = x * x - x_square_sum = np.sum(x_square, (0, 1, 2)) - x_sum = np.sum(x, axis=(0, 1, 2)) - element_count = np.size(x) / int(np.shape(x)[-1]) - mean = x_sum / element_count - var = x_square_sum / element_count - mean * mean - normalized = (x - mean) / np.sqrt(var + epsilon) - return (normalized * scale + offset), mean, var + if data_format == "NCHW": + n, c, h, w = x.shape + x_square = x * x + x_square_sum = np.sum(x_square, (0, 2, 3)) + x_sum = np.sum(x, axis=(0, 2, 3)) + element_count = np.size(x) / int(np.shape(x)[1]) + mean = x_sum / element_count + var = x_square_sum / element_count - mean * mean + mean_tile = np.reshape(mean, (1, c, 1, 1)) + mean_tile = np.tile(mean_tile, (n, 1, h, w)) + var_tile = np.reshape(var, (1, c, 1, 1)) + var_tile = np.tile(var_tile, (n, 1, h, w)) + normalized = (x - mean_tile) / np.sqrt(var_tile + epsilon) + scale_tile = np.reshape(scale, (1, c, 1, 1)) + scale_tile = np.tile(scale_tile, (n, 1, h, w)) + offset_tile = np.reshape(offset, (1, c, 1, 1)) + offset_tile = np.reshape(offset_tile, (1, c, 1, 1)) + y = normalized * scale_tile + offset_tile + return y, mean, var + elif data_format == "NHWC": + x_square = x * x + x_square_sum = np.sum(x_square, (0, 1, 2)) + x_sum = np.sum(x, axis=(0, 1, 2)) + element_count = np.size(x) / int(np.shape(x)[-1]) + mean = x_sum / element_count + var = x_square_sum / element_count - mean * mean + normalized = (x - mean) / np.sqrt(var + epsilon) + return (normalized * scale + offset), mean, var + else: + raise ValueError("Unknown data order.") def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format): @@ -28,8 +48,13 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format): # grad_x = # 1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) - # (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon)) - if data_format != "NHWC": - raise ValueError("data_format must be NHWC, got %s." % data_format) + + # transfer from (N, C, H, W) to (N, H, W, C) to simplify computation + if data_format == "NCHW": + x = np.transpose(x, (0, 2, 3, 1)) + grad_y = np.transpose(grad_y, (0, 2, 3, 1)) + + # raise ValueError("data_format must be NHWC, got %s." % data_format) grad_x = scale * (grad_y - np.mean( grad_y, axis=(0, 1, 2)) - (x - mean) * np.mean( grad_y * (x - mean), axis=(0, 1, 2)) / @@ -37,6 +62,12 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format): grad_scale = np.sum(grad_y * (x - mean) / np.sqrt(var + epsilon), axis=(0, 1, 2)) grad_offset = np.sum(grad_y, axis=(0, 1, 2)) + + # transfer back to N, C, H, W + if data_format == "NCHW": + grad_x = np.transpose(grad_x, (0, 3, 1, 2)) + x = np.transpose(x, (0, 3, 1, 2)) + grad_y = np.transpose(grad_y, (0, 3, 1, 2)) return grad_x, grad_scale, grad_offset @@ -72,39 +103,104 @@ class TestBatchNormOp(OpTest): def __assert_close(self, tensor, np_array, msg, atol=1e-4): self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) - def test_forward_backward(self): - # attr + def test_python(self): data_format = "NHWC" epsilon = 0.00001 momentum = 0.9 + # N, H, W, C: 2, 3, 4, 2 channel_num = 2 x_shape = [2, 3, 4, channel_num] scale_shape = [channel_num] - # input x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) bias_val = np.random.random_sample(scale_shape).astype(np.float32) mean = np.zeros(scale_shape).astype(np.float32) - variance = np.zeros(scale_shape).astype(np.float32) + variance = np.ones(scale_shape).astype(np.float32) + + # run forward + y_out, saved_mean, var_ref = _reference_training( + x_val, scale_val, bias_val, epsilon, "NHWC") + + # + mean_out = saved_mean * (1. - momentum) + momentum * mean + variance_out = var_ref * (1. - momentum) + momentum * variance + saved_variance = 1. / np.sqrt(var_ref + epsilon) + + # running N, C, H, W case + # should produce the same results + x_shape2 = [2, channel_num, 3, 4] + x_val2 = np.transpose(x_val, (0, 3, 1, 2)) + y_out2, saved_mean2, var_ref2 = _reference_training( + x_val2, scale_val, bias_val, epsilon, "NCHW") + + self.__assert_close(saved_mean, saved_mean2, "batch mean") + self.__assert_close(var_ref, var_ref2, "batch variance") + + # transfer (N, C, H, W) back to (N, H, W, C) + y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1)) + self.__assert_close(y_out, y_out2_trans, "batch variance") + print 'python: NHWC, NCHW, forward checking passed' + + # test backward now + # NHWC + y_grad = np.ones(x_shape).astype(np.float32) + x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad( + x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, "NHWC") + + # NCHW + y_grad2 = np.ones(x_shape2).astype(np.float32) + x_grad_ref2, scale_grad_ref2, bias_grad_ref2 = _reference_grad( + x_val2, y_grad2, scale_val, saved_mean2, var_ref2, epsilon, "NCHW") + + self.__assert_close(scale_grad_ref, scale_grad_ref2, "scale gradient") + self.__assert_close(bias_grad_ref, bias_grad_ref2, "bias gradient") + + x_grad_transpose = np.transpose(x_grad_ref2, (0, 2, 3, 1)) + self.__assert_close(x_grad_ref, x_grad_transpose, "x gradient") + print 'python: NHWC, NCHW, backward checking passed' + + def test_forward_backward(self): + # attr + data_format = "NCHW" + epsilon = 0.00001 + momentum = 0.9 + + # N, H, W, C: 2, 3, 4, 2 + n, h, w, c = 2, 3, 4, 2 + + if data_format == "NHWC": + x_shape = [n, h, w, c] + elif data_format == "NCHW": + x_shape = [n, c, h, w] + else: + raise ValueError("Unknown data type.") + scale_shape = [c] + + x_val = np.random.random_sample(x_shape).astype(np.float32) + scale_val = np.random.random_sample(scale_shape).astype(np.float32) + bias_val = np.random.random_sample(scale_shape).astype(np.float32) + + mean = np.zeros(scale_shape).astype(np.float32) + variance = np.ones(scale_shape).astype(np.float32) # run forward y_out, saved_mean, var_ref = _reference_training( x_val, scale_val, bias_val, epsilon, data_format) - # run backward - mean_out = saved_mean * (1 - momentum) - variance_out = var_ref * (1 - momentum) - saved_variance = 1 / np.sqrt(var_ref + epsilon) + # update moving mean and variance + mean_out = saved_mean * (1. - momentum) + momentum * mean + variance_out = var_ref * (1. - momentum) + momentum * variance + saved_variance = 1. / np.sqrt(var_ref + epsilon) # for gradient test y_grad = np.ones(x_shape).astype(np.float32) x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad( x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, data_format) - def test_with_place(place): + def test_with_place(place, tensor_format=data_format): scope = core.Scope() # create input @@ -142,7 +238,7 @@ class TestBatchNormOp(OpTest): SavedVariance="saved_variance", # attrs is_test=False, - tensor_format=data_format, + tensor_format=tensor_format, momentum=momentum, epsilon=epsilon) @@ -162,6 +258,7 @@ class TestBatchNormOp(OpTest): atol = 1e-4 self.__assert_close(variance_out_tensor, variance_out, "variance_out", atol) + print "op test forward passed: ", tensor_format # run backward batch_norm_op_grad = get_backward_op(scope, batch_norm_op, set()) @@ -185,12 +282,14 @@ class TestBatchNormOp(OpTest): self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad") self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad") self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad") + print "op test backward passed: ", tensor_format places = [core.CPUPlace()] if core.is_compile_gpu() and core.op_support_gpu("batch_norm"): places.append(core.GPUPlace(0)) for place in places: test_with_place(place) + print "test forward passed" if __name__ == '__main__': From 822cf9785b42ab6b9316b6bcdd3fb63f11773036 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Fri, 27 Oct 2017 10:28:48 -0700 Subject: [PATCH 2/4] more test and bn fix --- paddle/operators/batch_norm_op.cu | 3 --- .../v2/framework/tests/test_batch_norm_op.py | 21 ++++++++++++------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/paddle/operators/batch_norm_op.cu b/paddle/operators/batch_norm_op.cu index 6ba6ee12ec..6cbbb33438 100644 --- a/paddle/operators/batch_norm_op.cu +++ b/paddle/operators/batch_norm_op.cu @@ -117,9 +117,6 @@ class BatchNormKernel : public framework::OpKernel { math::SetConstant functor; functor(ctx.device_context(), saved_mean, 0); functor(ctx.device_context(), saved_variance, 0); - // FIXME(qiao) should not set zero self - functor(ctx.device_context(), mean_out, 0); - functor(ctx.device_context(), variance_out, 0); auto handle = ctx.cuda_device_context().cudnn_handle(); diff --git a/python/paddle/v2/framework/tests/test_batch_norm_op.py b/python/paddle/v2/framework/tests/test_batch_norm_op.py index 76c1ff018a..a82aaa4d39 100644 --- a/python/paddle/v2/framework/tests/test_batch_norm_op.py +++ b/python/paddle/v2/framework/tests/test_batch_norm_op.py @@ -104,14 +104,14 @@ class TestBatchNormOp(OpTest): self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) def test_python(self): - data_format = "NHWC" + data_format = "NCHW" epsilon = 0.00001 momentum = 0.9 # N, H, W, C: 2, 3, 4, 2 - channel_num = 2 - x_shape = [2, 3, 4, channel_num] - scale_shape = [channel_num] + n, h, w, c = 2, 3, 4, 2 + x_shape = [n, h, w, c] + scale_shape = [c] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) @@ -131,7 +131,7 @@ class TestBatchNormOp(OpTest): # running N, C, H, W case # should produce the same results - x_shape2 = [2, channel_num, 3, 4] + x_shape2 = [n, c, h, w] x_val2 = np.transpose(x_val, (0, 3, 1, 2)) y_out2, saved_mean2, var_ref2 = _reference_training( x_val2, scale_val, bias_val, epsilon, "NCHW") @@ -146,12 +146,15 @@ class TestBatchNormOp(OpTest): # test backward now # NHWC - y_grad = np.ones(x_shape).astype(np.float32) + self.y_grad = np.random.random_sample(x_shape).astype(np.float32) + y_grad = self.y_grad + # y_grad = np.ones(x_shape).astype(np.float32) x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad( x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, "NHWC") # NCHW - y_grad2 = np.ones(x_shape2).astype(np.float32) + y_grad2 = np.transpose(y_grad, (0, 3, 1, 2)) + # y_grad2 = np.ones(x_shape2).astype(np.float32) x_grad_ref2, scale_grad_ref2, bias_grad_ref2 = _reference_grad( x_val2, y_grad2, scale_val, saved_mean2, var_ref2, epsilon, "NCHW") @@ -168,7 +171,7 @@ class TestBatchNormOp(OpTest): epsilon = 0.00001 momentum = 0.9 - # N, H, W, C: 2, 3, 4, 2 + # N, H, W, C: 12, 3, 4, 2 n, h, w, c = 2, 3, 4, 2 if data_format == "NHWC": @@ -279,6 +282,8 @@ class TestBatchNormOp(OpTest): None, place) # check gradient output + print 'var x_grad tensor: ', str(place), np.array(x_grad_tensor) + print 'var x_grad by python: ', str(place), x_grad_ref self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad") self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad") self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad") From 03789a7df4beb929aa67ea9892c214d68fd6e7d8 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Fri, 27 Oct 2017 14:55:15 -0700 Subject: [PATCH 3/4] batch norm fully tortured and passed --- paddle/operators/batch_norm_op.cu | 11 ++++-- .../v2/framework/tests/test_batch_norm_op.py | 35 +++++++++++-------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/paddle/operators/batch_norm_op.cu b/paddle/operators/batch_norm_op.cu index 6cbbb33438..726d1ea1b8 100644 --- a/paddle/operators/batch_norm_op.cu +++ b/paddle/operators/batch_norm_op.cu @@ -208,8 +208,15 @@ class BatchNormGradKernel mode_ = CUDNN_BATCHNORM_SPATIAL; #endif - std::vector dims = {N, C, H, W, D}; - std::vector strides = {H * W * C * D, 1, W * D * C, D * C, C}; + std::vector dims; + std::vector strides; + if (tensor_format == TensorFormat::NCHW) { + dims = {N, C, H, W, D}; + strides = {C * H * W * D, H * W * D, W * D, D, 1}; + } else { + dims = {N, C, H, W, D}; + strides = {H * W * C * D, 1, W * D * C, D * C, C}; + } CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( data_desc_, CudnnDataType::type, x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); diff --git a/python/paddle/v2/framework/tests/test_batch_norm_op.py b/python/paddle/v2/framework/tests/test_batch_norm_op.py index f0e7f1e523..fedb48eee8 100644 --- a/python/paddle/v2/framework/tests/test_batch_norm_op.py +++ b/python/paddle/v2/framework/tests/test_batch_norm_op.py @@ -96,22 +96,25 @@ def create_or_get_tensor(scope, var_name, var, place): return tensor -def set_output_grad(scope, outputs, place): - def __set_tensor__(name): +def set_output_grad(scope, outputs, place, feed_dict=None): + def __set_tensor__(name, data=None): out_tensor = scope.find_var(name).get_tensor() grad_tensor = scope.var(grad_var_name(name)).get_tensor() out_dtype = out_tensor.dtype() - if out_dtype == core.DataType.FP64: - data = np.ones(out_tensor.shape(), dtype=np.float64) - elif out_dtype == core.DataType.FP32: - data = np.ones(out_tensor.shape(), dtype=np.float32) - else: - raise ValueError("Not supported data type " + str(out_dtype)) - + if data is None: + if out_dtype == core.DataType.FP64: + data = np.ones(out_tensor.shape(), dtype=np.float64) + elif out_dtype == core.DataType.FP32: + data = np.ones(out_tensor.shape(), dtype=np.float32) + else: + raise ValueError("Not supported data type " + str(out_dtype)) grad_tensor.set(data, place) for output in outputs: - __set_tensor__(output) + data = None + if output in feed_dict: + data = feed_dict[output] + __set_tensor__(output, data) class TestBatchNormOp(OpTest): @@ -119,7 +122,7 @@ class TestBatchNormOp(OpTest): self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) def test_python(self): - data_format = "NCHW" + data_format = "NHWC" epsilon = 0.00001 momentum = 0.9 @@ -214,7 +217,10 @@ class TestBatchNormOp(OpTest): saved_variance = 1. / np.sqrt(var_ref + epsilon) # for gradient test - y_grad = np.ones(x_shape).astype(np.float32) + # y_grad = np.ones(x_shape).astype(np.float32) + y_grad = np.zeros(x_shape).astype(np.float32) + y_grad[0, 0, 0, 0] = 1. + # y_grad = np.random.random_sample(x_shape).astype(np.float32) x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad( x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, data_format) @@ -283,7 +289,8 @@ class TestBatchNormOp(OpTest): set_output_grad( scope, ["y_out", "mean", "variance", "saved_mean", "saved_variance"], - place) + place, + feed_dict={"y_out": y_grad}) batch_norm_op_grad.run(scope, ctx) x_grad_tensor = create_or_get_tensor(scope, @@ -297,8 +304,6 @@ class TestBatchNormOp(OpTest): None, place) # check gradient output - print 'var x_grad tensor: ', str(place), np.array(x_grad_tensor) - print 'var x_grad by python: ', str(place), x_grad_ref self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad") self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad") self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad") From f456a4e938c443d68484848a1aeece71f5e0cbd3 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Fri, 27 Oct 2017 15:31:36 -0700 Subject: [PATCH 4/4] batch-norm forward backward nchw, nhwc passed --- .../v2/framework/tests/test_batch_norm_op.py | 89 +++++++++---------- 1 file changed, 44 insertions(+), 45 deletions(-) diff --git a/python/paddle/v2/framework/tests/test_batch_norm_op.py b/python/paddle/v2/framework/tests/test_batch_norm_op.py index fedb48eee8..dee339f43c 100644 --- a/python/paddle/v2/framework/tests/test_batch_norm_op.py +++ b/python/paddle/v2/framework/tests/test_batch_norm_op.py @@ -184,47 +184,47 @@ class TestBatchNormOp(OpTest): print 'python: NHWC, NCHW, backward checking passed' def test_forward_backward(self): - # attr - data_format = "NCHW" - epsilon = 0.00001 - momentum = 0.9 - - # N, H, W, C: 12, 3, 4, 2 - n, h, w, c = 2, 3, 4, 2 - - if data_format == "NHWC": - x_shape = [n, h, w, c] - elif data_format == "NCHW": - x_shape = [n, c, h, w] - else: - raise ValueError("Unknown data type.") - scale_shape = [c] - - x_val = np.random.random_sample(x_shape).astype(np.float32) - scale_val = np.random.random_sample(scale_shape).astype(np.float32) - bias_val = np.random.random_sample(scale_shape).astype(np.float32) - - mean = np.zeros(scale_shape).astype(np.float32) - variance = np.ones(scale_shape).astype(np.float32) - - # run forward - y_out, saved_mean, var_ref = _reference_training( - x_val, scale_val, bias_val, epsilon, data_format) - - # update moving mean and variance - mean_out = saved_mean * (1. - momentum) + momentum * mean - variance_out = var_ref * (1. - momentum) + momentum * variance - saved_variance = 1. / np.sqrt(var_ref + epsilon) - - # for gradient test - # y_grad = np.ones(x_shape).astype(np.float32) - y_grad = np.zeros(x_shape).astype(np.float32) - y_grad[0, 0, 0, 0] = 1. - # y_grad = np.random.random_sample(x_shape).astype(np.float32) - x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad( - x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, data_format) + def test_with_place(place, tensor_format): + # attr + epsilon = 0.00001 + momentum = 0.9 + + # N, H, W, C: 12, 3, 4, 2 + n, h, w, c = 2, 3, 4, 2 + + if data_format == "NHWC": + x_shape = [n, h, w, c] + elif data_format == "NCHW": + x_shape = [n, c, h, w] + else: + raise ValueError("Unknown data type.") + scale_shape = [c] + + x_val = np.random.random_sample(x_shape).astype(np.float32) + scale_val = np.random.random_sample(scale_shape).astype(np.float32) + bias_val = np.random.random_sample(scale_shape).astype(np.float32) + + mean = np.zeros(scale_shape).astype(np.float32) + variance = np.ones(scale_shape).astype(np.float32) + + # run forward + y_out, saved_mean, var_ref = _reference_training( + x_val, scale_val, bias_val, epsilon, data_format) + + # update moving mean and variance + mean_out = saved_mean * (1. - momentum) + momentum * mean + variance_out = var_ref * (1. - momentum) + momentum * variance + saved_variance = 1. / np.sqrt(var_ref + epsilon) + + # for gradient test + # y_grad = np.ones(x_shape).astype(np.float32) + y_grad = np.zeros(x_shape).astype(np.float32) + y_grad[0, 0, 0, 0] = 1. + # y_grad = np.random.random_sample(x_shape).astype(np.float32) + x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad( + x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, + data_format) - def test_with_place(place, tensor_format=data_format): scope = core.Scope() # create input @@ -275,14 +275,13 @@ class TestBatchNormOp(OpTest): self.__assert_close(saved_variance_tensor, saved_variance, "saved_variance") self.__assert_close(mean_out_tensor, mean_out, "mean_out") - # FIXME(qiao) figure out why with cuDNN variance_out have a higher error rate if isinstance(place, core.GPUPlace): atol = 5e-2 else: atol = 1e-4 self.__assert_close(variance_out_tensor, variance_out, "variance_out", atol) - print "op test forward passed: ", tensor_format + print "op test forward passed: ", str(place), tensor_format # run backward batch_norm_op_grad = get_backward_op(scope, batch_norm_op, set()) @@ -307,14 +306,14 @@ class TestBatchNormOp(OpTest): self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad") self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad") self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad") - print "op test backward passed: ", tensor_format + print "op test backward passed: ", str(place), tensor_format places = [core.CPUPlace()] if core.is_compile_gpu() and core.op_support_gpu("batch_norm"): places.append(core.GPUPlace(0)) for place in places: - test_with_place(place) - print "test forward passed" + for data_format in ["NCHW", "NHWC"]: + test_with_place(place, data_format) if __name__ == '__main__':