Add interface to launch parallel dygraph by multiprocessing (#26044)

* add dygraph parallel run interface

* polish implement & unified env property name

* add print config arg

* refactor init_parallel_env function

* Compatible with multiprocessing and launch modes

* set default trainer start port

* support run in python 2

* polish python2 support code

* remove python2 support

* refine launch import

* polish dome design details

* refactor api implemention & path

* use new method _set_expected_place

* add spawn unittest framework & mnist test

* add more unittests & doc

* fix unittest failed

* polish english doc

* self review and polish details

* refactor code by reviewer's comments

* fix unittest failed

* fix parallel_env unittest

* fix several typos

* fix error introduced when fixing typos

* add unpublic note for start_processes

* polish details by xiaoguang's comment

* verify correctly when spawn nprocs=-1

* refactor spawn & init_parallel_env design

* polish doc details

* open spawn unittests

* try to fix doc compile error

* try to fix unknown doc format error

* add skip unittest when not gpu
revert-26856-strategy_example2
Chen Weihang 5 years ago committed by GitHub
parent 3390c7e260
commit 31f422ae5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -230,8 +230,6 @@ from .framework import grad #DEFINE_ALIAS
from .framework import no_grad #DEFINE_ALIAS from .framework import no_grad #DEFINE_ALIAS
from .framework import save #DEFINE_ALIAS from .framework import save #DEFINE_ALIAS
from .framework import load #DEFINE_ALIAS from .framework import load #DEFINE_ALIAS
from .framework import prepare_context #DEFINE_ALIAS
from .framework import ParallelEnv #DEFINE_ALIAS
from .framework import DataParallel #DEFINE_ALIAS from .framework import DataParallel #DEFINE_ALIAS
from .framework import NoamDecay #DEFINE_ALIAS from .framework import NoamDecay #DEFINE_ALIAS

@ -12,4 +12,27 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import spawn
from .spawn import spawn
from . import parallel
from .parallel import init_parallel_env
from .parallel import get_rank
from .parallel import get_world_size
from paddle.fluid.dygraph.parallel import prepare_context #DEFINE_ALIAS
from paddle.fluid.dygraph.parallel import ParallelEnv #DEFINE_ALIAS
from . import collective
from .collective import * from .collective import *
# start multiprocess apis
__all__ = ["spawn"]
# dygraph parallel apis
__all__ += [
"init_parallel_env", "get_rank", "get_world_size", "prepare_context",
"ParallelEnv"
]
# collective apis
__all__ += collective.__all__

@ -44,11 +44,9 @@ import time
import six import six
import copy import copy
from argparse import ArgumentParser, REMAINDER from argparse import ArgumentParser, REMAINDER
import paddle
import paddle.fluid as fluid
from paddle.distributed.utils import * from paddle.distributed.utils import *
import paddle.distributed.cloud_utils as cloud_utils from paddle.distributed import cloud_utils
def _print_arguments(args): def _print_arguments(args):
@ -167,7 +165,8 @@ def get_cluster_from_args(args, selected_gpus):
def get_gpus(selected_gpus): def get_gpus(selected_gpus):
if selected_gpus is None: if selected_gpus is None:
gpus_num = fluid.core.get_cuda_device_count() from paddle.fluid import core
gpus_num = core.get_cuda_device_count()
selected_gpus = [str(x) for x in range(0, gpus_num)] selected_gpus = [str(x) for x in range(0, gpus_num)]
else: else:
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
@ -190,7 +189,7 @@ def get_gpus(selected_gpus):
return selected_gpus return selected_gpus
def launch(args): def get_cluster_and_pod(args):
# parse arguments, used for cloud-single-machine and local # parse arguments, used for cloud-single-machine and local
selected_gpus = get_gpus(args.selected_gpus) selected_gpus = get_gpus(args.selected_gpus)
trainers_num = cloud_utils.get_trainers_num() trainers_num = cloud_utils.get_trainers_num()
@ -209,6 +208,12 @@ def launch(args):
cluster, pod = get_cluster_from_args(args, selected_gpus) cluster, pod = get_cluster_from_args(args, selected_gpus)
logger.info("get cluster from args:{}".format(cluster)) logger.info("get cluster from args:{}".format(cluster))
return cluster, pod
def launch(args):
cluster, pod = get_cluster_and_pod(args)
procs = start_local_trainers( procs = start_local_trainers(
cluster, cluster,
pod, pod,

@ -0,0 +1,184 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import six
import warnings
from paddle import compat as cpt
# deprecated module import
from paddle.fluid import core
from paddle.fluid.framework import _set_expected_place
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph.parallel import ParallelEnv
__all__ = ["init_parallel_env"]
ParallelStrategy = core.ParallelStrategy
def init_parallel_env(backend='nccl'):
"""
Initialize parallel training environments in dynamic mode.
Args:
backend(str, optional): The backend to communication between multiple devices.
Now only support ``nccl`` . Default value is ``nccl`` .
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
# 2. initialize parallel environment
dist.init_parallel_env()
# 3. create data parallel layer & optimizer
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)
loss_fn = nn.MSELoss()
adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters())
# 4. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
if __name__ == '__main__':
dist.spawn(train)
"""
# 1. input check
if not isinstance(backend, six.string_types):
raise TypeError("input `backend` type error, expected type is str, "
"but received type is %s." % type(backend))
if cpt.to_text(backend) != 'nccl':
raise ValueError(
"backend `%s` is not supported, now only supports `nccl` backend." %
backend)
# 2. check env
def _check_var_exists(var_name):
var = os.environ.get(var_name, None)
if var is None:
raise ValueError("paddle.distributed initialize error, "
"environment variable %s is needed, but not set." %
var_name)
_check_var_exists("FLAGS_selected_gpus")
_check_var_exists("PADDLE_TRAINER_ID")
_check_var_exists("PADDLE_CURRENT_ENDPOINT")
_check_var_exists("PADDLE_TRAINERS_NUM")
_check_var_exists("PADDLE_TRAINER_ENDPOINTS")
# 3. init ParallelStrategy
strategy = ParallelStrategy()
if cpt.to_text(backend) == 'nccl':
if parallel_helper._is_parallel_ctx_initialized():
warnings.warn("The parallel environment has been initialized.")
strategy.nranks = ParallelEnv().world_size
strategy.local_rank = ParallelEnv().rank
strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
strategy.current_endpoint = ParallelEnv().current_endpoint
if strategy.nranks < 2:
return
# NOTE(chenweihang): [ why config global place here? ]
# the dygraph mode will be set to default mode,
# users will not call `dygraph.guard` or `enable_dygraph`
# directly, if they want to switch default place,
# they need to call a function to change default place,
# here just set correctly place to users
place = core.CUDAPlace(ParallelEnv().device_id)
_set_expected_place(place)
# init nccl context
parallel_helper._set_parallel_ctx(
core.NCCLParallelContext(strategy, place))
parallel_helper._init_parallel_ctx()
def get_rank():
"""
Returns the rank of current trainer.
Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` .
The default value is 0.
Returns:
(int) The rank of current trainer.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
# execute this command in terminal: export PADDLE_TRAINER_ID=0
print("The rank is %d" % dist.get_rank())
# The rank is 0
"""
return ParallelEnv().rank
def get_world_size():
"""
The number of trainers (number of processes participating in current job).
Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` .
The default value is 1.
Returns:
(int) The number of trainers.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
# execute this command in terminal: export PADDLE_TRAINERS_NUM=4
print("The world_size is %d" % dist.get_world_size())
# The world_size is 4
"""
return ParallelEnv().world_size

File diff suppressed because it is too large Load Diff

@ -327,6 +327,17 @@ def find_free_ports(num):
return None return None
def _prepare_trainer_env(cluster, trainer):
proc_env = {
"FLAGS_selected_gpus": "%s" % ",".join([str(g) for g in trainer.gpus]),
"PADDLE_TRAINER_ID": "%d" % trainer.rank,
"PADDLE_CURRENT_ENDPOINT": "%s" % trainer.endpoint,
"PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(),
"PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints())
}
return proc_env
class TrainerProc(object): class TrainerProc(object):
def __init__(self): def __init__(self):
self.proc = None self.proc = None
@ -352,14 +363,7 @@ def start_local_trainers(cluster,
procs = [] procs = []
for idx, t in enumerate(pod.trainers): for idx, t in enumerate(pod.trainers):
proc_env = { proc_env = _prepare_trainer_env(cluster, t)
"FLAGS_selected_gpus": "%s" % ",".join([str(g) for g in t.gpus]),
"PADDLE_TRAINER_ID": "%d" % t.rank,
"PADDLE_CURRENT_ENDPOINT": "%s" % t.endpoint,
"PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(),
"PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints())
}
current_env.update(proc_env) current_env.update(proc_env)
logger.debug("trainer proc env:{}".format(current_env)) logger.debug("trainer proc env:{}".format(current_env))

File diff suppressed because it is too large Load Diff

@ -23,6 +23,11 @@ def _is_data_parallel_mode():
os.getenv("PADDLE_TRAINERS_NUM", "1")) > 1 os.getenv("PADDLE_TRAINERS_NUM", "1")) > 1
def _is_parallel_ctx_initialized():
global __parallel_ctx__clz__
return __parallel_ctx__clz__ is not None
def _set_parallel_ctx(nccl_parallel_context): def _set_parallel_ctx(nccl_parallel_context):
global __parallel_ctx__clz__ global __parallel_ctx__clz__
assert __parallel_ctx__clz__ is None, \ assert __parallel_ctx__clz__ is None, \

@ -0,0 +1,81 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function, division
import numpy as np
import unittest
import paddle
# used by model.run_trainer in test_dist_base
from test_dist_base import RUN_STEP
# NOTE: compatible TestParallelDyGraphRunnerBase args
class SpawnAssistTestArgs(object):
update_method = "local"
trainer_id = 0
class TestDistSpawnRunner(unittest.TestCase):
def setUp(self):
# NOTE(chenweihang): keep consistent with
# TestDistBase.check_with_place
self.nprocs = 2
def _run(self, model, args):
args.update_method = "local"
return model.run_trainer_with_spawn(args)
def _run_parallel(self, model, args):
args.update_method = "nccl2"
context = paddle.distributed.spawn(
func=model.run_trainer_with_spawn,
args=(args, ),
nprocs=self.nprocs,
join=True)
result_list = []
for res_queue in context.return_queues:
result_list.append(res_queue.get())
return result_list
def check_dist_result_with_spawn(self, test_class, delta=1e-3):
# 0. prepare model and args
model = test_class()
args = SpawnAssistTestArgs()
# 1. calc signal card loss
losses = self._run(model, args)
# 2. calc multi card loss (nccl mode)
dist_losses_list = self._run_parallel(model, args)
# 3. compare losses
for step_id in range(RUN_STEP):
loss = losses[step_id]
dist_loss_sum = None
for dist_losses in dist_losses_list:
if dist_loss_sum is None:
dist_loss_sum = np.array(dist_losses[step_id])
else:
dist_loss_sum += np.array(dist_losses[step_id])
dist_loss = dist_loss_sum / self.nprocs
self.assertAlmostEqual(
loss,
dist_loss,
delta=delta,
msg="The results of single-card execution and multi-card execution are inconsistent."
"signal-card loss is:\n{}\nmulti-card average loss is:\n{}\n".
format(loss, dist_loss))

@ -38,9 +38,10 @@ class TestDirectory(unittest.TestCase):
'paddle.enable_static', 'paddle.disable_static', 'paddle.enable_static', 'paddle.disable_static',
'paddle.in_dynamic_mode', 'paddle.to_variable', 'paddle.grad', 'paddle.in_dynamic_mode', 'paddle.to_variable', 'paddle.grad',
'paddle.no_grad', 'paddle.save', 'paddle.load', 'paddle.no_grad', 'paddle.save', 'paddle.load',
'paddle.static.save', 'paddle.static.load', 'paddle.ParallelEnv', 'paddle.static.save', 'paddle.static.load',
'paddle.prepare_context', 'paddle.DataParallel', 'paddle.jit', 'paddle.distributed.ParallelEnv',
'paddle.jit.TracedLayer', 'paddle.jit.to_static', 'paddle.distributed.prepare_context', 'paddle.DataParallel',
'paddle.jit', 'paddle.jit.TracedLayer', 'paddle.jit.to_static',
'paddle.jit.ProgramTranslator', 'paddle.jit.TranslatedLayer', 'paddle.jit.ProgramTranslator', 'paddle.jit.TranslatedLayer',
'paddle.jit.save', 'paddle.jit.load', 'paddle.jit.SaveLoadConfig', 'paddle.jit.save', 'paddle.jit.load', 'paddle.jit.SaveLoadConfig',
'paddle.NoamDecay', 'paddle.PiecewiseDecay', 'paddle.NoamDecay', 'paddle.PiecewiseDecay',

@ -23,8 +23,11 @@ import subprocess
import six import six
import argparse import argparse
import pickle import pickle
import random
import numpy as np import numpy as np
import time import time
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler from paddle.fluid import compiler
import paddle.fluid.dygraph as dygraph import paddle.fluid.dygraph as dygraph
@ -382,22 +385,22 @@ class TestParallelDyGraphRunnerBase(object):
raise NotImplementedError( raise NotImplementedError(
"train_one_loop should be implemented by the child classes.") "train_one_loop should be implemented by the child classes.")
def _get_data(self, batch, args):
if args.update_method != "local":
new_batch = []
for offset, item in enumerate(batch):
if offset % 2 == args.trainer_id:
new_batch.append(item)
return new_batch
else:
return batch
def run_trainer(self, args): def run_trainer(self, args):
seed = 90 seed = 90
device_id = int(os.getenv("FLAGS_selected_gpus", "0")) device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id) place = fluid.CUDAPlace(device_id)
def _get_data(batch):
if args.update_method != "local":
new_batch = []
for offset, item in enumerate(batch):
if offset % 2 == args.trainer_id:
new_batch.append(item)
return new_batch
else:
return batch
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
@ -422,7 +425,7 @@ class TestParallelDyGraphRunnerBase(object):
out_losses = [] out_losses = []
print_to_err(type(self).__name__, "begin to run dygraph training") print_to_err(type(self).__name__, "begin to run dygraph training")
for step_id, data in enumerate(train_reader()): for step_id, data in enumerate(train_reader()):
data = _get_data(data) data = self._get_data(data, args)
if step_id == RUN_STEP: if step_id == RUN_STEP:
break break
loss = self.run_one_loop(model, opt, data) loss = self.run_one_loop(model, opt, data)
@ -444,6 +447,47 @@ class TestParallelDyGraphRunnerBase(object):
model.clear_gradients() model.clear_gradients()
print_to_out(out_losses) print_to_out(out_losses)
def run_trainer_with_spawn(self, args):
# 1. enable dygraph
paddle.disable_static()
# 2. init seed
seed = 90
paddle.static.default_startup_program().random_seed = seed
paddle.static.default_main_program().random_seed = seed
np.random.seed(seed)
random.seed = seed
# get trainer id
args.trainer_id = paddle.distributed.get_rank()
# 3. init parallel env
if args.update_method == "nccl2":
paddle.distributed.init_parallel_env()
# 4. train model
model, train_reader, opt = self.get_model()
if args.update_method == "nccl2":
model = paddle.DataParallel(model)
out_losses = []
for step_id, data in enumerate(train_reader()):
data = self._get_data(data, args)
if step_id == RUN_STEP:
break
loss = self.run_one_loop(model, opt, data)
out_losses.append(loss.numpy())
if args.update_method == "nccl2":
loss = model.scale_loss(loss)
loss.backward()
if args.update_method == "nccl2":
model.apply_collective_grads()
opt.minimize(loss)
model.clear_gradients()
return out_losses
def runtime_main(test_class): def runtime_main(test_class):
parser = argparse.ArgumentParser(description='Run dist test.') parser = argparse.ArgumentParser(description='Run dist test.')

@ -43,7 +43,7 @@ class MLP(fluid.Layer):
class TestDataParallelStateDict(unittest.TestCase): class TestDataParallelStateDict(unittest.TestCase):
def test_data_parallel_state_dict(self): def test_data_parallel_state_dict(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
strategy = paddle.prepare_context() strategy = paddle.distributed.prepare_context()
mlp = MLP() mlp = MLP()
parallel_mlp = dygraph.parallel.DataParallel(mlp, strategy) parallel_mlp = dygraph.parallel.DataParallel(mlp, strategy)

@ -13,11 +13,16 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import os
import sys
import unittest import unittest
from test_dist_base import TestDistBase
import paddle.fluid as fluid import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_mnist import TestMnist
import os
flag_name = os.path.splitext(__file__)[0] flag_name = os.path.splitext(__file__)[0]
@ -36,5 +41,11 @@ class TestParallelDygraphMnist(TestDistBase):
log_name=flag_name) log_name=flag_name)
class TestParallelDygraphMnistSpawn(TestDistSpawnRunner):
def test_mnist_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(test_class=TestMnist, delta=1e-5)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

@ -13,11 +13,16 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import os
import sys
import unittest import unittest
from test_dist_base import TestDistBase
import paddle.fluid as fluid import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_se_resnext import TestSeResNeXt
import os
flag_name = os.path.splitext(__file__)[0] flag_name = os.path.splitext(__file__)[0]
@ -36,5 +41,12 @@ class TestParallelDygraphSeResNeXt(TestDistBase):
log_name=flag_name) log_name=flag_name)
class TestParallelDygraphSeResNeXtSpawn(TestDistSpawnRunner):
def test_se_resnext_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestSeResNeXt, delta=0.01)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

@ -15,10 +15,13 @@
from __future__ import print_function from __future__ import print_function
import os import os
import sys
import unittest import unittest
import paddle.fluid as fluid
import paddle.fluid as fluid
from test_dist_base import TestDistBase from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_sparse_embedding import TestSparseEmbedding
flag_name = os.path.splitext(__file__)[0] flag_name = os.path.splitext(__file__)[0]
@ -38,5 +41,12 @@ class TestParallelDygraphSparseEmdedding(TestDistBase):
log_name=flag_name) log_name=flag_name)
class TestParallelDygraphSparseEmdeddingSpawn(TestDistSpawnRunner):
def test_sparse_embedding_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestSparseEmbedding, delta=1e-5)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

@ -15,10 +15,13 @@
from __future__ import print_function from __future__ import print_function
import os import os
import sys
import unittest import unittest
import paddle.fluid as fluid
import paddle.fluid as fluid
from test_dist_base import TestDistBase from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_transformer import TestTransformer
flag_name = os.path.splitext(__file__)[0] flag_name = os.path.splitext(__file__)[0]
@ -38,5 +41,12 @@ class TestParallelDygraphTransformer(TestDistBase):
log_name=flag_name) log_name=flag_name)
class TestParallelDygraphTransformerSpawn(TestDistSpawnRunner):
def test_transformer_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestTransformer, delta=1e-5)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

@ -0,0 +1,87 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import numpy as np
import unittest
import paddle
import paddle.distributed as dist
from paddle.distributed.spawn import _get_subprocess_env_list
from paddle.fluid import core
from paddle.fluid.dygraph import parallel_helper
# NOTE(chenweihang): Coverage CI is currently not able to count python3
# unittest, so the unittests here covers some cases that will only be
# executed in the python3 sub-process.
class TestInitParallelEnv(unittest.TestCase):
def test_beckend_type_error(self):
with self.assertRaises(TypeError):
dist.init_parallel_env(backend=1)
def test_backend_value_error(self):
with self.assertRaises(ValueError):
dist.init_parallel_env(backend="mpi")
def test_check_env_failed(self):
os.environ['FLAGS_selected_gpus'] = '0'
os.environ['PADDLE_TRAINER_ID'] = '0'
os.environ['PADDLE_CURRENT_ENDPOINT'] = '127.0.0.1:6170'
os.environ['PADDLE_TRAINERS_NUM'] = '1'
with self.assertRaises(ValueError):
dist.init_parallel_env()
def test_init_parallel_env_break(self):
os.environ['FLAGS_selected_gpus'] = '0'
os.environ['PADDLE_TRAINER_ID'] = '0'
os.environ['PADDLE_CURRENT_ENDPOINT'] = '127.0.0.1:6170'
os.environ['PADDLE_TRAINERS_NUM'] = '1'
os.environ['PADDLE_TRAINER_ENDPOINTS'] = '127.0.0.1:6170'
# coverage success branch
dist.init_parallel_env()
self.assertFalse(parallel_helper._is_parallel_ctx_initialized())
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSpawnAssistMethod(unittest.TestCase):
def test_only_cluster_node_ips_error(self):
with self.assertRaises(ValueError):
options = dict()
options['cluster_node_ips'] = "127.0.0.1,127.0.0.2"
_get_subprocess_env_list(nprocs=1, options=options)
def test_nprocs_greater_than_device_num_error(self):
with self.assertRaises(RuntimeError):
_get_subprocess_env_list(nprocs=100, options=dict())
def test_selected_gpus_error(self):
with self.assertRaises(ValueError):
options = dict()
options['selected_gpus'] = "100,101"
_get_subprocess_env_list(nprocs=2, options=options)
def test_get_correct_env(self):
env_dict = _get_subprocess_env_list(nprocs=1, options=dict())[0]
self.assertEqual(env_dict['PADDLE_TRAINER_ID'], '0')
self.assertEqual(env_dict['PADDLE_TRAINERS_NUM'], '1')
if __name__ == "__main__":
unittest.main()

@ -50,8 +50,6 @@ from ..fluid.dygraph.base import to_variable #DEFINE_ALIAS
from ..fluid.dygraph.base import grad #DEFINE_ALIAS from ..fluid.dygraph.base import grad #DEFINE_ALIAS
from ..fluid.dygraph.checkpoint import load_dygraph as load #DEFINE_ALIAS from ..fluid.dygraph.checkpoint import load_dygraph as load #DEFINE_ALIAS
from ..fluid.dygraph.checkpoint import save_dygraph as save #DEFINE_ALIAS from ..fluid.dygraph.checkpoint import save_dygraph as save #DEFINE_ALIAS
from ..fluid.dygraph.parallel import prepare_context #DEFINE_ALIAS
from ..fluid.dygraph.parallel import ParallelEnv #DEFINE_ALIAS
from ..fluid.dygraph.parallel import DataParallel #DEFINE_ALIAS from ..fluid.dygraph.parallel import DataParallel #DEFINE_ALIAS
from ..fluid.dygraph.learning_rate_scheduler import NoamDecay #DEFINE_ALIAS from ..fluid.dygraph.learning_rate_scheduler import NoamDecay #DEFINE_ALIAS

Loading…
Cancel
Save