Add Group Normalization

pull/379/head
zhaojichen 5 years ago
parent 0b7de6968f
commit 04c522d0c6

@ -57,6 +57,7 @@ def test_compile():
input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
_executor.compile(net, input_data)
class GroupNet(nn.Cell):
def __init__(self):
super(GroupNet, self).__init__()
@ -64,6 +65,7 @@ class GroupNet(nn.Cell):
def construct(self, x):
return self.group_bn(x)
def test_compile_groupnorm():
net = nn.GroupNorm(16, 64)
input_data = Tensor(np.random.rand(1,64,256,256).astype(np.float32))

Loading…
Cancel
Save