mv ParallelMode to context

pull/5351/head
yao_yf 5 years ago
parent 78386683bf
commit 07117e4dd4

@ -28,7 +28,7 @@ from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context
_reset_auto_parallel_context _reset_auto_parallel_context
__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context', __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
'get_auto_parallel_context', 'reset_auto_parallel_context'] 'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode']
GRAPH_MODE = 0 GRAPH_MODE = 0
PYNATIVE_MODE = 1 PYNATIVE_MODE = 1
@ -647,3 +647,26 @@ def get_context(attr_key):
raise ValueError( raise ValueError(
"Get context keyword %s is not recognized!" % attr_key) "Get context keyword %s is not recognized!" % attr_key)
return getattr(_context(), attr_key) return getattr(_context(), attr_key)
class ParallelMode:
"""
Parallel mode options.
There are five kinds of parallel modes, "STAND_ALONE", "DATA_PARALLEL",
"HYBRID_PARALLEL", "SEMI_AUTO_PARALLEL" and "AUTO_PARALLEL". Default: "STAND_ALONE".
- STAND_ALONE: Only one processor working.
- DATA_PARALLEL: Distributing the data across different processors.
- HYBRID_PARALLEL: Achieving data parallelism and model parallelism manually.
- SEMI_AUTO_PARALLEL: Achieving data parallelism and model parallelism by setting parallel strategies.
- AUTO_PARALLEL: Achieving parallelism automatically.
MODE_LIST: The list for all supported parallel modes.
"""
STAND_ALONE = "stand_alone"
DATA_PARALLEL = "data_parallel"
HYBRID_PARALLEL = "hybrid_parallel"
SEMI_AUTO_PARALLEL = "semi_auto_parallel"
AUTO_PARALLEL = "auto_parallel"
MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL]

@ -20,7 +20,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore._checkparam import Validator from mindspore._checkparam import Validator
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode from mindspore.parallel._utils import _get_parallel_mode
from ..cell import Cell from ..cell import Cell
from ..._checkparam import Validator as validator, Rel from ..._checkparam import Validator as validator, Rel
@ -129,9 +129,9 @@ class EmbeddingLookup(Cell):
embedding_size (int): The size of each embedding vector. embedding_size (int): The size of each embedding vector.
param_init (str): The initialize way of embedding table. Default: 'normal'. param_init (str): The initialize way of embedding table. Default: 'normal'.
target (str): Specify the target where the op is executed. The value should in target (str): Specify the target where the op is executed. The value should in
['DEVICE', 'CPU']. Default: 'CPU'. ['DEVICE', 'CPU']. Default: 'CPU'.
slice_mode (str): The slicing way in semi auto parallel/auto parallel. The value should get through slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value should get through
nn.EmbeddingLookUpSplitMode. Default: 'batch_slice'. nn.EmbeddingLookUpSplitMode. Default: nn.EmbeddingLookUpSplitMode.BATCH_SLICE.
manual_shapes (tuple): The accompaniment array in field slice mode. manual_shapes (tuple): The accompaniment array in field slice mode.
Inputs: Inputs:

@ -29,7 +29,7 @@ from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from mindspore import log as logger from mindspore import log as logger
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore import context from mindspore import context
from mindspore.nn.learning_rate_schedule import LearningRateSchedule from mindspore.nn.learning_rate_schedule import LearningRateSchedule

@ -15,7 +15,7 @@
"""Cell_wrapper.""" """Cell_wrapper."""
from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean, from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean,
_get_parallel_mode) _get_parallel_mode)
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.parameter import Parameter, ParameterTuple from ...common.parameter import Parameter, ParameterTuple
from ...ops import composite as C from ...ops import composite as C

@ -251,8 +251,9 @@ class DistributedGradReducer(Cell):
>>> from mindspore.ops import operations as P >>> from mindspore.ops import operations as P
>>> from mindspore.ops import functional as F >>> from mindspore.ops import functional as F
>>> from mindspore import context >>> from mindspore import context
>>> from mindspore.context import ParallelMode
>>> from mindspore import nn >>> from mindspore import nn
>>> from mindspore import ParallelMode, ParameterTuple >>> from mindspore import ParameterTuple
>>> >>>
>>> device_id = int(os.environ["DEVICE_ID"]) >>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,

@ -15,7 +15,7 @@
"""Loss scale cell for loss scale training.""" """Loss scale cell for loss scale training."""
import mindspore.context as context import mindspore.context as context
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from ..cell import Cell from ..cell import Cell
from ...common import Tensor, RowTensor from ...common import Tensor, RowTensor

@ -18,8 +18,7 @@ High-Level training interfaces.
Helper functions in train piplines. Helper functions in train piplines.
""" """
from .model import Model from .model import Model
from .parallel_utils import ParallelMode
from .dataset_helper import DatasetHelper from .dataset_helper import DatasetHelper
from . import amp from . import amp
__all__ = ["Model", "ParallelMode", "DatasetHelper", "amp"] __all__ = ["Model", "DatasetHelper", "amp"]

@ -23,7 +23,7 @@ from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from ..ops import functional as F from ..ops import functional as F
from ..parallel._utils import _get_parallel_mode from ..parallel._utils import _get_parallel_mode
from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from .parallel_utils import ParallelMode from ..context import ParallelMode
from .. import context from .. import context
__all__ = ["build_train_network"] __all__ = ["build_train_network"]

@ -30,7 +30,7 @@ from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_r
from ..nn.metrics import Loss from ..nn.metrics import Loss
from .. import nn from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from .parallel_utils import ParallelMode from ..context import ParallelMode
from ..parallel._utils import _need_to_full, _to_full_tensor from ..parallel._utils import _need_to_full, _to_full_tensor
from ..common import dtype as mstype from ..common import dtype as mstype
from .dataset_helper import DatasetHelper from .dataset_helper import DatasetHelper

@ -1,41 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parallel utils"""
__all__ = ["ParallelMode"]
class ParallelMode:
"""
Parallel mode options.
There are five kinds of parallel modes, "STAND_ALONE", "DATA_PARALLEL",
"HYBRID_PARALLEL", "SEMI_AUTO_PARALLEL" and "AUTO_PARALLEL". Default: "STAND_ALONE".
- STAND_ALONE: Only one processor working.
- DATA_PARALLEL: Distributing the data across different processors.
- HYBRID_PARALLEL: Achieving data parallelism and model parallelism manually.
- SEMI_AUTO_PARALLEL: Achieving data parallelism and model parallelism by setting parallel strategies.
- AUTO_PARALLEL: Achieving parallelism automatically.
MODE_LIST: The list for all supported parallel modes.
"""
STAND_ALONE = "stand_alone"
DATA_PARALLEL = "data_parallel"
HYBRID_PARALLEL = "hybrid_parallel"
SEMI_AUTO_PARALLEL = "semi_auto_parallel"
AUTO_PARALLEL = "auto_parallel"
MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL]

@ -17,7 +17,8 @@ import argparse
from mindspore import context from mindspore import context
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore import Model, ParallelMode from mindspore import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor
from src.md_dataset import create_dataset from src.md_dataset import create_dataset

@ -25,7 +25,8 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import SGD from mindspore.nn import SGD
import mindspore.dataset.engine as de import mindspore.dataset.engine as de

@ -28,7 +28,8 @@ from mindspore import context
from mindspore.communication.management import init, get_rank from mindspore.communication.management import init, get_rank
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import cifar_cfg as cfg from src.config import cifar_cfg as cfg

@ -21,7 +21,7 @@ import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore import ParallelMode from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.nn.optim.rmsprop import RMSProp from mindspore.nn.optim.rmsprop import RMSProp
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor

@ -24,7 +24,8 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import SGD from mindspore.nn import SGD
import mindspore.dataset.engine as de import mindspore.dataset.engine as de

@ -30,7 +30,8 @@ from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net

@ -22,7 +22,8 @@ import numpy as np
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore import nn from mindspore import nn
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint from mindspore.train.serialization import load_checkpoint

@ -28,7 +28,8 @@ from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net

@ -22,7 +22,8 @@ from mindspore import Tensor
from mindspore import dataset as de from mindspore import dataset as de
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager

@ -21,7 +21,8 @@ from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint from mindspore.train.serialization import load_checkpoint

@ -102,7 +102,8 @@ class DistributedGradReducerThor(Cell):
>>> from mindspore.ops import functional as F >>> from mindspore.ops import functional as F
>>> from mindspore import context >>> from mindspore import context
>>> from mindspore import nn >>> from mindspore import nn
>>> from mindspore import ParallelMode, ParameterTuple >>> from mindspore import ParameterTuple
>>> from mindspore.context import ParallelMode
>>> >>>
>>> device_id = int(os.environ["DEVICE_ID"]) >>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,

@ -18,7 +18,7 @@ import math
from mindspore.train.callback import RunContext from mindspore.train.callback import RunContext
from mindspore import context from mindspore import context
from mindspore import nn from mindspore import nn
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.parallel._utils import _need_to_full, _to_full_tensor from mindspore.parallel._utils import _need_to_full, _to_full_tensor
from mindspore.common.dtype import pytype_to_dtype from mindspore.common.dtype import pytype_to_dtype

@ -22,7 +22,7 @@ from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore import dataset as de from mindspore import dataset as de
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.model import ParallelMode from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.communication.management import init, get_rank, get_group_size from mindspore.communication.management import init, get_rank, get_group_size

@ -20,7 +20,7 @@ import datetime
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor, context from mindspore import Tensor, context
from mindspore import ParallelMode from mindspore.context import ParallelMode
from mindspore.nn.optim import Momentum from mindspore.nn.optim import Momentum
from mindspore.communication.management import init, get_rank, get_group_size from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint from mindspore.train.callback import ModelCheckpoint

@ -19,6 +19,7 @@ import mindspore.common.dtype as mstype
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Parameter, context, Tensor from mindspore import Parameter, context, Tensor
from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore.ops import operations as P from mindspore.ops import operations as P
@ -388,7 +389,7 @@ class TrainingWrapper(nn.Cell):
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]: if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("mirror_mean")

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save