|
|
|
@ -24,51 +24,35 @@ from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerCo
|
|
|
|
|
|
|
|
|
|
class TrainerRuntimeConfig(object):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.max_merge_var_num = os.getenv(
|
|
|
|
|
"FLAGS_communicator_max_merge_var_num", "20")
|
|
|
|
|
self.send_queue_size = os.getenv("FLAGS_communicator_send_queue_size",
|
|
|
|
|
"20")
|
|
|
|
|
self.independent_recv_thread = os.getenv(
|
|
|
|
|
"FLAGS_communicator_independent_recv_thread", "1")
|
|
|
|
|
self.min_send_grad_num_before_recv = os.getenv(
|
|
|
|
|
"FLAGS_communicator_min_send_grad_num_before_recv", "20")
|
|
|
|
|
self.thread_pool_size = os.getenv("FLAGS_communicator_thread_pool_size",
|
|
|
|
|
"5")
|
|
|
|
|
self.send_wait_times = os.getenv("FLAGS_communicator_send_wait_times",
|
|
|
|
|
"5")
|
|
|
|
|
self.fake_rpc = os.getenv("FLAGS_communicator_fake_rpc", "0")
|
|
|
|
|
self.merge_sparse_grad = os.getenv(
|
|
|
|
|
"FLAGS_communicator_merge_sparse_grad", "1")
|
|
|
|
|
self.is_sgd_optimizer = os.getenv("FLAGS_communicator_is_sgd_optimizer",
|
|
|
|
|
"1")
|
|
|
|
|
|
|
|
|
|
self.runtime_configs = {}
|
|
|
|
|
# not used
|
|
|
|
|
self._rpc_deadline = os.getenv("FLAGS_rpc_deadline", "180000")
|
|
|
|
|
self._rpc_retry_times = os.getenv("FLAGS_rpc_retry_times", "3")
|
|
|
|
|
self.runtime_configs['rpc_deadline'] = os.getenv("FLAGS_rpc_deadline",
|
|
|
|
|
"180000")
|
|
|
|
|
self.runtime_configs['rpc_retry_times'] = os.getenv(
|
|
|
|
|
"FLAGS_rpc_retry_times", "3")
|
|
|
|
|
|
|
|
|
|
def get_communicator_flags(self):
|
|
|
|
|
_communicator_flags = dict()
|
|
|
|
|
_communicator_flags["communicator_max_merge_var_num"] = str(
|
|
|
|
|
self.max_merge_var_num)
|
|
|
|
|
_communicator_flags["communicator_send_queue_size"] = str(
|
|
|
|
|
self.send_queue_size)
|
|
|
|
|
_communicator_flags["communicator_independent_recv_thread"] = str(
|
|
|
|
|
self.independent_recv_thread)
|
|
|
|
|
_communicator_flags["communicator_min_send_grad_num_before_recv"] = str(
|
|
|
|
|
self.min_send_grad_num_before_recv)
|
|
|
|
|
_communicator_flags["communicator_thread_pool_size"] = str(
|
|
|
|
|
self.thread_pool_size)
|
|
|
|
|
_communicator_flags["communicator_send_wait_times"] = str(
|
|
|
|
|
self.send_wait_times)
|
|
|
|
|
_communicator_flags["communicator_is_sgd_optimizer"] = str(
|
|
|
|
|
self.is_sgd_optimizer)
|
|
|
|
|
return _communicator_flags
|
|
|
|
|
return self.runtime_configs
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
_str = "please check that TrainerRuntimeConfig is as expected:\n"
|
|
|
|
|
_communicator_flags = self.get_communicator_flags()
|
|
|
|
|
for key in _communicator_flags:
|
|
|
|
|
_str += "{}: {}\n".format(key, _communicator_flags[key])
|
|
|
|
|
raw0, raw1, length = 45, 5, 50
|
|
|
|
|
h_format = "{:^45s}{:<5s}\n"
|
|
|
|
|
l_format = "{:<45s}{:<5s}\n"
|
|
|
|
|
|
|
|
|
|
border = "".join(["="] * length)
|
|
|
|
|
line = "".join(["-"] * length)
|
|
|
|
|
|
|
|
|
|
draws = ""
|
|
|
|
|
draws += border + "\n"
|
|
|
|
|
draws += h_format.format("TrainerRuntimeConfig Overview", "Value")
|
|
|
|
|
draws += line + "\n"
|
|
|
|
|
|
|
|
|
|
for k, v in self.get_communicator_flags().items():
|
|
|
|
|
draws += l_format.format(k, v)
|
|
|
|
|
|
|
|
|
|
draws += border
|
|
|
|
|
|
|
|
|
|
_str = "\n{}\n".format(draws)
|
|
|
|
|
return _str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -77,9 +61,11 @@ class DistributedStrategy(object):
|
|
|
|
|
self._program_config = DistributeTranspilerConfig()
|
|
|
|
|
self._trainer_runtime_config = TrainerRuntimeConfig()
|
|
|
|
|
self._server_runtime_config = ServerRuntimeConfig()
|
|
|
|
|
num_threads = int(os.getenv("CPU_NUM", "1"))
|
|
|
|
|
|
|
|
|
|
self._execute_strategy = fluid.ExecutionStrategy()
|
|
|
|
|
self._build_strategy = fluid.BuildStrategy()
|
|
|
|
|
num_threads = int(os.getenv("CPU_NUM", "1"))
|
|
|
|
|
|
|
|
|
|
self._execute_strategy.num_threads = num_threads
|
|
|
|
|
if num_threads > 1:
|
|
|
|
|
self._build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
|
|
|
|
@ -110,9 +96,9 @@ class DistributedStrategy(object):
|
|
|
|
|
if isinstance(config, TrainerRuntimeConfig):
|
|
|
|
|
self._trainer_runtime_config = config
|
|
|
|
|
elif isinstance(config, dict):
|
|
|
|
|
for key in config:
|
|
|
|
|
if hasattr(self._trainer_runtime_config, key):
|
|
|
|
|
setattr(self._trainer_runtime_config, key, config[key])
|
|
|
|
|
for key, Value in config.items():
|
|
|
|
|
if key in self._trainer_runtime_config.runtime_configs:
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[key] = Value
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"TrainerRuntimeConfig doesn't have key: {}".format(key))
|
|
|
|
@ -182,6 +168,21 @@ class SyncStrategy(DistributedStrategy):
|
|
|
|
|
self._program_config.runtime_split_send_recv = False
|
|
|
|
|
self._build_strategy.async_mode = False
|
|
|
|
|
|
|
|
|
|
num_threads = os.getenv("CPU_NUM", "1")
|
|
|
|
|
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_max_merge_var_num'] = os.getenv(
|
|
|
|
|
"FLAGS_communicator_max_merge_var_num", num_threads)
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_send_wait_times'] = os.getenv(
|
|
|
|
|
"FLAGS_communicator_send_wait_times", "5")
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_thread_pool_size'] = os.getenv(
|
|
|
|
|
"FLAGS_communicator_thread_pool_size", "10")
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_send_queue_size'] = os.getenv(
|
|
|
|
|
"FLAGS_communicator_send_queue_size", num_threads)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AsyncStrategy(DistributedStrategy):
|
|
|
|
|
def __init__(self):
|
|
|
|
@ -190,6 +191,30 @@ class AsyncStrategy(DistributedStrategy):
|
|
|
|
|
self._program_config.runtime_split_send_recv = True
|
|
|
|
|
self._build_strategy.async_mode = True
|
|
|
|
|
|
|
|
|
|
num_threads = os.getenv("CPU_NUM", "1")
|
|
|
|
|
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_max_merge_var_num'] = os.getenv(
|
|
|
|
|
"FLAGS_communicator_max_merge_var_num", num_threads)
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_independent_recv_thread'] = os.getenv(
|
|
|
|
|
"FLAGS_communicator_independent_recv_thread", "0")
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_min_send_grad_num_before_recv'] = os.getenv(
|
|
|
|
|
"FLAGS_communicator_min_send_grad_num_before_recv", num_threads)
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_thread_pool_size'] = os.getenv(
|
|
|
|
|
"FLAGS_communicator_thread_pool_size", "10")
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_send_wait_times'] = os.getenv(
|
|
|
|
|
"FLAGS_communicator_send_wait_times", "5")
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_is_sgd_optimizer'] = os.getenv(
|
|
|
|
|
"FLAGS_communicator_is_sgd_optimizer", "1")
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_send_queue_size'] = os.getenv(
|
|
|
|
|
"FLAGS_communicator_send_queue_size", num_threads)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HalfAsyncStrategy(DistributedStrategy):
|
|
|
|
|
def __init__(self):
|
|
|
|
@ -200,15 +225,37 @@ class HalfAsyncStrategy(DistributedStrategy):
|
|
|
|
|
self._build_strategy.async_mode = True
|
|
|
|
|
self._execute_strategy.use_thread_barrier = True
|
|
|
|
|
|
|
|
|
|
num_threads = os.getenv("CPU_NUM", "1")
|
|
|
|
|
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_max_merge_var_num'] = os.getenv(
|
|
|
|
|
"FLAGS_communicator_max_merge_var_num", num_threads)
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_send_wait_times'] = os.getenv(
|
|
|
|
|
"FLAGS_communicator_send_wait_times", "5")
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_thread_pool_size'] = os.getenv(
|
|
|
|
|
"FLAGS_communicator_thread_pool_size", "10")
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_send_queue_size'] = os.getenv(
|
|
|
|
|
"FLAGS_communicator_send_queue_size", num_threads)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GeoStrategy(DistributedStrategy):
|
|
|
|
|
def __init__(self, update_frequency=100):
|
|
|
|
|
super(GeoStrategy, self).__init__()
|
|
|
|
|
self._program_config.sync_mode = False
|
|
|
|
|
self._program_config.runtime_split_send_recv = True
|
|
|
|
|
self._build_strategy.async_mode = True
|
|
|
|
|
self._program_config.geo_sgd_mode = True
|
|
|
|
|
self._program_config.geo_sgd_need_push_nums = update_frequency
|
|
|
|
|
self._build_strategy.async_mode = True
|
|
|
|
|
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_thread_pool_size'] = os.getenv(
|
|
|
|
|
"FLAGS_communicator_thread_pool_size", "10")
|
|
|
|
|
self._trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_send_wait_times'] = os.getenv(
|
|
|
|
|
"FLAGS_communicator_send_wait_times", "5")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StrategyFactory(object):
|
|
|
|
|