|
|
@ -82,6 +82,7 @@ class _BatchNorm(Cell):
|
|
|
|
self.dtype = P.DType()
|
|
|
|
self.dtype = P.DType()
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
self.is_ascend = context.get_context("device_target") == "Ascend"
|
|
|
|
self.is_ascend = context.get_context("device_target") == "Ascend"
|
|
|
|
|
|
|
|
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
|
|
|
|
|
|
|
|
|
|
|
|
if context.get_context("enable_ge"):
|
|
|
|
if context.get_context("enable_ge"):
|
|
|
|
self.is_ge_backend = True
|
|
|
|
self.is_ge_backend = True
|
|
|
@ -89,7 +90,7 @@ class _BatchNorm(Cell):
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.is_ge_backend = False
|
|
|
|
self.is_ge_backend = False
|
|
|
|
self.momentum = 1.0 - momentum
|
|
|
|
self.momentum = 1.0 - momentum
|
|
|
|
if self.is_ge_backend or self.is_ascend:
|
|
|
|
if self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
|
|
|
|
self.bn_train = P.BatchNorm(is_training=True,
|
|
|
|
self.bn_train = P.BatchNorm(is_training=True,
|
|
|
|
epsilon=self.eps)
|
|
|
|
epsilon=self.eps)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -147,7 +148,7 @@ class _BatchNorm(Cell):
|
|
|
|
if self.is_ge_backend and self.is_global:
|
|
|
|
if self.is_ge_backend and self.is_global:
|
|
|
|
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
|
|
|
|
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
|
|
|
|
y = self._global_sync(x, axes, re_shape)
|
|
|
|
y = self._global_sync(x, axes, re_shape)
|
|
|
|
elif self.is_ge_backend or self.is_ascend:
|
|
|
|
elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
|
|
|
|
y, batch_mean, batch_var, _, _ = \
|
|
|
|
y, batch_mean, batch_var, _, _ = \
|
|
|
|
self.bn_train(x,
|
|
|
|
self.bn_train(x,
|
|
|
|
self.gamma,
|
|
|
|
self.gamma,
|
|
|
|