add global batch normalization

pull/399/head
zhaojichen 5 years ago
parent f7872774f3
commit c5120e770c

@ -27,14 +27,6 @@ from mindspore.communication import management
from mindspore._checkparam import check_int_positive
from ..cell import Cell
class _GlobalBNHelper(Cell):
def __init__(self, group):
super(_GlobalBNHelper, self).__init__()
self.group = group
self.reduce = P.AllReduce(P.ReduceOp.SUM, group=self.group).add_prim_attr('fusion', 1)
def construct(self, x):
x = self.reduce(x)
return x
class _BatchNorm(Cell):
"""Batch Normalization base class."""

Loading…
Cancel
Save