|
|
|
@ -195,12 +195,10 @@ class TestGroupNormOpLargeData_With_NHWC(TestGroupNormOp):
|
|
|
|
|
|
|
|
|
|
class TestGroupNormAPI_With_NHWC(OpTest):
|
|
|
|
|
def test_case1(self):
|
|
|
|
|
data1 = fluid.layers.data(
|
|
|
|
|
name='data1', shape=[3, 3, 4], dtype='float32')
|
|
|
|
|
data1 = fluid.data(name='data1', shape=[None, 3, 3, 4], dtype='float32')
|
|
|
|
|
out1 = fluid.layers.group_norm(
|
|
|
|
|
input=data1, groups=2, data_layout="NHWC")
|
|
|
|
|
data2 = fluid.layers.data(
|
|
|
|
|
name='data2', shape=[4, 3, 3], dtype='float32')
|
|
|
|
|
data2 = fluid.data(name='data2', shape=[None, 4, 3, 3], dtype='float32')
|
|
|
|
|
out2 = fluid.layers.group_norm(
|
|
|
|
|
input=data2, groups=2, data_layout="NCHW")
|
|
|
|
|
|
|
|
|
@ -223,14 +221,17 @@ class TestGroupNormAPI_With_NHWC(OpTest):
|
|
|
|
|
self.assertTrue(np.allclose(results[0], expect_res1[0]))
|
|
|
|
|
self.assertTrue(np.allclose(results[1], expect_res2[0]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestGroupNormException(OpTest):
|
|
|
|
|
# data_layout is not NHWC or NCHW
|
|
|
|
|
def test_case2(self):
|
|
|
|
|
data = fluid.layers.data(name='data', shape=[3, 3, 4], dtype="float32")
|
|
|
|
|
try:
|
|
|
|
|
def test_exception(self):
|
|
|
|
|
data = fluid.data(name='data', shape=[None, 3, 3, 4], dtype="float32")
|
|
|
|
|
|
|
|
|
|
def attr_data_format():
|
|
|
|
|
out = fluid.layers.group_norm(
|
|
|
|
|
input=data, groups=2, data_layout="NDHW")
|
|
|
|
|
except:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
self.assertRaises(ValueError, attr_data_format)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|