|
|
|
@ -254,6 +254,8 @@ class DistributedGradReducer(Cell):
|
|
|
|
|
>>> from mindspore.context import ParallelMode
|
|
|
|
|
>>> from mindspore import nn
|
|
|
|
|
>>> from mindspore import ParameterTuple
|
|
|
|
|
>>> from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
|
|
|
|
|
>>> _get_parallel_mode)
|
|
|
|
|
>>>
|
|
|
|
|
>>> device_id = int(os.environ["DEVICE_ID"])
|
|
|
|
|
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
|
|
|
|
@ -279,11 +281,8 @@ class DistributedGradReducer(Cell):
|
|
|
|
|
>>> ParallelMode.HYBRID_PARALLEL]:
|
|
|
|
|
>>> self.reducer_flag = True
|
|
|
|
|
>>> if self.reducer_flag:
|
|
|
|
|
>>> mean = context.get_auto_parallel_context("gradients_mean")
|
|
|
|
|
>>> if mean.get_device_num_is_set():
|
|
|
|
|
>>> degree = context.get_auto_parallel_context("device_num")
|
|
|
|
|
>>> else:
|
|
|
|
|
>>> degree = get_group_size()
|
|
|
|
|
>>> mean = _get_gradients_mean()
|
|
|
|
|
>>> degree = _get_device_num()
|
|
|
|
|
>>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
|
|
|
|
|
>>>
|
|
|
|
|
>>> def construct(self, *args):
|
|
|
|
|