|
|
|
@ -27,14 +27,6 @@ from mindspore.communication import management
|
|
|
|
|
from mindspore._checkparam import check_int_positive
|
|
|
|
|
from ..cell import Cell
|
|
|
|
|
|
|
|
|
|
class _GlobalBNHelper(Cell):
|
|
|
|
|
def __init__(self, group):
|
|
|
|
|
super(_GlobalBNHelper, self).__init__()
|
|
|
|
|
self.group = group
|
|
|
|
|
self.reduce = P.AllReduce(P.ReduceOp.SUM, group=self.group).add_prim_attr('fusion', 1)
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
x = self.reduce(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
class _BatchNorm(Cell):
|
|
|
|
|
"""Batch Normalization base class."""
|
|
|
|
|