Add interface to launch parallel dygraph by multiprocessing (#26044)

* add dygraph parallel run interface

* polish implement & unified env property name

* add print config arg

* refactor init_parallel_env function

* Compatible with multiprocessing and launch modes

* set default trainer start port

* support run in python 2

* polish python2 support code

* remove python2 support

* refine launch import

* polish dome design details

* refactor api implemention & path

* use new method _set_expected_place

* add spawn unittest framework & mnist test

* add more unittests & doc

* fix unittest failed

* polish english doc

* self review and polish details

* refactor code by reviewer's comments

* fix unittest failed

* fix parallel_env unittest

* fix several typos

* fix error introduced when fixing typos

* add unpublic note for start_processes

* polish details by xiaoguang's comment

* verify correctly when spawn nprocs=-1

* refactor spawn & init_parallel_env design

* polish doc details

* open spawn unittests

* try to fix doc compile error

* try to fix unknown doc format error

* add skip unittest when not gpu
revert-26856-strategy_example2
Chen Weihang 5 years ago committed by GitHub
parent 3390c7e260
commit 31f422ae5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -230,8 +230,6 @@ from .framework import grad #DEFINE_ALIAS
from .framework import no_grad #DEFINE_ALIAS
from .framework import save #DEFINE_ALIAS
from .framework import load #DEFINE_ALIAS
from .framework import prepare_context #DEFINE_ALIAS
from .framework import ParallelEnv #DEFINE_ALIAS
from .framework import DataParallel #DEFINE_ALIAS
from .framework import NoamDecay #DEFINE_ALIAS

@ -12,4 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import spawn
from .spawn import spawn
from . import parallel
from .parallel import init_parallel_env
from .parallel import get_rank
from .parallel import get_world_size
from paddle.fluid.dygraph.parallel import prepare_context #DEFINE_ALIAS
from paddle.fluid.dygraph.parallel import ParallelEnv #DEFINE_ALIAS
from . import collective
from .collective import *
# start multiprocess apis
__all__ = ["spawn"]
# dygraph parallel apis
__all__ += [
"init_parallel_env", "get_rank", "get_world_size", "prepare_context",
"ParallelEnv"
]
# collective apis
__all__ += collective.__all__

@ -44,11 +44,9 @@ import time
import six
import copy
from argparse import ArgumentParser, REMAINDER
import paddle
import paddle.fluid as fluid
from paddle.distributed.utils import *
import paddle.distributed.cloud_utils as cloud_utils
from paddle.distributed import cloud_utils
def _print_arguments(args):
@ -167,7 +165,8 @@ def get_cluster_from_args(args, selected_gpus):
def get_gpus(selected_gpus):
if selected_gpus is None:
gpus_num = fluid.core.get_cuda_device_count()
from paddle.fluid import core
gpus_num = core.get_cuda_device_count()
selected_gpus = [str(x) for x in range(0, gpus_num)]
else:
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
@ -190,7 +189,7 @@ def get_gpus(selected_gpus):
return selected_gpus
def launch(args):
def get_cluster_and_pod(args):
# parse arguments, used for cloud-single-machine and local
selected_gpus = get_gpus(args.selected_gpus)
trainers_num = cloud_utils.get_trainers_num()
@ -209,6 +208,12 @@ def launch(args):
cluster, pod = get_cluster_from_args(args, selected_gpus)
logger.info("get cluster from args:{}".format(cluster))
return cluster, pod
def launch(args):
cluster, pod = get_cluster_and_pod(args)
procs = start_local_trainers(
cluster,
pod,

@ -0,0 +1,184 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import six
import warnings
from paddle import compat as cpt
# deprecated module import
from paddle.fluid import core
from paddle.fluid.framework import _set_expected_place
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph.parallel import ParallelEnv
__all__ = ["init_parallel_env"]
ParallelStrategy = core.ParallelStrategy
def init_parallel_env(backend='nccl'):
"""
Initialize parallel training environments in dynamic mode.
Args:
backend(str, optional): The backend to communication between multiple devices.
Now only support ``nccl`` . Default value is ``nccl`` .
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
# 2. initialize parallel environment
dist.init_parallel_env()
# 3. create data parallel layer & optimizer
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)
loss_fn = nn.MSELoss()
adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters())
# 4. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
if __name__ == '__main__':
dist.spawn(train)
"""
# 1. input check
if not isinstance(backend, six.string_types):
raise TypeError("input `backend` type error, expected type is str, "
"but received type is %s." % type(backend))
if cpt.to_text(backend) != 'nccl':
raise ValueError(
"backend `%s` is not supported, now only supports `nccl` backend." %
backend)
# 2. check env
def _check_var_exists(var_name):
var = os.environ.get(var_name, None)
if var is None:
raise ValueError("paddle.distributed initialize error, "
"environment variable %s is needed, but not set." %
var_name)
_check_var_exists("FLAGS_selected_gpus")
_check_var_exists("PADDLE_TRAINER_ID")
_check_var_exists("PADDLE_CURRENT_ENDPOINT")
_check_var_exists("PADDLE_TRAINERS_NUM")
_check_var_exists("PADDLE_TRAINER_ENDPOINTS")
# 3. init ParallelStrategy
strategy = ParallelStrategy()
if cpt.to_text(backend) == 'nccl':
if parallel_helper._is_parallel_ctx_initialized():
warnings.warn("The parallel environment has been initialized.")
strategy.nranks = ParallelEnv().world_size
strategy.local_rank = ParallelEnv().rank
strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
strategy.current_endpoint = ParallelEnv().current_endpoint
if strategy.nranks < 2:
return
# NOTE(chenweihang): [ why config global place here? ]
# the dygraph mode will be set to default mode,
# users will not call `dygraph.guard` or `enable_dygraph`
# directly, if they want to switch default place,
# they need to call a function to change default place,
# here just set correctly place to users
place = core.CUDAPlace(ParallelEnv().device_id)
_set_expected_place(place)
# init nccl context
parallel_helper._set_parallel_ctx(
core.NCCLParallelContext(strategy, place))
parallel_helper._init_parallel_ctx()
def get_rank():
"""
Returns the rank of current trainer.
Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` .
The default value is 0.
Returns:
(int) The rank of current trainer.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
# execute this command in terminal: export PADDLE_TRAINER_ID=0
print("The rank is %d" % dist.get_rank())
# The rank is 0
"""
return ParallelEnv().rank
def get_world_size():
"""
The number of trainers (number of processes participating in current job).
Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` .
The default value is 1.
Returns:
(int) The number of trainers.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
# execute this command in terminal: export PADDLE_TRAINERS_NUM=4
print("The world_size is %d" % dist.get_world_size())
# The world_size is 4
"""
return ParallelEnv().world_size

File diff suppressed because it is too large Load Diff

@ -327,6 +327,17 @@ def find_free_ports(num):
return None
def _prepare_trainer_env(cluster, trainer):
proc_env = {
"FLAGS_selected_gpus": "%s" % ",".join([str(g) for g in trainer.gpus]),
"PADDLE_TRAINER_ID": "%d" % trainer.rank,
"PADDLE_CURRENT_ENDPOINT": "%s" % trainer.endpoint,
"PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(),
"PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints())
}
return proc_env
class TrainerProc(object):
def __init__(self):
self.proc = None
@ -352,14 +363,7 @@ def start_local_trainers(cluster,
procs = []
for idx, t in enumerate(pod.trainers):
proc_env = {
"FLAGS_selected_gpus": "%s" % ",".join([str(g) for g in t.gpus]),
"PADDLE_TRAINER_ID": "%d" % t.rank,
"PADDLE_CURRENT_ENDPOINT": "%s" % t.endpoint,
"PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(),
"PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints())
}
proc_env = _prepare_trainer_env(cluster, t)
current_env.update(proc_env)
logger.debug("trainer proc env:{}".format(current_env))

File diff suppressed because it is too large Load Diff

@ -23,6 +23,11 @@ def _is_data_parallel_mode():
os.getenv("PADDLE_TRAINERS_NUM", "1")) > 1
def _is_parallel_ctx_initialized():
global __parallel_ctx__clz__
return __parallel_ctx__clz__ is not None
def _set_parallel_ctx(nccl_parallel_context):
global __parallel_ctx__clz__
assert __parallel_ctx__clz__ is None, \

@ -0,0 +1,81 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function, division
import numpy as np
import unittest
import paddle
# used by model.run_trainer in test_dist_base
from test_dist_base import RUN_STEP
# NOTE: compatible TestParallelDyGraphRunnerBase args
class SpawnAssistTestArgs(object):
update_method = "local"
trainer_id = 0
class TestDistSpawnRunner(unittest.TestCase):
def setUp(self):
# NOTE(chenweihang): keep consistent with
# TestDistBase.check_with_place
self.nprocs = 2
def _run(self, model, args):
args.update_method = "local"
return model.run_trainer_with_spawn(args)
def _run_parallel(self, model, args):
args.update_method = "nccl2"
context = paddle.distributed.spawn(
func=model.run_trainer_with_spawn,
args=(args, ),
nprocs=self.nprocs,
join=True)
result_list = []
for res_queue in context.return_queues:
result_list.append(res_queue.get())
return result_list
def check_dist_result_with_spawn(self, test_class, delta=1e-3):
# 0. prepare model and args
model = test_class()
args = SpawnAssistTestArgs()
# 1. calc signal card loss
losses = self._run(model, args)
# 2. calc multi card loss (nccl mode)
dist_losses_list = self._run_parallel(model, args)
# 3. compare losses
for step_id in range(RUN_STEP):
loss = losses[step_id]
dist_loss_sum = None
for dist_losses in dist_losses_list:
if dist_loss_sum is None:
dist_loss_sum = np.array(dist_losses[step_id])
else:
dist_loss_sum += np.array(dist_losses[step_id])
dist_loss = dist_loss_sum / self.nprocs
self.assertAlmostEqual(
loss,
dist_loss,
delta=delta,
msg="The results of single-card execution and multi-card execution are inconsistent."
"signal-card loss is:\n{}\nmulti-card average loss is:\n{}\n".
format(loss, dist_loss))

@ -38,9 +38,10 @@ class TestDirectory(unittest.TestCase):
'paddle.enable_static', 'paddle.disable_static',
'paddle.in_dynamic_mode', 'paddle.to_variable', 'paddle.grad',
'paddle.no_grad', 'paddle.save', 'paddle.load',
'paddle.static.save', 'paddle.static.load', 'paddle.ParallelEnv',
'paddle.prepare_context', 'paddle.DataParallel', 'paddle.jit',
'paddle.jit.TracedLayer', 'paddle.jit.to_static',
'paddle.static.save', 'paddle.static.load',
'paddle.distributed.ParallelEnv',
'paddle.distributed.prepare_context', 'paddle.DataParallel',
'paddle.jit', 'paddle.jit.TracedLayer', 'paddle.jit.to_static',
'paddle.jit.ProgramTranslator', 'paddle.jit.TranslatedLayer',
'paddle.jit.save', 'paddle.jit.load', 'paddle.jit.SaveLoadConfig',
'paddle.NoamDecay', 'paddle.PiecewiseDecay',

@ -23,8 +23,11 @@ import subprocess
import six
import argparse
import pickle
import random
import numpy as np
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid import compiler
import paddle.fluid.dygraph as dygraph
@ -382,13 +385,7 @@ class TestParallelDyGraphRunnerBase(object):
raise NotImplementedError(
"train_one_loop should be implemented by the child classes.")
def run_trainer(self, args):
seed = 90
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id)
def _get_data(batch):
def _get_data(self, batch, args):
if args.update_method != "local":
new_batch = []
for offset, item in enumerate(batch):
@ -398,6 +395,12 @@ class TestParallelDyGraphRunnerBase(object):
else:
return batch
def run_trainer(self, args):
seed = 90
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id)
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
@ -422,7 +425,7 @@ class TestParallelDyGraphRunnerBase(object):
out_losses = []
print_to_err(type(self).__name__, "begin to run dygraph training")
for step_id, data in enumerate(train_reader()):
data = _get_data(data)
data = self._get_data(data, args)
if step_id == RUN_STEP:
break
loss = self.run_one_loop(model, opt, data)
@ -444,6 +447,47 @@ class TestParallelDyGraphRunnerBase(object):
model.clear_gradients()
print_to_out(out_losses)
def run_trainer_with_spawn(self, args):
# 1. enable dygraph
paddle.disable_static()
# 2. init seed
seed = 90
paddle.static.default_startup_program().random_seed = seed
paddle.static.default_main_program().random_seed = seed
np.random.seed(seed)
random.seed = seed
# get trainer id
args.trainer_id = paddle.distributed.get_rank()
# 3. init parallel env
if args.update_method == "nccl2":
paddle.distributed.init_parallel_env()
# 4. train model
model, train_reader, opt = self.get_model()
if args.update_method == "nccl2":
model = paddle.DataParallel(model)
out_losses = []
for step_id, data in enumerate(train_reader()):
data = self._get_data(data, args)
if step_id == RUN_STEP:
break
loss = self.run_one_loop(model, opt, data)
out_losses.append(loss.numpy())
if args.update_method == "nccl2":
loss = model.scale_loss(loss)
loss.backward()
if args.update_method == "nccl2":
model.apply_collective_grads()
opt.minimize(loss)
model.clear_gradients()
return out_losses
def runtime_main(test_class):
parser = argparse.ArgumentParser(description='Run dist test.')

@ -43,7 +43,7 @@ class MLP(fluid.Layer):
class TestDataParallelStateDict(unittest.TestCase):
def test_data_parallel_state_dict(self):
with fluid.dygraph.guard():
strategy = paddle.prepare_context()
strategy = paddle.distributed.prepare_context()
mlp = MLP()
parallel_mlp = dygraph.parallel.DataParallel(mlp, strategy)

@ -13,11 +13,16 @@
# limitations under the License.
from __future__ import print_function
import os
import sys
import unittest
from test_dist_base import TestDistBase
import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_mnist import TestMnist
import os
flag_name = os.path.splitext(__file__)[0]
@ -36,5 +41,11 @@ class TestParallelDygraphMnist(TestDistBase):
log_name=flag_name)
class TestParallelDygraphMnistSpawn(TestDistSpawnRunner):
def test_mnist_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(test_class=TestMnist, delta=1e-5)
if __name__ == "__main__":
unittest.main()

@ -13,11 +13,16 @@
# limitations under the License.
from __future__ import print_function
import os
import sys
import unittest
from test_dist_base import TestDistBase
import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_se_resnext import TestSeResNeXt
import os
flag_name = os.path.splitext(__file__)[0]
@ -36,5 +41,12 @@ class TestParallelDygraphSeResNeXt(TestDistBase):
log_name=flag_name)
class TestParallelDygraphSeResNeXtSpawn(TestDistSpawnRunner):
def test_se_resnext_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestSeResNeXt, delta=0.01)
if __name__ == "__main__":
unittest.main()

@ -15,10 +15,13 @@
from __future__ import print_function
import os
import sys
import unittest
import paddle.fluid as fluid
import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_sparse_embedding import TestSparseEmbedding
flag_name = os.path.splitext(__file__)[0]
@ -38,5 +41,12 @@ class TestParallelDygraphSparseEmdedding(TestDistBase):
log_name=flag_name)
class TestParallelDygraphSparseEmdeddingSpawn(TestDistSpawnRunner):
def test_sparse_embedding_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestSparseEmbedding, delta=1e-5)
if __name__ == "__main__":
unittest.main()

@ -15,10 +15,13 @@
from __future__ import print_function
import os
import sys
import unittest
import paddle.fluid as fluid
import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
from parallel_dygraph_transformer import TestTransformer
flag_name = os.path.splitext(__file__)[0]
@ -38,5 +41,12 @@ class TestParallelDygraphTransformer(TestDistBase):
log_name=flag_name)
class TestParallelDygraphTransformerSpawn(TestDistSpawnRunner):
def test_transformer_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
self.check_dist_result_with_spawn(
test_class=TestTransformer, delta=1e-5)
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,87 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import numpy as np
import unittest
import paddle
import paddle.distributed as dist
from paddle.distributed.spawn import _get_subprocess_env_list
from paddle.fluid import core
from paddle.fluid.dygraph import parallel_helper
# NOTE(chenweihang): Coverage CI is currently not able to count python3
# unittest, so the unittests here covers some cases that will only be
# executed in the python3 sub-process.
class TestInitParallelEnv(unittest.TestCase):
def test_beckend_type_error(self):
with self.assertRaises(TypeError):
dist.init_parallel_env(backend=1)
def test_backend_value_error(self):
with self.assertRaises(ValueError):
dist.init_parallel_env(backend="mpi")
def test_check_env_failed(self):
os.environ['FLAGS_selected_gpus'] = '0'
os.environ['PADDLE_TRAINER_ID'] = '0'
os.environ['PADDLE_CURRENT_ENDPOINT'] = '127.0.0.1:6170'
os.environ['PADDLE_TRAINERS_NUM'] = '1'
with self.assertRaises(ValueError):
dist.init_parallel_env()
def test_init_parallel_env_break(self):
os.environ['FLAGS_selected_gpus'] = '0'
os.environ['PADDLE_TRAINER_ID'] = '0'
os.environ['PADDLE_CURRENT_ENDPOINT'] = '127.0.0.1:6170'
os.environ['PADDLE_TRAINERS_NUM'] = '1'
os.environ['PADDLE_TRAINER_ENDPOINTS'] = '127.0.0.1:6170'
# coverage success branch
dist.init_parallel_env()
self.assertFalse(parallel_helper._is_parallel_ctx_initialized())
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSpawnAssistMethod(unittest.TestCase):
def test_only_cluster_node_ips_error(self):
with self.assertRaises(ValueError):
options = dict()
options['cluster_node_ips'] = "127.0.0.1,127.0.0.2"
_get_subprocess_env_list(nprocs=1, options=options)
def test_nprocs_greater_than_device_num_error(self):
with self.assertRaises(RuntimeError):
_get_subprocess_env_list(nprocs=100, options=dict())
def test_selected_gpus_error(self):
with self.assertRaises(ValueError):
options = dict()
options['selected_gpus'] = "100,101"
_get_subprocess_env_list(nprocs=2, options=options)
def test_get_correct_env(self):
env_dict = _get_subprocess_env_list(nprocs=1, options=dict())[0]
self.assertEqual(env_dict['PADDLE_TRAINER_ID'], '0')
self.assertEqual(env_dict['PADDLE_TRAINERS_NUM'], '1')
if __name__ == "__main__":
unittest.main()

@ -50,8 +50,6 @@ from ..fluid.dygraph.base import to_variable #DEFINE_ALIAS
from ..fluid.dygraph.base import grad #DEFINE_ALIAS
from ..fluid.dygraph.checkpoint import load_dygraph as load #DEFINE_ALIAS
from ..fluid.dygraph.checkpoint import save_dygraph as save #DEFINE_ALIAS
from ..fluid.dygraph.parallel import prepare_context #DEFINE_ALIAS
from ..fluid.dygraph.parallel import ParallelEnv #DEFINE_ALIAS
from ..fluid.dygraph.parallel import DataParallel #DEFINE_ALIAS
from ..fluid.dygraph.learning_rate_scheduler import NoamDecay #DEFINE_ALIAS

Loading…
Cancel
Save