|
|
|
@ -14,6 +14,8 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""Context of auto parallel"""
|
|
|
|
|
import threading
|
|
|
|
|
import mindspore.context as context
|
|
|
|
|
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
|
|
|
|
|
from mindspore._c_expression import AutoParallelContext
|
|
|
|
|
from mindspore._extends.pynative_helper import args_type_check
|
|
|
|
|
|
|
|
|
@ -219,13 +221,15 @@ class _AutoParallelContext:
|
|
|
|
|
indices (list): Indices list.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If type of indices item is not int.
|
|
|
|
|
TypeError: If type of indices item is not int.
|
|
|
|
|
"""
|
|
|
|
|
self.check_context_handle()
|
|
|
|
|
for index in indices:
|
|
|
|
|
if not isinstance(index, int):
|
|
|
|
|
raise TypeError('indices has invalid value')
|
|
|
|
|
return self._context_handle.set_all_reduce_fusion_split_indices(indices)
|
|
|
|
|
self._context_handle.set_all_reduce_fusion_split_indices(indices)
|
|
|
|
|
if context.get_context("device_target") == "Ascend":
|
|
|
|
|
_set_fusion_strategy_by_idx(indices)
|
|
|
|
|
|
|
|
|
|
def get_all_reduce_fusion_split_indices(self):
|
|
|
|
|
"""Get allreduce fusion split indices."""
|
|
|
|
@ -240,13 +244,15 @@ class _AutoParallelContext:
|
|
|
|
|
sizes (list): Sizes list.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If type of sizes item is not int.
|
|
|
|
|
TypeError: If type of sizes item is not int.
|
|
|
|
|
"""
|
|
|
|
|
self.check_context_handle()
|
|
|
|
|
for size in sizes:
|
|
|
|
|
if not isinstance(size, int):
|
|
|
|
|
raise TypeError('sizes has invalid value')
|
|
|
|
|
return self._context_handle.set_all_reduce_fusion_split_sizes(sizes)
|
|
|
|
|
self._context_handle.set_all_reduce_fusion_split_sizes(sizes)
|
|
|
|
|
if context.get_context("device_target") == "Ascend":
|
|
|
|
|
_set_fusion_strategy_by_size(sizes)
|
|
|
|
|
|
|
|
|
|
def get_all_reduce_fusion_split_sizes(self):
|
|
|
|
|
"""Get allreduce fusion split sizes."""
|
|
|
|
|