mv ParallelMode to context

pull/5351/head
yao_yf 5 years ago
parent 78386683bf
commit 07117e4dd4

@ -28,7 +28,7 @@ from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context
_reset_auto_parallel_context
__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
'get_auto_parallel_context', 'reset_auto_parallel_context']
'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode']
GRAPH_MODE = 0
PYNATIVE_MODE = 1
@ -647,3 +647,26 @@ def get_context(attr_key):
raise ValueError(
"Get context keyword %s is not recognized!" % attr_key)
return getattr(_context(), attr_key)
class ParallelMode:
"""
Parallel mode options.
There are five kinds of parallel modes, "STAND_ALONE", "DATA_PARALLEL",
"HYBRID_PARALLEL", "SEMI_AUTO_PARALLEL" and "AUTO_PARALLEL". Default: "STAND_ALONE".
- STAND_ALONE: Only one processor working.
- DATA_PARALLEL: Distributing the data across different processors.
- HYBRID_PARALLEL: Achieving data parallelism and model parallelism manually.
- SEMI_AUTO_PARALLEL: Achieving data parallelism and model parallelism by setting parallel strategies.
- AUTO_PARALLEL: Achieving parallelism automatically.
MODE_LIST: The list for all supported parallel modes.
"""
STAND_ALONE = "stand_alone"
DATA_PARALLEL = "data_parallel"
HYBRID_PARALLEL = "hybrid_parallel"
SEMI_AUTO_PARALLEL = "semi_auto_parallel"
AUTO_PARALLEL = "auto_parallel"
MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL]

@ -20,7 +20,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore._checkparam import Validator
from mindspore.communication.management import get_group_size
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode
from ..cell import Cell
from ..._checkparam import Validator as validator, Rel
@ -129,9 +129,9 @@ class EmbeddingLookup(Cell):
embedding_size (int): The size of each embedding vector.
param_init (str): The initialize way of embedding table. Default: 'normal'.
target (str): Specify the target where the op is executed. The value should in
['DEVICE', 'CPU']. Default: 'CPU'.
slice_mode (str): The slicing way in semi auto parallel/auto parallel. The value should get through
nn.EmbeddingLookUpSplitMode. Default: 'batch_slice'.
['DEVICE', 'CPU']. Default: 'CPU'.
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value should get through
nn.EmbeddingLookUpSplitMode. Default: nn.EmbeddingLookUpSplitMode.BATCH_SLICE.
manual_shapes (tuple): The accompaniment array in field slice mode.
Inputs:

@ -29,7 +29,7 @@ from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore import log as logger
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore import context
from mindspore.nn.learning_rate_schedule import LearningRateSchedule

@ -15,7 +15,7 @@
"""Cell_wrapper."""
from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean,
_get_parallel_mode)
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from ...common import dtype as mstype
from ...common.parameter import Parameter, ParameterTuple
from ...ops import composite as C

@ -251,8 +251,9 @@ class DistributedGradReducer(Cell):
>>> from mindspore.ops import operations as P
>>> from mindspore.ops import functional as F
>>> from mindspore import context
>>> from mindspore.context import ParallelMode
>>> from mindspore import nn
>>> from mindspore import ParallelMode, ParameterTuple
>>> from mindspore import ParameterTuple
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,

@ -15,7 +15,7 @@
"""Loss scale cell for loss scale training."""
import mindspore.context as context
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from ..cell import Cell
from ...common import Tensor, RowTensor

@ -18,8 +18,7 @@ High-Level training interfaces.
Helper functions in train piplines.
"""
from .model import Model
from .parallel_utils import ParallelMode
from .dataset_helper import DatasetHelper
from . import amp
__all__ = ["Model", "ParallelMode", "DatasetHelper", "amp"]
__all__ = ["Model", "DatasetHelper", "amp"]

@ -23,7 +23,7 @@ from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from ..ops import functional as F
from ..parallel._utils import _get_parallel_mode
from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from .parallel_utils import ParallelMode
from ..context import ParallelMode
from .. import context
__all__ = ["build_train_network"]

@ -30,7 +30,7 @@ from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_r
from ..nn.metrics import Loss
from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from .parallel_utils import ParallelMode
from ..context import ParallelMode
from ..parallel._utils import _need_to_full, _to_full_tensor
from ..common import dtype as mstype
from .dataset_helper import DatasetHelper

@ -1,41 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parallel utils"""
__all__ = ["ParallelMode"]
class ParallelMode:
"""
Parallel mode options.
There are five kinds of parallel modes, "STAND_ALONE", "DATA_PARALLEL",
"HYBRID_PARALLEL", "SEMI_AUTO_PARALLEL" and "AUTO_PARALLEL". Default: "STAND_ALONE".
- STAND_ALONE: Only one processor working.
- DATA_PARALLEL: Distributing the data across different processors.
- HYBRID_PARALLEL: Achieving data parallelism and model parallelism manually.
- SEMI_AUTO_PARALLEL: Achieving data parallelism and model parallelism by setting parallel strategies.
- AUTO_PARALLEL: Achieving parallelism automatically.
MODE_LIST: The list for all supported parallel modes.
"""
STAND_ALONE = "stand_alone"
DATA_PARALLEL = "data_parallel"
HYBRID_PARALLEL = "hybrid_parallel"
SEMI_AUTO_PARALLEL = "semi_auto_parallel"
AUTO_PARALLEL = "auto_parallel"
MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL]

@ -17,7 +17,8 @@ import argparse
from mindspore import context
from mindspore.communication.management import init
from mindspore.nn.optim.momentum import Momentum
from mindspore import Model, ParallelMode
from mindspore import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor
from src.md_dataset import create_dataset

@ -25,7 +25,8 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor
from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import SGD
import mindspore.dataset.engine as de

@ -28,7 +28,8 @@ from mindspore import context
from mindspore.communication.management import init, get_rank
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import cifar_cfg as cfg

@ -21,7 +21,7 @@ import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore import ParallelMode
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.nn.optim.rmsprop import RMSProp
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor

@ -24,7 +24,8 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor
from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import SGD
import mindspore.dataset.engine as de

@ -30,7 +30,8 @@ from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net

@ -22,7 +22,8 @@ import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore import nn
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint

@ -28,7 +28,8 @@ from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net

@ -22,7 +22,8 @@ from mindspore import Tensor
from mindspore import dataset as de
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.loss_scale_manager import FixedLossScaleManager

@ -21,7 +21,8 @@ from mindspore import context
from mindspore import Tensor
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint

@ -102,7 +102,8 @@ class DistributedGradReducerThor(Cell):
>>> from mindspore.ops import functional as F
>>> from mindspore import context
>>> from mindspore import nn
>>> from mindspore import ParallelMode, ParameterTuple
>>> from mindspore import ParameterTuple
>>> from mindspore.context import ParallelMode
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,

@ -18,7 +18,7 @@ import math
from mindspore.train.callback import RunContext
from mindspore import context
from mindspore import nn
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.train.model import Model
from mindspore.parallel._utils import _need_to_full, _to_full_tensor
from mindspore.common.dtype import pytype_to_dtype

@ -22,7 +22,7 @@ from mindspore import context
from mindspore import Tensor
from mindspore import dataset as de
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.model import ParallelMode
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.communication.management import init, get_rank, get_group_size

@ -20,7 +20,7 @@ import datetime
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore import ParallelMode
from mindspore.context import ParallelMode
from mindspore.nn.optim import Momentum
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint

@ -19,6 +19,7 @@ import mindspore.common.dtype as mstype
import mindspore as ms
import mindspore.nn as nn
from mindspore import Parameter, context, Tensor
from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size
from mindspore.ops import operations as P
@ -388,7 +389,7 @@ class TrainingWrapper(nn.Cell):
self.reducer_flag = False
self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]:
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save