add degree and mean to TrainOneStepCell

pull/13355/head
wangnan39@huawei.com 4 years ago
parent aa9bee0ce3
commit 0f844de1a9

@ -339,9 +339,9 @@ class TrainOneStepCell(Cell):
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
self.mean = _get_gradients_mean()
self.degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
def construct(self, *inputs):
weights = self.weights

Loading…
Cancel
Save