Add interface to launch parallel dygraph by multiprocessing (#26044)
* add dygraph parallel run interface * polish implement & unified env property name * add print config arg * refactor init_parallel_env function * Compatible with multiprocessing and launch modes * set default trainer start port * support run in python 2 * polish python2 support code * remove python2 support * refine launch import * polish dome design details * refactor api implemention & path * use new method _set_expected_place * add spawn unittest framework & mnist test * add more unittests & doc * fix unittest failed * polish english doc * self review and polish details * refactor code by reviewer's comments * fix unittest failed * fix parallel_env unittest * fix several typos * fix error introduced when fixing typos * add unpublic note for start_processes * polish details by xiaoguang's comment * verify correctly when spawn nprocs=-1 * refactor spawn & init_parallel_env design * polish doc details * open spawn unittests * try to fix doc compile error * try to fix unknown doc format error * add skip unittest when not gpurevert-26856-strategy_example2
parent
3390c7e260
commit
31f422ae5e
@ -0,0 +1,184 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except jin compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import six
|
||||
import warnings
|
||||
|
||||
from paddle import compat as cpt
|
||||
|
||||
# deprecated module import
|
||||
from paddle.fluid import core
|
||||
from paddle.fluid.framework import _set_expected_place
|
||||
from paddle.fluid.dygraph import parallel_helper
|
||||
from paddle.fluid.dygraph.parallel import ParallelEnv
|
||||
|
||||
__all__ = ["init_parallel_env"]
|
||||
|
||||
ParallelStrategy = core.ParallelStrategy
|
||||
|
||||
|
||||
def init_parallel_env(backend='nccl'):
|
||||
"""
|
||||
Initialize parallel training environments in dynamic mode.
|
||||
|
||||
Args:
|
||||
backend(str, optional): The backend to communication between multiple devices.
|
||||
Now only support ``nccl`` . Default value is ``nccl`` .
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.optimizer as opt
|
||||
import paddle.distributed as dist
|
||||
|
||||
class LinearNet(nn.Layer):
|
||||
def __init__(self):
|
||||
super(LinearNet, self).__init__()
|
||||
self._linear1 = nn.Linear(10, 10)
|
||||
self._linear2 = nn.Linear(10, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self._linear2(self._linear1(x))
|
||||
|
||||
def train():
|
||||
# 1. enable dynamic mode
|
||||
paddle.disable_static()
|
||||
|
||||
# 2. initialize parallel environment
|
||||
dist.init_parallel_env()
|
||||
|
||||
# 3. create data parallel layer & optimizer
|
||||
layer = LinearNet()
|
||||
dp_layer = paddle.DataParallel(layer)
|
||||
|
||||
loss_fn = nn.MSELoss()
|
||||
adam = opt.Adam(
|
||||
learning_rate=0.001, parameters=dp_layer.parameters())
|
||||
|
||||
# 4. run layer
|
||||
inputs = paddle.randn([10, 10], 'float32')
|
||||
outputs = dp_layer(inputs)
|
||||
labels = paddle.randn([10, 1], 'float32')
|
||||
loss = loss_fn(outputs, labels)
|
||||
|
||||
loss = dp_layer.scale_loss(loss)
|
||||
loss.backward()
|
||||
dp_layer.apply_collective_grads()
|
||||
|
||||
adam.step()
|
||||
adam.clear_grad()
|
||||
|
||||
if __name__ == '__main__':
|
||||
dist.spawn(train)
|
||||
"""
|
||||
|
||||
# 1. input check
|
||||
if not isinstance(backend, six.string_types):
|
||||
raise TypeError("input `backend` type error, expected type is str, "
|
||||
"but received type is %s." % type(backend))
|
||||
if cpt.to_text(backend) != 'nccl':
|
||||
raise ValueError(
|
||||
"backend `%s` is not supported, now only supports `nccl` backend." %
|
||||
backend)
|
||||
|
||||
# 2. check env
|
||||
def _check_var_exists(var_name):
|
||||
var = os.environ.get(var_name, None)
|
||||
if var is None:
|
||||
raise ValueError("paddle.distributed initialize error, "
|
||||
"environment variable %s is needed, but not set." %
|
||||
var_name)
|
||||
|
||||
_check_var_exists("FLAGS_selected_gpus")
|
||||
_check_var_exists("PADDLE_TRAINER_ID")
|
||||
_check_var_exists("PADDLE_CURRENT_ENDPOINT")
|
||||
_check_var_exists("PADDLE_TRAINERS_NUM")
|
||||
_check_var_exists("PADDLE_TRAINER_ENDPOINTS")
|
||||
|
||||
# 3. init ParallelStrategy
|
||||
strategy = ParallelStrategy()
|
||||
if cpt.to_text(backend) == 'nccl':
|
||||
if parallel_helper._is_parallel_ctx_initialized():
|
||||
warnings.warn("The parallel environment has been initialized.")
|
||||
strategy.nranks = ParallelEnv().world_size
|
||||
strategy.local_rank = ParallelEnv().rank
|
||||
strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
|
||||
strategy.current_endpoint = ParallelEnv().current_endpoint
|
||||
if strategy.nranks < 2:
|
||||
return
|
||||
# NOTE(chenweihang): [ why config global place here? ]
|
||||
# the dygraph mode will be set to default mode,
|
||||
# users will not call `dygraph.guard` or `enable_dygraph`
|
||||
# directly, if they want to switch default place,
|
||||
# they need to call a function to change default place,
|
||||
# here just set correctly place to users
|
||||
place = core.CUDAPlace(ParallelEnv().device_id)
|
||||
_set_expected_place(place)
|
||||
|
||||
# init nccl context
|
||||
parallel_helper._set_parallel_ctx(
|
||||
core.NCCLParallelContext(strategy, place))
|
||||
parallel_helper._init_parallel_ctx()
|
||||
|
||||
|
||||
def get_rank():
|
||||
"""
|
||||
Returns the rank of current trainer.
|
||||
|
||||
Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` .
|
||||
The default value is 0.
|
||||
|
||||
Returns:
|
||||
(int) The rank of current trainer.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
|
||||
# execute this command in terminal: export PADDLE_TRAINER_ID=0
|
||||
print("The rank is %d" % dist.get_rank())
|
||||
# The rank is 0
|
||||
"""
|
||||
return ParallelEnv().rank
|
||||
|
||||
|
||||
def get_world_size():
|
||||
"""
|
||||
The number of trainers (number of processes participating in current job).
|
||||
|
||||
Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` .
|
||||
The default value is 1.
|
||||
|
||||
Returns:
|
||||
(int) The number of trainers.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
|
||||
# execute this command in terminal: export PADDLE_TRAINERS_NUM=4
|
||||
print("The world_size is %d" % dist.get_world_size())
|
||||
# The world_size is 4
|
||||
"""
|
||||
return ParallelEnv().world_size
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,81 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function, division
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
|
||||
# used by model.run_trainer in test_dist_base
|
||||
from test_dist_base import RUN_STEP
|
||||
|
||||
|
||||
# NOTE: compatible TestParallelDyGraphRunnerBase args
|
||||
class SpawnAssistTestArgs(object):
|
||||
update_method = "local"
|
||||
trainer_id = 0
|
||||
|
||||
|
||||
class TestDistSpawnRunner(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# NOTE(chenweihang): keep consistent with
|
||||
# TestDistBase.check_with_place
|
||||
self.nprocs = 2
|
||||
|
||||
def _run(self, model, args):
|
||||
args.update_method = "local"
|
||||
return model.run_trainer_with_spawn(args)
|
||||
|
||||
def _run_parallel(self, model, args):
|
||||
args.update_method = "nccl2"
|
||||
context = paddle.distributed.spawn(
|
||||
func=model.run_trainer_with_spawn,
|
||||
args=(args, ),
|
||||
nprocs=self.nprocs,
|
||||
join=True)
|
||||
result_list = []
|
||||
for res_queue in context.return_queues:
|
||||
result_list.append(res_queue.get())
|
||||
return result_list
|
||||
|
||||
def check_dist_result_with_spawn(self, test_class, delta=1e-3):
|
||||
# 0. prepare model and args
|
||||
model = test_class()
|
||||
args = SpawnAssistTestArgs()
|
||||
|
||||
# 1. calc signal card loss
|
||||
losses = self._run(model, args)
|
||||
|
||||
# 2. calc multi card loss (nccl mode)
|
||||
dist_losses_list = self._run_parallel(model, args)
|
||||
|
||||
# 3. compare losses
|
||||
for step_id in range(RUN_STEP):
|
||||
loss = losses[step_id]
|
||||
dist_loss_sum = None
|
||||
for dist_losses in dist_losses_list:
|
||||
if dist_loss_sum is None:
|
||||
dist_loss_sum = np.array(dist_losses[step_id])
|
||||
else:
|
||||
dist_loss_sum += np.array(dist_losses[step_id])
|
||||
dist_loss = dist_loss_sum / self.nprocs
|
||||
self.assertAlmostEqual(
|
||||
loss,
|
||||
dist_loss,
|
||||
delta=delta,
|
||||
msg="The results of single-card execution and multi-card execution are inconsistent."
|
||||
"signal-card loss is:\n{}\nmulti-card average loss is:\n{}\n".
|
||||
format(loss, dist_loss))
|
@ -0,0 +1,87 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
from paddle.distributed.spawn import _get_subprocess_env_list
|
||||
|
||||
from paddle.fluid import core
|
||||
from paddle.fluid.dygraph import parallel_helper
|
||||
|
||||
# NOTE(chenweihang): Coverage CI is currently not able to count python3
|
||||
# unittest, so the unittests here covers some cases that will only be
|
||||
# executed in the python3 sub-process.
|
||||
|
||||
|
||||
class TestInitParallelEnv(unittest.TestCase):
|
||||
def test_beckend_type_error(self):
|
||||
with self.assertRaises(TypeError):
|
||||
dist.init_parallel_env(backend=1)
|
||||
|
||||
def test_backend_value_error(self):
|
||||
with self.assertRaises(ValueError):
|
||||
dist.init_parallel_env(backend="mpi")
|
||||
|
||||
def test_check_env_failed(self):
|
||||
os.environ['FLAGS_selected_gpus'] = '0'
|
||||
os.environ['PADDLE_TRAINER_ID'] = '0'
|
||||
os.environ['PADDLE_CURRENT_ENDPOINT'] = '127.0.0.1:6170'
|
||||
os.environ['PADDLE_TRAINERS_NUM'] = '1'
|
||||
with self.assertRaises(ValueError):
|
||||
dist.init_parallel_env()
|
||||
|
||||
def test_init_parallel_env_break(self):
|
||||
os.environ['FLAGS_selected_gpus'] = '0'
|
||||
os.environ['PADDLE_TRAINER_ID'] = '0'
|
||||
os.environ['PADDLE_CURRENT_ENDPOINT'] = '127.0.0.1:6170'
|
||||
os.environ['PADDLE_TRAINERS_NUM'] = '1'
|
||||
os.environ['PADDLE_TRAINER_ENDPOINTS'] = '127.0.0.1:6170'
|
||||
# coverage success branch
|
||||
dist.init_parallel_env()
|
||||
self.assertFalse(parallel_helper._is_parallel_ctx_initialized())
|
||||
|
||||
|
||||
@unittest.skipIf(not core.is_compiled_with_cuda(),
|
||||
"core is not compiled with CUDA")
|
||||
class TestSpawnAssistMethod(unittest.TestCase):
|
||||
def test_only_cluster_node_ips_error(self):
|
||||
with self.assertRaises(ValueError):
|
||||
options = dict()
|
||||
options['cluster_node_ips'] = "127.0.0.1,127.0.0.2"
|
||||
_get_subprocess_env_list(nprocs=1, options=options)
|
||||
|
||||
def test_nprocs_greater_than_device_num_error(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
_get_subprocess_env_list(nprocs=100, options=dict())
|
||||
|
||||
def test_selected_gpus_error(self):
|
||||
with self.assertRaises(ValueError):
|
||||
options = dict()
|
||||
options['selected_gpus'] = "100,101"
|
||||
_get_subprocess_env_list(nprocs=2, options=options)
|
||||
|
||||
def test_get_correct_env(self):
|
||||
env_dict = _get_subprocess_env_list(nprocs=1, options=dict())[0]
|
||||
self.assertEqual(env_dict['PADDLE_TRAINER_ID'], '0')
|
||||
self.assertEqual(env_dict['PADDLE_TRAINERS_NUM'], '1')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue