|
|
|
@ -156,19 +156,23 @@ class _BatchNorm(Cell):
|
|
|
|
|
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
|
|
|
|
|
y = self._global_sync(x, axes, re_shape)
|
|
|
|
|
elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
|
|
|
|
|
y, batch_mean, batch_var, _, _ = \
|
|
|
|
|
self.bn_train(x,
|
|
|
|
|
self.gamma,
|
|
|
|
|
self.beta,
|
|
|
|
|
None,
|
|
|
|
|
None)
|
|
|
|
|
|
|
|
|
|
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
|
|
|
|
|
temp_mean = self.mul_mean(mean_sub, self.momentum)
|
|
|
|
|
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
|
|
|
|
|
temp_variance = self.mul_var(mean_sub2, self.momentum)
|
|
|
|
|
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
|
|
|
|
|
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
|
|
|
|
|
if self.is_global:
|
|
|
|
|
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
|
|
|
|
|
y = self._global_sync(x, axes, re_shape)
|
|
|
|
|
else:
|
|
|
|
|
y, batch_mean, batch_var, _, _ = \
|
|
|
|
|
self.bn_train(x,
|
|
|
|
|
self.gamma,
|
|
|
|
|
self.beta,
|
|
|
|
|
None,
|
|
|
|
|
None)
|
|
|
|
|
|
|
|
|
|
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
|
|
|
|
|
temp_mean = self.mul_mean(mean_sub, self.momentum)
|
|
|
|
|
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
|
|
|
|
|
temp_variance = self.mul_var(mean_sub2, self.momentum)
|
|
|
|
|
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
|
|
|
|
|
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
|
|
|
|
|
else:
|
|
|
|
|
y = self.bn_train(x,
|
|
|
|
|
self.gamma,
|
|
|
|
|