group_norm support data_layout:NHWC, test=develop, test=document_preview (#19614)

1. group_norm support data_layout=NHWC
2. modified doc of group_norm
expand_as_op_1
Zhang Ting 6 years ago committed by Aurelius84
parent e117114289
commit 93364b45c1

@ -170,7 +170,7 @@ paddle.fluid.layers.beam_search (ArgSpec(args=['pre_ids', 'pre_scores', 'ids', '
paddle.fluid.layers.row_conv (ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)), ('document', '1d8a1c8b686b55631ba1b77805e4eacf'))
paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None), ('document', '2c4d1ae83da6ed35e3b36ba1b3b51d23'))
paddle.fluid.layers.layer_norm (ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)), ('document', '79797f827d89ae72c77960e9696883a9'))
paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '96b24820e8863d6044d5be4eaaddb9fd'))
paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '65231cc8281815124934b1439fbb750c'))
paddle.fluid.layers.spectral_norm (ArgSpec(args=['weight', 'dim', 'power_iters', 'eps', 'name'], varargs=None, keywords=None, defaults=(0, 1, 1e-12, None)), ('document', '9461e67095a6fc5d568fb2ce8fef66ff'))
paddle.fluid.layers.softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax', 'axis'], varargs=None, keywords=None, defaults=(False, -100, True, False, -1)), ('document', '54e1675aa0364f4a78fa72804ec0f413'))
paddle.fluid.layers.smooth_l1 (ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'ecb75c1b00c4c76c98b482f633b7a10c'))

@ -38,9 +38,11 @@ class GroupNormOp : public framework::OperatorWithKernel {
"Output(Mean) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Variance"),
"Output(Variance) of GroupNormOp should not be null.");
auto x_dim = ctx->GetInputDim("X");
auto channel_num = x_dim[1];
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
const int64_t channel_num =
(data_layout == DataLayout::kNCHW ? x_dim[1] : x_dim[x_dim.size() - 1]);
auto batch_size = x_dim[0];
auto groups = ctx->Attrs().Get<int>("groups");
PADDLE_ENFORCE_LE(
@ -91,7 +93,9 @@ class GroupNormOpMaker : public framework::OpProtoAndCheckerMaker {
.AddCustomChecker([](const int &groups) {
PADDLE_ENFORCE_GT(groups, 0, "'groups' should be greater than zero.");
});
AddAttr<std::string>("data_layout",
"An optional string from: \"NHWC\", \"NCHW\". ")
.SetDefault("NCHW");
AddComment(R"DOC(
Group Normalization

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -3764,7 +3764,7 @@ def group_norm(input,
bias :math:`b`. If it is set to False, no bias will be added to the output units.
If it is set to None, the bias is initialized zero. Default: None.
act(str): Activation to be applied to the output of group normalizaiton.
data_layout(string|NCHW): Only NCHW is supported.
data_layout(string, default NCHW): NCHW(num_batch, channels, h, w) or NHWC(num_batch, h, w, channels).
name (str): The name of this layer. It is optional.
Returns:
@ -3783,9 +3783,12 @@ def group_norm(input,
# create intput and parameters
inputs = {'X': input}
input_shape = input.shape
if data_layout != 'NCHW':
raise ValueError("unsupported data layout:" + data_layout)
param_shape = [input_shape[1]]
if data_layout != 'NCHW' and data_layout != 'NHWC':
raise ValueError(
"Param(data_layout) of Op(fluid.layers.group_norm) got wrong value: received "
+ data_layout + " but only NCHW or NHWC supported.")
channel_num = input_shape[1] if data_layout == 'NCHW' else input_shape[-1]
param_shape = [channel_num]
if param_attr:
scale = helper.create_parameter(
attr=helper.param_attr,
@ -3811,8 +3814,11 @@ def group_norm(input,
"Mean": mean_out,
"Variance": variance_out,
},
attrs={"epsilon": epsilon,
"groups": groups})
attrs={
"epsilon": epsilon,
"groups": groups,
"data_layout": data_layout
})
return helper.append_activation(group_norm_out)

@ -24,7 +24,9 @@ from op_test import OpTest
from testsuite import create_op
def group_norm_naive(x, scale, bias, epsilon, groups):
def group_norm_naive(x, scale, bias, epsilon, groups, data_layout):
if data_layout == "NHWC":
x = np.transpose(x, (0, 3, 1, 2)) # NHWC => NCHW
N, C, H, W = x.shape
G = groups
x = x.reshape((N * G, -1))
@ -33,6 +35,8 @@ def group_norm_naive(x, scale, bias, epsilon, groups):
output = (x - mean) / np.sqrt(var + epsilon)
output = output.reshape((N, C, H, W)) * scale.reshape(
(-1, 1, 1)) + bias.reshape((-1, 1, 1))
if data_layout == "NHWC":
output = np.transpose(output, (0, 2, 3, 1)) # NCHW => NHWC
return output, mean.reshape((N, G)), var.reshape((N, G))
@ -42,15 +46,18 @@ class TestGroupNormOp(OpTest):
self.data_format = "NCHW"
self.dtype = np.float32
self.shape = (2, 4, 3, 3)
self.attrs = {'epsilon': 1e-5, 'groups': 2}
self.attrs = {'epsilon': 1e-5, 'groups': 2, 'data_layout': "NCHW"}
self.compare_between_place = False
self.init_test_case()
input = np.random.random(self.shape).astype(self.dtype)
if self.data_format == "NHWC":
input = np.transpose(input, (0, 2, 3, 1))
scale = np.random.random([self.shape[1]]).astype(self.dtype)
bias = np.random.random([self.shape[1]]).astype(self.dtype)
output, mean, var = group_norm_naive(
input, scale, bias, self.attrs['epsilon'], self.attrs['groups'])
input, scale, bias, self.attrs['epsilon'], self.attrs['groups'],
self.data_format)
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(input),
@ -58,6 +65,7 @@ class TestGroupNormOp(OpTest):
'Bias': OpTest.np_dtype_to_fluid_dtype(bias)
}
self.outputs = {'Y': output, 'Mean': mean, 'Variance': var}
self.attrs['data_layout'] = self.data_format
def test_check_output(self):
atol = 1e-4
@ -66,6 +74,7 @@ class TestGroupNormOp(OpTest):
# add inplace_atol bacause group_norm doesn't ensure computational consistency
self.check_output_with_place(
place, atol=atol, inplace_atol=inplace_atol)
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(
@ -94,6 +103,7 @@ class TestGroupNormOp(OpTest):
if self.compare_between_place:
self.do_compare_between_place()
return
place = core.CPUPlace()
self.check_grad_with_place(
place, set(['X', 'Scale', 'Bias']), 'Y', max_relative_error=0.01)
@ -143,5 +153,85 @@ class TestGroupNormOpLargeData(TestGroupNormOp):
self.compare_between_place = True
class TestGroupNormOp1_With_NHWC(TestGroupNormOp):
def init_test_case(self):
self.attrs['groups'] = 1
self.data_format = "NHWC"
class TestGroupNormOp2_With_NHWC(TestGroupNormOp):
def init_test_case(self):
self.attrs['groups'] = 4
self.data_format = "NHWC"
class TestGroupNormOpBigEps1_With_NHWC(TestGroupNormOp):
def init_test_case(self):
self.attrs['groups'] = 1
self.attrs['epsilon'] = 0.5
self.data_format = "NHWC"
class TestGroupNormOpBigEps2_With_NHWC(TestGroupNormOp):
def init_test_case(self):
self.attrs['groups'] = 4
self.attrs['epsilon'] = 0.5
self.data_format = "NHWC"
class TestGroupNormOpBigEps3_With_NHWC(TestGroupNormOp):
def init_test_case(self):
self.attrs['epsilon'] = 0.5
self.data_format = "NHWC"
class TestGroupNormOpLargeData_With_NHWC(TestGroupNormOp):
def init_test_case(self):
self.shape = (2, 64, 32, 32) # NCHW
self.attrs['groups'] = 8
self.data_format = "NHWC"
self.compare_between_place = True
class TestGroupNormAPI_With_NHWC(OpTest):
def test_case1(self):
data1 = fluid.layers.data(
name='data1', shape=[3, 3, 4], dtype='float32')
out1 = fluid.layers.group_norm(
input=data1, groups=2, data_layout="NHWC")
data2 = fluid.layers.data(
name='data2', shape=[4, 3, 3], dtype='float32')
out2 = fluid.layers.group_norm(
input=data2, groups=2, data_layout="NCHW")
data1_np = np.random.random((2, 3, 3, 4)).astype("float32")
data2_np = np.random.random((2, 4, 3, 3)).astype("float32")
scale = np.array([1]).astype("float32")
bias = np.array([0]).astype("float32")
place = core.CPUPlace()
exe = fluid.Executor(place)
results = exe.run(fluid.default_main_program(),
feed={"data1": data1_np,
"data2": data2_np},
fetch_list=[out1, out2],
return_numpy=True)
expect_res1 = group_norm_naive(
data1_np, scale, bias, epsilon=1e-5, groups=2, data_layout="NHWC")
expect_res2 = group_norm_naive(
data2_np, scale, bias, epsilon=1e-5, groups=2, data_layout="NCHW")
self.assertTrue(np.allclose(results[0], expect_res1[0]))
self.assertTrue(np.allclose(results[1], expect_res2[0]))
# data_layout is not NHWC or NCHW
def test_case2(self):
data = fluid.layers.data(name='data', shape=[3, 3, 4], dtype="float32")
try:
out = fluid.layers.group_norm(
input=data, groups=2, data_layout="NDHW")
except:
pass
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save