You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
627 lines
22 KiB
627 lines
22 KiB
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
|
|
import logging
|
|
|
|
import paddle.fluid as fluid
|
|
import paddle.fluid.io as io
|
|
import paddle.fluid.transpiler.distribute_transpiler as dist_transpiler
|
|
from paddle.fluid.executor import Executor
|
|
from paddle.fluid.parallel_executor import ParallelExecutor
|
|
from paddle.fluid.compiler import CompiledProgram
|
|
from paddle.fluid.framework import Program
|
|
|
|
from paddle.fluid.incubate.fleet.base.fleet_base import Fleet
|
|
from paddle.fluid.incubate.fleet.base.fleet_base import Mode
|
|
from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
|
|
|
|
from paddle.fluid import compiler
|
|
from paddle.distributed.fs_wrapper import LocalFS, BDFS
|
|
|
|
import os
|
|
import sys
|
|
import six
|
|
import json
|
|
import re
|
|
import shutil
|
|
|
|
|
|
class LambConfig(object):
|
|
def __init__(self):
|
|
pass
|
|
|
|
|
|
class DistFCConfig(object):
|
|
def __init__(self):
|
|
pass
|
|
|
|
|
|
class TrainStatus(object):
|
|
def __init__(self, epoch_no=-1):
|
|
# completed epoch
|
|
self._epoch_no = epoch_no
|
|
|
|
def next(self):
|
|
return self._epoch_no + 1
|
|
|
|
def __eq__(self, t):
|
|
return self._epoch_no == t._epoch_no
|
|
|
|
def __ne__(self, t):
|
|
return not self == t
|
|
|
|
|
|
class Collective(Fleet):
|
|
def __init__(self):
|
|
super(Collective, self).__init__(Mode.COLLECTIVE)
|
|
self._local_ip = 0
|
|
|
|
self.startup_program = None
|
|
self._origin_program = None
|
|
self._transpiled_program = None
|
|
self.main_program = None
|
|
self._checkoint_prefix = "__paddle_fleet_checkpoint__"
|
|
self._param_file_name = "_paddle_fleet_param__"
|
|
|
|
def init_worker(self):
|
|
logging.warn(
|
|
"You should not call 'init_worker' method for collective mode.")
|
|
|
|
def run_worker(self, main_programs=None, scopes=None):
|
|
logging.warn(
|
|
"You should not call 'run_worker' method for collective mode.")
|
|
|
|
def init_server(self, model_dir=None):
|
|
logging.warn(
|
|
"You should not call 'init_server' method for collective mode.")
|
|
|
|
def run_server(self):
|
|
logging.warn(
|
|
"You should not call 'run_server' method for collective mode.")
|
|
|
|
def stop_worker(self):
|
|
logging.warn(
|
|
"You should not call 'stop_worker' method for collective mode.")
|
|
|
|
def distributed_optimizer(self, optimizer, strategy=None):
|
|
self._optimizer = \
|
|
CollectiveOptimizer(optimizer, strategy)
|
|
return self._optimizer
|
|
|
|
def save_inference_model(self,
|
|
executor,
|
|
dirname,
|
|
feeded_var_names=None,
|
|
target_vars=None,
|
|
main_program=None,
|
|
export_for_deployment=True):
|
|
"""
|
|
Prune the given `main_program` to build a new program especially for
|
|
inference, and then save it and all related parameters to given
|
|
`dirname` by the `executor`.
|
|
"""
|
|
assert isinstance(executor, Executor), \
|
|
"In fleet.save_inference_model() function, executor must be as" \
|
|
" Executor type."
|
|
|
|
if main_program is None:
|
|
main_program = self._origin_program
|
|
assert isinstance(main_program, Program), \
|
|
"In fleet.save_inference_model() function, main_program " \
|
|
"must be as Program type."
|
|
|
|
io.save_inference_model(dirname, feeded_var_names, target_vars,
|
|
executor, main_program, None, None,
|
|
export_for_deployment)
|
|
|
|
def save_persistables(self,
|
|
executor,
|
|
dirname,
|
|
main_program=None,
|
|
filename=None):
|
|
"""
|
|
This function filters out all variables with `persistable==True` from
|
|
the give `main_program` and then saves these variables to the folder
|
|
`dirname` or file `filename`.
|
|
|
|
The `dirname` is used to specify the folder where persistable variables
|
|
are going to be saved. If you would like to save variables in separate
|
|
files, set `filename` None; if you would like to save all variables in a
|
|
single file, use `filename` to specify the file name.
|
|
"""
|
|
assert isinstance(executor, Executor), \
|
|
"In fleet.save_inference_model() function, executor must be as" \
|
|
" Executor type."
|
|
|
|
if main_program is None:
|
|
main_program = self._origin_program
|
|
|
|
assert isinstance(main_program, Program), \
|
|
"In fleet.save_inference_model() function, main_program " \
|
|
"must be as Program type."
|
|
|
|
io.save_persistables(executor, dirname, main_program, filename=filename)
|
|
|
|
def _save_train_status(self, path, train_status):
|
|
d = {}
|
|
d["epoch_no"] = train_status._epoch_no
|
|
|
|
file_name = "{}/fleet_train_status".format(path)
|
|
with open(file_name, 'w') as f:
|
|
json.dump(d, f)
|
|
|
|
def _load_train_status(self, path):
|
|
file_name = "{}/fleet_train_status".format(path)
|
|
|
|
r = TrainStatus()
|
|
if not os.path.isfile(file_name):
|
|
return r
|
|
|
|
d = {}
|
|
with open(file_name, 'r') as f:
|
|
d = json.load(f)
|
|
|
|
assert "epoch_no" in d, "Can't find epoch_no in dict from train_status file:{}".format(
|
|
d)
|
|
r._epoch_no = d["epoch_no"]
|
|
assert r._epoch_no >= 0, "Data in checkpoint file is not valid:{}".format(
|
|
d)
|
|
|
|
return r
|
|
|
|
def _get_last_checkpoint_no(self, root_path, fs):
|
|
"""
|
|
only get the first depth
|
|
"""
|
|
max_no = -1
|
|
d = {}
|
|
dirs = fs.list_dirs(root_path)
|
|
for dir in dirs:
|
|
g = dir.split(".")
|
|
if len(g) != 2:
|
|
continue
|
|
|
|
if g[0] != "__paddle_fleet_checkpoint__":
|
|
continue
|
|
|
|
try:
|
|
n = int(g[1])
|
|
if n > max_no:
|
|
max_no = n
|
|
except:
|
|
continue
|
|
|
|
return max_no
|
|
|
|
def clean_redundant_check_points(self,
|
|
root_path,
|
|
fs=LocalFS(),
|
|
checkpoint_num=1):
|
|
max_no = self._get_last_checkpoint_no(root_path, fs)
|
|
if max_no < 0:
|
|
return
|
|
|
|
if checkpoint_num < 1:
|
|
checkpoint_num = 1
|
|
|
|
dirs = fs.list_dirs(root_path)
|
|
for dir in dirs:
|
|
g = dir.split(".")
|
|
if len(g) != 2:
|
|
continue
|
|
|
|
if g[0] != self._checkoint_prefix:
|
|
continue
|
|
|
|
try:
|
|
n = int(g[1])
|
|
if n <= max_no - checkpoint_num:
|
|
path = "{}/{}.{}".format(root_path, self._checkoint_prefix,
|
|
n)
|
|
fs.rmr(path)
|
|
except Exception as e:
|
|
print(e)
|
|
continue
|
|
|
|
def save_check_point(self,
|
|
executor,
|
|
path,
|
|
train_status,
|
|
main_program=None,
|
|
fs=LocalFS(),
|
|
local_cache_path=".cache",
|
|
remain_all_checkpoint=True):
|
|
"""
|
|
This function save persistables and current epoch num to path.
|
|
"""
|
|
|
|
if main_program == None:
|
|
main_program = self._transpiled_program
|
|
|
|
if not fs.stat(path):
|
|
fs.mkdir(path)
|
|
|
|
max_no = self._get_last_checkpoint_no(path, fs=fs)
|
|
if max_no < 0:
|
|
max_no = -1
|
|
|
|
real_path = "{}/{}.{}".format(path, self._checkoint_prefix, max_no + 1)
|
|
tmp_path = "{}.tmp".format(real_path)
|
|
saved_path = tmp_path
|
|
|
|
local_fs = LocalFS()
|
|
|
|
cache_path = None
|
|
if fs.need_upload_download():
|
|
cache_path = "{}/{}.{}.saved_cache".format(
|
|
local_cache_path, self._checkoint_prefix, max_no + 1)
|
|
if not local_fs.stat(cache_path):
|
|
local_fs.mkdir(cache_path)
|
|
saved_path = cache_path
|
|
|
|
self.save_persistables(
|
|
executor=executor,
|
|
dirname=saved_path,
|
|
main_program=main_program,
|
|
filename=self._param_file_name)
|
|
self._save_train_status(path=saved_path, train_status=train_status)
|
|
|
|
if fs.need_upload_download():
|
|
fs.delete(tmp_path)
|
|
fs.upload(cache_path, tmp_path)
|
|
fs.mv(tmp_path, real_path)
|
|
|
|
if not remain_all_checkpoint:
|
|
self.clean_redundant_check_points(path)
|
|
|
|
def load_check_point(self,
|
|
executor,
|
|
path,
|
|
trainer_id,
|
|
main_program=None,
|
|
fs=LocalFS(),
|
|
local_cache_path=".cache",
|
|
ignore_empty=True):
|
|
"""
|
|
This function load persistables and current epoch num from path.
|
|
"""
|
|
max_no = self._get_last_checkpoint_no(path, fs)
|
|
|
|
if not ignore_empty:
|
|
assert max_no >= 0, "Can't find checkpoint"
|
|
|
|
if max_no < 0:
|
|
return None
|
|
|
|
local_fs = LocalFS()
|
|
if fs.need_upload_download():
|
|
cache_path = "{}/{}.{}.load_cache.{}".format(
|
|
local_cache_path, self._checkoint_prefix, max_no, trainer_id)
|
|
if local_fs.stat(cache_path):
|
|
local_fs.delete(cache_path)
|
|
|
|
real_path = "{}/{}.{}".format(path, self._checkoint_prefix, max_no)
|
|
load_path = real_path
|
|
if fs.need_upload_download():
|
|
fs.download(real_path, cache_path)
|
|
load_path = cache_path
|
|
|
|
if main_program == None:
|
|
main_program = self._transpiled_program
|
|
|
|
io.load_persistables(
|
|
executor=executor,
|
|
dirname=load_path,
|
|
main_program=main_program,
|
|
filename=self._param_file_name)
|
|
|
|
return self._load_train_status(load_path)
|
|
|
|
|
|
fleet = Collective()
|
|
|
|
|
|
class DistributedStrategy(fluid.BuildStrategy):
|
|
"""
|
|
Init function of DistributedStrategy
|
|
"""
|
|
|
|
def __init__(self):
|
|
super(DistributedStrategy, self).__init__()
|
|
self.use_local_sgd = False
|
|
self.use_dist_fc = False
|
|
|
|
self.dist_fc_config = None # DistFCConfig
|
|
self.mode = "nccl2" # or collective
|
|
self.collective_mode = None # local_sgd or grad_allreduce
|
|
self.nccl_comm_num = 1
|
|
self.forward_recompute = False
|
|
self.recompute_checkpoints = []
|
|
|
|
self.exec_strategy = fluid.ExecutionStrategy()
|
|
|
|
# configurations below are used for unit test
|
|
self._ut4grad_allreduce = False
|
|
|
|
|
|
class CollectiveOpBasedOptimizer(DistributedOptimizer):
|
|
"""
|
|
Collective Operator Base Class For Distributed Optimizer
|
|
The class is invisible to a user
|
|
"""
|
|
|
|
def __init__(self, optimizer, strategy=None):
|
|
assert isinstance(
|
|
strategy,
|
|
DistributedStrategy), "strategy must be DistributedStrategy"
|
|
super(CollectiveOpBasedOptimizer, self).__init__(optimizer, strategy)
|
|
|
|
def backward(self,
|
|
loss,
|
|
startup_program=None,
|
|
parameter_list=None,
|
|
no_grad_set=None,
|
|
callbacks=None):
|
|
return self._optimizer.backward(loss, startup_program, parameter_list,
|
|
no_grad_set, callbacks)
|
|
|
|
def apply_gradients(self, params_grads):
|
|
return self._optimizer.apply_gradients(params_grads)
|
|
|
|
|
|
class CollectiveOptimizer(DistributedOptimizer):
|
|
"""
|
|
DistributedOptimizer is a wrapper for paddle.fluid.optimizer
|
|
A user should pass a paddle.fluid.optimizer to DistributedOptimizer
|
|
minimize() function is implemented.
|
|
DistributedOptimizer is the starting point for a user who wants to
|
|
run distributed training. The optimized information will be stored in
|
|
Fleet() instance who holds the global information about current distributed
|
|
training.
|
|
"""
|
|
|
|
def __init__(self, optimizer, strategy=DistributedStrategy()):
|
|
if strategy is None:
|
|
strategy = DistributedStrategy()
|
|
super(CollectiveOptimizer, self).__init__(optimizer, strategy)
|
|
if strategy.forward_recompute:
|
|
self.forward_recompute = True
|
|
self.recompute_checkpoints = strategy.recompute_checkpoints
|
|
else:
|
|
self.forward_recompute = False
|
|
self.print_config = False
|
|
|
|
def backward(self,
|
|
loss,
|
|
startup_program=None,
|
|
parameter_list=None,
|
|
no_grad_set=None,
|
|
callbacks=None):
|
|
return self._optimizer.backward(loss, startup_program, parameter_list,
|
|
no_grad_set, callbacks)
|
|
|
|
def apply_gradients(self, params_grads):
|
|
return self._optimizer.apply_gradients(params_grads)
|
|
|
|
def _check_condition(self, name, **kwargs):
|
|
for k, v in six.iteritems(kwargs):
|
|
if v is True:
|
|
assert False, "you can't use %s and %s together" % (name, k)
|
|
|
|
def _check_collective_mode(self, main_program, optimizer, strategy):
|
|
"""
|
|
Check the conflict conditions.
|
|
"""
|
|
if strategy.use_local_sgd:
|
|
strategy.mode = "collective"
|
|
strategy.collective_mode = "local_sgd"
|
|
self._check_condition(
|
|
"use_local_sgd",
|
|
use_dgc=main_program._enable_dgc,
|
|
use_dist_fc=strategy.use_dist_fc,
|
|
use_lamb=main_program._use_lamb)
|
|
|
|
if strategy.use_dist_fc:
|
|
self._check_condition(
|
|
"use_dist_fc",
|
|
use_dgc=main_program._enable_dgc,
|
|
use_local_sgd=strategy.use_local_sgd,
|
|
use_lamb=main_program._use_lamb)
|
|
assert strategy.dist_fc_config is not None, "DistributedStrategy.dist_fc_config should be set"
|
|
|
|
if strategy._ut4grad_allreduce:
|
|
strategy.mode = "collective"
|
|
strategy.collective_mode = "grad_allreduce"
|
|
self._check_condition(
|
|
"_ut4grad_allreduce",
|
|
use_dgc=main_program._enable_dgc,
|
|
use_lamb=main_program._use_lamb)
|
|
|
|
if self._strategy.collective_mode=="local_sgd" \
|
|
or self._strategy.collective_mode == "grad_allreduce":
|
|
assert self._strategy.mode == "collective", \
|
|
"local_sgd and grad_allreduce can be used under collective mode"
|
|
|
|
def _transpile(self, startup_program, main_program):
|
|
"""
|
|
Transpile the programs to distributed programs. And add the variables.
|
|
"""
|
|
worker_endpoints = fleet.worker_endpoints()
|
|
trainer_id = fleet.worker_index()
|
|
current_endpoint = fleet.worker_endpoints()[trainer_id]
|
|
worker_endpoints_env = ','.join(worker_endpoints)
|
|
trainers_num = fleet.worker_num()
|
|
|
|
if self.print_config:
|
|
print("worker_endpoints:{} trainers_num:{} current_endpoint:{} \
|
|
trainer_id:{}".format(worker_endpoints, trainers_num,
|
|
current_endpoint, trainer_id))
|
|
|
|
# call transpiler
|
|
config = dist_transpiler.DistributeTranspilerConfig()
|
|
config.mode = self._strategy.mode
|
|
config.collective_mode = self._strategy.collective_mode
|
|
|
|
config.nccl_comm_num = self._strategy.nccl_comm_num
|
|
config.use_hierarchical_allreduce = self._strategy.use_hierarchical_allreduce
|
|
config.hierarchical_allreduce_inter_nranks = self._strategy.hierarchical_allreduce_inter_nranks
|
|
|
|
t = dist_transpiler.DistributeTranspiler(config=config)
|
|
t.transpile(
|
|
trainer_id=trainer_id,
|
|
trainers=worker_endpoints_env,
|
|
startup_program=startup_program,
|
|
program=main_program,
|
|
current_endpoint=current_endpoint)
|
|
|
|
def _get_node_ips_from_endpoints(self, endpoints):
|
|
ss = set()
|
|
ips = []
|
|
for ep in endpoints:
|
|
ip = ep.split(":")[0].strip()
|
|
if ip not in ss:
|
|
ss.add(ip)
|
|
ips.append(ip)
|
|
else:
|
|
continue
|
|
|
|
return ips
|
|
|
|
def _node_num(self):
|
|
worker_endpoints = fleet.worker_endpoints()
|
|
current_endpoint = fleet.worker_endpoints()[fleet.worker_index()]
|
|
worker_endpoints_env = ','.join(worker_endpoints)
|
|
|
|
node_ips = self._get_node_ips_from_endpoints(worker_endpoints)
|
|
node_ip = current_endpoint.split(":")[0].strip()
|
|
|
|
node_num = len(node_ips)
|
|
|
|
return node_num
|
|
|
|
def _try_to_compile(self, startup_program, main_program):
|
|
node_num = self._node_num()
|
|
assert node_num >= 1, "nccl2 node_num must >= 1, now:{}" % node_num
|
|
|
|
exec_strategy = self._strategy.exec_strategy
|
|
|
|
if node_num <= 1:
|
|
if self._strategy.nccl_comm_num > 1:
|
|
logging.warn("set nccl_comm_num=1 since you only have 1 node.")
|
|
self._strategy.nccl_comm_num = 1
|
|
|
|
if self._strategy.use_hierarchical_allreduce:
|
|
logging.warn(
|
|
"set use_hierarchical_allreduce=False since you only have 1 node."
|
|
)
|
|
self._strategy.use_hierarchical_allreduce = False
|
|
|
|
sync_allreduce = os.getenv("FLAGS_sync_nccl_allreduce")
|
|
if sync_allreduce is None or sync_allreduce == "1":
|
|
exec_strategy.num_threads = self._strategy.nccl_comm_num + 1
|
|
if self._strategy.use_hierarchical_allreduce:
|
|
exec_strategy.num_threads = 2 * self._strategy.nccl_comm_num + 1
|
|
if exec_strategy.num_threads > 4:
|
|
logging.warn(
|
|
"if you use use_hierarchical_allreduce or "
|
|
"with multi nccl comm, please export FLAGS_sync_nccl_allreduce = 0"
|
|
)
|
|
|
|
# NOTE. open sync_batch_norm will hang when use multi num_threads
|
|
sync_batch_norm = self._strategy.sync_batch_norm
|
|
if sync_batch_norm is not None and sync_batch_norm is True:
|
|
self._strategy.nccl_comm_num = 1
|
|
self._strategy.use_hierarchical_allreduce = False
|
|
exec_strategy.num_threads = 1
|
|
logging.warn(
|
|
"use sync_batch_norm will hang when set num_threads > 1, so "
|
|
"set num_threads=1, nccl_comm_num=1, use_hierarchical_allreduce=False."
|
|
)
|
|
|
|
if self.print_config:
|
|
print("node_num:", node_num, "num_threads:",
|
|
exec_strategy.num_threads, "use_hierarchical_allreduce:",
|
|
self._strategy.use_hierarchical_allreduce, "nccl_comm_num:",
|
|
self._strategy.nccl_comm_num, "FLAGS_sync_nccl_allreduce:",
|
|
sync_allreduce)
|
|
|
|
self._transpile(startup_program, main_program)
|
|
|
|
if self._strategy.mode == "collective":
|
|
return main_program
|
|
|
|
self._strategy.num_trainers = fleet.worker_num()
|
|
self._strategy.trainer_id = fleet.worker_index()
|
|
self._strategy.trainers_endpoints = fleet.worker_endpoints()
|
|
self._strategy.enable_backward_optimizer_op_deps = True
|
|
|
|
self._compiled_program = compiler.CompiledProgram(main_program)
|
|
|
|
self._compiled_program.with_data_parallel(
|
|
loss_name=self._loss.name,
|
|
build_strategy=self._strategy,
|
|
exec_strategy=self._strategy.exec_strategy,
|
|
share_vars_from=None)
|
|
|
|
return self._compiled_program
|
|
|
|
def minimize(self,
|
|
loss,
|
|
startup_program=None,
|
|
parameter_list=None,
|
|
no_grad_set=None):
|
|
"""
|
|
minimize a program through loss
|
|
Args:
|
|
loss (Variable|Variable List): loss variable or loss variable list to run optimization.
|
|
startup_program (Program): startup_program for initializing parameters
|
|
in `parameter_list`.
|
|
parameter_list (list): list of Variables to update.
|
|
no_grad_set (set|None): set of Variables should be ignored.
|
|
Returns:
|
|
tuple: (optimize_ops, params_grads) which are, list of operators appended;
|
|
and list of (param, grad) Variables pair for optimization.
|
|
Note that in parameter server mode, a worker will not get anything about optimize_os
|
|
Because optimizer algorithms run on pserver side. We will make this usable in pserver
|
|
process, but currently the optimization part is written into Fleet(). A user does not
|
|
need to care about how to startup a pserver node.
|
|
"""
|
|
main_program = loss.block.program
|
|
if startup_program is None:
|
|
startup_program = fluid.default_startup_program()
|
|
fleet.startup_program = startup_program
|
|
|
|
self._loss = loss
|
|
|
|
self._check_collective_mode(main_program, self._optimizer,
|
|
self._strategy)
|
|
|
|
if self.forward_recompute:
|
|
assert (isinstance(self.recompute_checkpoints, list) and
|
|
len(self.recompute_checkpoints) > 0)
|
|
self._optimizer = \
|
|
fluid.optimizer.RecomputeOptimizer(self._optimizer)
|
|
self._optimizer._set_checkpoints(self.recompute_checkpoints)
|
|
|
|
optimize_ops, param_grads = self._optimizer.minimize(
|
|
loss,
|
|
startup_program=startup_program,
|
|
parameter_list=parameter_list,
|
|
no_grad_set=no_grad_set)
|
|
|
|
fleet._origin_program = main_program.clone(for_test=False)
|
|
fleet._transpiled_program = main_program
|
|
fleet.main_program = self._try_to_compile(startup_program, main_program)
|
|
|
|
return optimize_ops, param_grads
|