|
|
@ -349,9 +349,9 @@ class TrainOneStepCell(Cell):
|
|
|
|
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
|
|
|
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
|
|
|
self.reducer_flag = True
|
|
|
|
self.reducer_flag = True
|
|
|
|
if self.reducer_flag:
|
|
|
|
if self.reducer_flag:
|
|
|
|
mean = _get_gradients_mean()
|
|
|
|
self.mean = _get_gradients_mean()
|
|
|
|
degree = _get_device_num()
|
|
|
|
self.degree = _get_device_num()
|
|
|
|
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
|
|
|
|
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
|
|
|
|
|
|
|
|
|
|
|
|
def construct(self, *inputs):
|
|
|
|
def construct(self, *inputs):
|
|
|
|
weights = self.weights
|
|
|
|
weights = self.weights
|
|
|
|