|
|
@ -15,12 +15,14 @@
|
|
|
|
"""Utils of auto parallel"""
|
|
|
|
"""Utils of auto parallel"""
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
from mindspore._c_expression import reset_op_id
|
|
|
|
from mindspore._c_expression import reset_op_id
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
from mindspore.common.dtype import dtype_to_nptype
|
|
|
|
from mindspore.common.dtype import dtype_to_nptype
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
from mindspore.communication.management import get_group_size, get_rank
|
|
|
|
from mindspore.communication.management import get_group_size, get_rank
|
|
|
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|
|
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|
|
|
|
|
|
|
from mindspore.common.seed import get_seed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_parallel_mode():
|
|
|
|
def _get_parallel_mode():
|
|
|
@ -136,16 +138,11 @@ def _get_global_rank():
|
|
|
|
def _get_parameter_broadcast():
|
|
|
|
def _get_parameter_broadcast():
|
|
|
|
"""Get the parameter broadcast."""
|
|
|
|
"""Get the parameter broadcast."""
|
|
|
|
parallel_mode = auto_parallel_context().get_parallel_mode()
|
|
|
|
parallel_mode = auto_parallel_context().get_parallel_mode()
|
|
|
|
if parallel_mode == "stand_alone":
|
|
|
|
parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
|
|
|
|
parameter_broadcast = False
|
|
|
|
|
|
|
|
return parameter_broadcast
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if auto_parallel_context().get_parameter_broadcast_is_set() is True:
|
|
|
|
if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False and get_seed is None:
|
|
|
|
parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
|
|
|
|
logger.warning("You are suggested to use mindspore.common.set_seed() to share"
|
|
|
|
elif parallel_mode in ("data_parallel", "hybrid_parallel"):
|
|
|
|
" parameters among devices.")
|
|
|
|
parameter_broadcast = True
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
parameter_broadcast = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return parameter_broadcast
|
|
|
|
return parameter_broadcast
|
|
|
|
|
|
|
|
|
|
|
|