fix globalbatchnorm bug

pull/832/head
zhaojichen 5 years ago
parent eb46dd9198
commit 0ba35eaec3

@ -117,6 +117,7 @@ class _BatchNorm(Cell):
return group_list
def _global_sync(self, x):
"""calculate global batch normalization output"""
if len(self.shape(x)) == 4:
axes = (0, 2, 3)
re_shape = (1, self.num_features, 1, 1)

Loading…
Cancel
Save