fleet base initial implementation and the API (#25442)
refactor fleet api under paddle.fleet update DistributedStrategyfix_copy_if_different
parent
214c6fcdee
commit
e657d7062d
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,55 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import socket
|
||||||
|
from contextlib import closing
|
||||||
|
from six import string_types
|
||||||
|
|
||||||
|
|
||||||
|
def wait_server_ready(endpoints):
|
||||||
|
"""
|
||||||
|
Wait until parameter servers are ready, use connext_ex to detect
|
||||||
|
port readiness.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoints (list): endpoints string list, like:
|
||||||
|
["127.0.0.1:8080", "127.0.0.1:8081"]
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
wait_server_ready(["127.0.0.1:8080", "127.0.0.1:8081"])
|
||||||
|
"""
|
||||||
|
assert not isinstance(endpoints, str)
|
||||||
|
while True:
|
||||||
|
all_ok = True
|
||||||
|
not_ready_endpoints = []
|
||||||
|
for ep in endpoints:
|
||||||
|
ip_port = ep.split(":")
|
||||||
|
with closing(socket.socket(socket.AF_INET,
|
||||||
|
socket.SOCK_STREAM)) as sock:
|
||||||
|
sock.settimeout(2)
|
||||||
|
result = sock.connect_ex((ip_port[0], int(ip_port[1])))
|
||||||
|
if result != 0:
|
||||||
|
all_ok = False
|
||||||
|
not_ready_endpoints.append(ep)
|
||||||
|
if not all_ok:
|
||||||
|
sys.stderr.write("server not ready, wait 3 sec to retry...\n")
|
||||||
|
sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints) +
|
||||||
|
"\n")
|
||||||
|
sys.stderr.flush()
|
||||||
|
time.sleep(3)
|
||||||
|
else:
|
||||||
|
break
|
@ -0,0 +1,27 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from ..runtime.collective_runtime import CollectiveRuntime
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeFactory(object):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _create_runtime(self, final_dist_strategy, role_maker, opt_ops,
|
||||||
|
params_grads):
|
||||||
|
if role_maker._is_collective:
|
||||||
|
collective_runtime = CollectiveRuntime()
|
||||||
|
collective_runtime._set_basic_info(final_dist_strategy, role_maker,
|
||||||
|
opt_ops, params_grads)
|
||||||
|
return collective_runtime
|
@ -0,0 +1,69 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
def maximum_path_len_algo(optimizer_list):
|
||||||
|
max_idx = 0
|
||||||
|
max_len = 0
|
||||||
|
candidates = []
|
||||||
|
for idx, opt in enumerate(optimizer_list):
|
||||||
|
local_buffer = [opt]
|
||||||
|
for opt_inner in optimizer_list:
|
||||||
|
if opt._can_update(opt_inner):
|
||||||
|
local_buffer.append(opt_inner)
|
||||||
|
if len(local_buffer) > max_len:
|
||||||
|
max_idx = idx
|
||||||
|
max_len = len(local_buffer)
|
||||||
|
candidates.append(local_buffer)
|
||||||
|
if len(candidates) == 0:
|
||||||
|
return None
|
||||||
|
for idx, opt in enumerate(candidates[max_idx][:-1]):
|
||||||
|
opt._update_inner_optimizer(candidates[max_idx][idx + 1])
|
||||||
|
return candidates[max_idx][0]
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyCompilerBase(object):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyCompiler(StrategyCompilerBase):
|
||||||
|
"""
|
||||||
|
StrategyCompiler is responsible for meta optimizers combination
|
||||||
|
Generally, a user can define serveral distributed strategies that
|
||||||
|
can generate serveral meta optimizer. The combination of these
|
||||||
|
meta optimizers should have the right order to apply the optimizers'
|
||||||
|
minimize function.
|
||||||
|
This class is responsible for the executable distributed optimizer
|
||||||
|
generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(StrategyCompiler, self).__init__()
|
||||||
|
|
||||||
|
def generate_optimizer(self, loss, role_maker, optimizer,
|
||||||
|
userd_defined_strategy, meta_optimizer_list,
|
||||||
|
graph_optimizer_list):
|
||||||
|
if len(meta_optimizer_list) == 0 and len(graph_optimizer_list) == 0:
|
||||||
|
return optimizer, None
|
||||||
|
else:
|
||||||
|
# currently, we use heuristic algorithm to select
|
||||||
|
# meta optimizers combinations
|
||||||
|
meta_optimizer = maximum_path_len_algo(meta_optimizer_list)
|
||||||
|
graph_optimizer = maximum_path_len_algo(graph_optimizer_list)
|
||||||
|
# should design a distributed strategy update interface
|
||||||
|
# when we have finally decided the combination of meta_optimizer
|
||||||
|
# and graph_optimizer, the corresponding distributed strategy
|
||||||
|
# should be updated.
|
||||||
|
return meta_optimizer, graph_optimizer, None
|
@ -0,0 +1,194 @@
|
|||||||
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle.fluid.framework import core
|
||||||
|
from paddle.fluid import compiler
|
||||||
|
from .meta_optimizer_base import MetaOptimizerBase
|
||||||
|
from ..base.private_helper_function import wait_server_ready
|
||||||
|
|
||||||
|
|
||||||
|
def get_build_strategy(dist_strategy):
|
||||||
|
build_strategy = paddle.BuildStrategy()
|
||||||
|
build_strategy.enable_sequential_execution = \
|
||||||
|
dist_strategy.sequential_execution
|
||||||
|
build_strategy.remove_unnecessary_lock = True
|
||||||
|
build_strategy.fuse_elewise_add_act_ops = \
|
||||||
|
dist_strategy.fuse_elewise_add_act_ops
|
||||||
|
build_strategy.fuse_bn_act_ops = \
|
||||||
|
dist_strategy.fuse_bn_act_ops
|
||||||
|
build_strategy.enable_auto_fusion = \
|
||||||
|
dist_strategy.enable_auto_fusion
|
||||||
|
build_strategy.fuse_relu_depthwise_conv = \
|
||||||
|
dist_strategy.fuse_relu_depthwise_conv
|
||||||
|
build_strategy.fuse_broadcast_ops = \
|
||||||
|
dist_strategy.fuse_broadcast_ops
|
||||||
|
build_strategy.sync_batch_norm = \
|
||||||
|
dist_strategy.sync_batch_norm
|
||||||
|
return build_strategy
|
||||||
|
|
||||||
|
|
||||||
|
def get_execution_strategy(dist_strategy):
|
||||||
|
execution_strategy = paddle.ExecutionStrategy()
|
||||||
|
execution_strategy.num_threads = \
|
||||||
|
dist_strategy.num_threads
|
||||||
|
execution_strategy.num_iteration_per_drop_scope = \
|
||||||
|
dist_strategy.num_iteration_per_drop_scope
|
||||||
|
execution_strategy.num_iteration_per_run = \
|
||||||
|
dist_strategy.num_iteration_per_run
|
||||||
|
execution_strategy.use_thread_barrier = \
|
||||||
|
dist_strategy.use_thread_barrier
|
||||||
|
return execution_strategy
|
||||||
|
|
||||||
|
|
||||||
|
class GraphExecutionOptimizer(MetaOptimizerBase):
|
||||||
|
def __init__(self, optimizer):
|
||||||
|
super(GraphExecutionOptimizer, self).__init__(optimizer)
|
||||||
|
self.inner_opt = optimizer
|
||||||
|
# we do not allow meta optimizer to be inner optimizer currently
|
||||||
|
self.meta_optimizers_white_list = []
|
||||||
|
|
||||||
|
def _is_graph_out(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _can_apply(self):
|
||||||
|
"""
|
||||||
|
Basically, this is PE, and almost all programs can be executed here
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def backward(self,
|
||||||
|
loss,
|
||||||
|
startup_program=None,
|
||||||
|
parameter_list=None,
|
||||||
|
no_grad_set=None,
|
||||||
|
callbacks=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# should fix the variable
|
||||||
|
def _setup_nccl_op(self, startup_program, main_program):
|
||||||
|
trainer_endpoints = self.role_maker.get_trainer_endpoints()
|
||||||
|
trainers = trainer_endpoints
|
||||||
|
trainer_id = self.role_maker.worker_index()
|
||||||
|
current_endpoint = self.role_maker.get_trainer_endpoints()[trainer_id]
|
||||||
|
trainer_endpoints_env = ",".join(trainer_endpoints)
|
||||||
|
trainers_num = self.role_maker.worker_num()
|
||||||
|
trainer_endpoints.remove(current_endpoint)
|
||||||
|
if trainer_id == 0:
|
||||||
|
wait_server_ready(trainer_endpoints)
|
||||||
|
nccl_id_var = startup_program.global_block().create_var(
|
||||||
|
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
|
||||||
|
for i in range(1, self.user_defined_strategy.nccl_comm_num):
|
||||||
|
startup_program.global_block().create_var(
|
||||||
|
name="NCCLID_{}".format(i),
|
||||||
|
persistable=True,
|
||||||
|
type=core.VarDesc.VarType.RAW)
|
||||||
|
|
||||||
|
if self.user_defined_strategy.hierachical_allreduce:
|
||||||
|
for i in range(0, self.user_defined_strategy.nccl_comm_num):
|
||||||
|
startup_program.global_block().create_var(
|
||||||
|
name="Hierarchical_inter_NCCLID_{}".format(i),
|
||||||
|
persistable=True,
|
||||||
|
type=core.VarDesc.VarType.RAW)
|
||||||
|
startup_program.global_block().create_var(
|
||||||
|
name="Hierarchical_exter_NCCLID_{}".format(i),
|
||||||
|
persistable=True,
|
||||||
|
type=core.VarDesc.VarType.RAW)
|
||||||
|
|
||||||
|
startup_program.global_block().append_op(
|
||||||
|
type="gen_nccl_id",
|
||||||
|
inputs={},
|
||||||
|
outputs={"NCCLID": nccl_id_var},
|
||||||
|
attrs={
|
||||||
|
"trainers": trainers,
|
||||||
|
"trainer_id": trainer_id,
|
||||||
|
"nccl_comm_num": self.user_defined_strategy.nccl_comm_num,
|
||||||
|
"use_hierarchical_allreduce":
|
||||||
|
self.user_defined_strategy.hierachical_allreduce,
|
||||||
|
"hierarchical_allreduce_inter_ranks":
|
||||||
|
self.user_defined_strategy.hierachical_allreduce_inter_ranks
|
||||||
|
})
|
||||||
|
|
||||||
|
def _try_to_compile(self, startup_program, main_program, loss):
|
||||||
|
build_strategy = get_build_strategy(self.user_defined_strategy)
|
||||||
|
exe_strategy = get_execution_strategy(self.user_defined_strategy)
|
||||||
|
node_num = self.role_maker.worker_num()
|
||||||
|
if self.role_maker._is_collective:
|
||||||
|
assert node_num >= 1, "nccl2 node_num must >= 1, now:{}" % node_num
|
||||||
|
|
||||||
|
if node_num <= 1:
|
||||||
|
# local mode
|
||||||
|
if self.user_defined_strategy.nccl_comm_num > 1:
|
||||||
|
logging.warn("set nccl_comm_num=1 since you only have 1 node.")
|
||||||
|
self.user_defined_strategy.nccl_comm_num = 1
|
||||||
|
|
||||||
|
if self.user_defined_strategy.hierachical_allreduce:
|
||||||
|
logging.warn(
|
||||||
|
"set hierachical_allreduce=False since you only have 1 node."
|
||||||
|
)
|
||||||
|
self.user_defined_strategy.hierachical_allreduce = False
|
||||||
|
|
||||||
|
sync_allreduce = self.user_defined_strategy.sync_nccl_allreduce
|
||||||
|
if sync_allreduce:
|
||||||
|
exe_strategy.num_threads = self.user_defined_strategy.nccl_comm_num + 1
|
||||||
|
if self.user_defined_strategy.hierachical_allreduce:
|
||||||
|
exe_strategy.num_threads = 2 * self.user_defined_strategy.nccl_comm_num + 1
|
||||||
|
if exe_strategy.num_threads > 4:
|
||||||
|
logging.warn(
|
||||||
|
"if you use hierachical_allreduce or "
|
||||||
|
"with multi nccl comm, please export FLAGS_sync_nccl_allreduce = 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(guru4elephant): should be an independent optimizer
|
||||||
|
sync_batch_norm = self.user_defined_strategy.sync_batch_norm
|
||||||
|
if sync_batch_norm:
|
||||||
|
self.user_defined_strategy.nccl_comm_num = 1
|
||||||
|
self.user_defined_strategy.hierachical_allreduce = False
|
||||||
|
exe_strategy.num_threads = 1
|
||||||
|
logging.warn(
|
||||||
|
"use sync_batch_norm will hang when set num_threads > 1, so "
|
||||||
|
"set num_threads=1, nccl_comm_num=1, hierachical_allreduce=False."
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(guru4elephant): should be an independent optimizer
|
||||||
|
self._setup_nccl_op(startup_program, main_program)
|
||||||
|
|
||||||
|
build_strategy.num_trainers = self.role_maker.worker_num()
|
||||||
|
build_strategy.trainer_id = self.role_maker.worker_index()
|
||||||
|
build_strategy.trainers_endpoints = self.role_maker.get_trainer_endpoints(
|
||||||
|
)
|
||||||
|
build_strategy.enable_backward_optimizer_op_deps = True
|
||||||
|
|
||||||
|
self._compiled_program = compiler.CompiledProgram(main_program)
|
||||||
|
|
||||||
|
self._compiled_program.with_data_parallel(
|
||||||
|
loss_name=loss.name,
|
||||||
|
build_strategy=build_strategy,
|
||||||
|
exec_strategy=exe_strategy,
|
||||||
|
share_vars_from=None)
|
||||||
|
|
||||||
|
return self._compiled_program
|
||||||
|
|
||||||
|
def minimize(self,
|
||||||
|
loss,
|
||||||
|
startup_program=None,
|
||||||
|
parameter_list=None,
|
||||||
|
no_grad_set=None):
|
||||||
|
if startup_program == None:
|
||||||
|
startup_program = paddle.default_startup_program()
|
||||||
|
compiled_program = self._try_to_compile(startup_program,
|
||||||
|
loss.block.program, loss)
|
||||||
|
loss.block.program.graph = compiled_program
|
||||||
|
|
||||||
|
# just return self.optimizer_ops and self.param_grads
|
||||||
|
return None, None
|
@ -0,0 +1,56 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
__all__ = ["MetaOptimizerBase"]
|
||||||
|
|
||||||
|
|
||||||
|
class MetaOptimizerBase(object):
|
||||||
|
def __init__(self, optimizer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
|
||||||
|
user_defined_strategy):
|
||||||
|
self.loss = loss
|
||||||
|
self.role_maker = role_maker
|
||||||
|
self.user_defined_optimizer = user_defined_optimizer
|
||||||
|
self.user_defined_strategy = user_defined_strategy
|
||||||
|
|
||||||
|
def _update_inner_optimier(self, optimizer):
|
||||||
|
self.inner_opt = optimizer
|
||||||
|
|
||||||
|
def _can_apply(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _is_graph_out(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _can_update(self, optimizer):
|
||||||
|
if str(optimizer.__class__.__name__) in self.meta_optimizers_white_list:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def minimize_impl(self,
|
||||||
|
loss,
|
||||||
|
startup_program=None,
|
||||||
|
parameter_list=None,
|
||||||
|
no_grad_set=None):
|
||||||
|
raise NotImplementedError("meta optimizer not implemented")
|
||||||
|
|
||||||
|
def minimize(self,
|
||||||
|
loss,
|
||||||
|
startup_program=None,
|
||||||
|
parameter_list=None,
|
||||||
|
no_grad_set=None):
|
||||||
|
optimize_ops, params_grads = self.minimize_impl(
|
||||||
|
loss, startup_program, parameter_list, no_grad_set)
|
||||||
|
return optimize_ops, params_grads
|
@ -0,0 +1,59 @@
|
|||||||
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
|
||||||
|
from paddle.fluid.optimizer import RecomputeOptimizer as RO
|
||||||
|
from .meta_optimizer_base import MetaOptimizerBase
|
||||||
|
|
||||||
|
__all__ = ["RecomputeOptimizer"]
|
||||||
|
|
||||||
|
|
||||||
|
class RecomputeOptimizer(MetaOptimizerBase):
|
||||||
|
def __init__(self, optimizer):
|
||||||
|
super(RecomputeOptimizer, self).__init__(optimizer)
|
||||||
|
#self.inner_opt = RO(optimizer)
|
||||||
|
self.inner_opt = optimizer
|
||||||
|
self.wrapped_opt = RO(optimizer)
|
||||||
|
# we do not allow meta optimizer to be inner optimizer currently
|
||||||
|
self.meta_optimizers_white_list = []
|
||||||
|
|
||||||
|
def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
|
||||||
|
user_defined_strategy):
|
||||||
|
super(RecomputeOptimizer, self)._set_basic_info(
|
||||||
|
loss, role_maker, user_defined_optimizer, user_defined_strategy)
|
||||||
|
self.wrapped_opt._set_checkpoints([])
|
||||||
|
|
||||||
|
def _can_apply(self):
|
||||||
|
if self.user_defined_strategy.recompute == True:
|
||||||
|
if len(self.user_defined_strategy.recompute_checkpoints) == 0:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def backward(self,
|
||||||
|
loss,
|
||||||
|
startup_program=None,
|
||||||
|
parameter_list=None,
|
||||||
|
no_grad_set=None,
|
||||||
|
callbacks=None):
|
||||||
|
return self.wrapped_opt.backward(loss, startup_program, parameter_list,
|
||||||
|
no_grad_set, callbacks)
|
||||||
|
|
||||||
|
def minimize_impl(self,
|
||||||
|
loss,
|
||||||
|
startup_program=None,
|
||||||
|
parameter_list=None,
|
||||||
|
no_grad_set=None):
|
||||||
|
optimize_ops, params_grads = \
|
||||||
|
self.wrapped_opt.minimize(loss, startup_program,
|
||||||
|
parameter_list, no_grad_set)
|
||||||
|
return optimize_ops, params_grads
|
@ -0,0 +1,48 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .runtime_base import RuntimeBase
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class CollectiveRuntime(RuntimeBase):
|
||||||
|
def __init__(self):
|
||||||
|
super(CollectiveRuntime, self).__init__()
|
||||||
|
|
||||||
|
def _init_worker(self):
|
||||||
|
logging.warn(
|
||||||
|
"You should not call 'init_worker' method for collective mode.")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _run_worker(self):
|
||||||
|
logging.warn(
|
||||||
|
"You should not call 'run_worker' method for collective mode.")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _init_server(self):
|
||||||
|
logging.warn(
|
||||||
|
"You should not call 'init_server' method for collective mode.")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _run_server(self):
|
||||||
|
logging.warn(
|
||||||
|
"You should not call 'run_server' method for collective mode.")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _stop_worker(self):
|
||||||
|
logging.warn(
|
||||||
|
"You should not call 'stop_worker' method for collective mode.")
|
||||||
|
pass
|
||||||
|
|
||||||
|
# save inference model should be added here
|
@ -0,0 +1,38 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
__all__ = []
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeBase(object):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _set_basic_info(self, loss, role_maker, optimizer, strategy):
|
||||||
|
self.loss = loss
|
||||||
|
self.role_maker = role_maker
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.strategy = strategy
|
||||||
|
|
||||||
|
def _run_worker(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _init_server(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _run_server(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _stop_worker(self):
|
||||||
|
pass
|
@ -0,0 +1,177 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import paddle
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class TestFleetBase(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
os.environ["POD_IP"] = "127.0.0.1"
|
||||||
|
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001"
|
||||||
|
os.environ["PADDLE_TRAINERS_NUM"] = "2"
|
||||||
|
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \
|
||||||
|
"127.0.0.1:36001,127.0.0.2:36001"
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
|
||||||
|
def test_is_first_worker(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
if fleet.is_first_worker():
|
||||||
|
print("test fleet first worker done.")
|
||||||
|
|
||||||
|
def test_worker_index(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
print(fleet.worker_index())
|
||||||
|
|
||||||
|
def test_worker_num(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
print(fleet.worker_num())
|
||||||
|
|
||||||
|
def test_is_worker(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
if fleet.is_worker():
|
||||||
|
print("test fleet is worker")
|
||||||
|
|
||||||
|
def test_worker_endpoints(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
print(fleet.worker_endpoints(to_string=True))
|
||||||
|
|
||||||
|
def test_server_num(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
if fleet.is_server():
|
||||||
|
print("fleet server num: {}".format(fleet.server_num()))
|
||||||
|
|
||||||
|
def test_server_index(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
if fleet.is_server():
|
||||||
|
print("fleet server index: {}".format(fleet.server_index()))
|
||||||
|
|
||||||
|
def test_server_endpoints(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
if fleet.is_server():
|
||||||
|
print("fleet server index: {}".format(
|
||||||
|
fleet.server_endpoints(to_string=True)))
|
||||||
|
|
||||||
|
def test_is_server(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
if fleet.is_server():
|
||||||
|
print("test fleet is server")
|
||||||
|
|
||||||
|
def test_util(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
self.assertEqual(fleet.util, None)
|
||||||
|
|
||||||
|
def test_barrier_worker(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
if fleet.is_worker():
|
||||||
|
fleet.barrier_worker()
|
||||||
|
|
||||||
|
def test_init_worker(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
if fleet.is_worker():
|
||||||
|
fleet.init_worker()
|
||||||
|
|
||||||
|
def test_run_server(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
if fleet.is_worker():
|
||||||
|
fleet.run_worker()
|
||||||
|
|
||||||
|
def test_stop_worker(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
if fleet.is_worker():
|
||||||
|
fleet.stop_worker()
|
||||||
|
|
||||||
|
def test_distributed_optimizer(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
strategy = fleet.DistributedStrategy()
|
||||||
|
optimizer = paddle.optimizer.SGD(learning_rate=0.001)
|
||||||
|
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
|
||||||
|
|
||||||
|
def test_minimize(self):
|
||||||
|
import paddle
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
|
||||||
|
input_x = paddle.fluid.layers.data(
|
||||||
|
name="x", shape=[32], dtype='float32')
|
||||||
|
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
|
||||||
|
|
||||||
|
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
|
||||||
|
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
|
||||||
|
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
|
||||||
|
cost = paddle.fluid.layers.cross_entropy(
|
||||||
|
input=prediction, label=input_y)
|
||||||
|
avg_cost = paddle.fluid.layers.mean(x=cost)
|
||||||
|
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
strategy = fleet.DistributedStrategy()
|
||||||
|
optimizer = paddle.optimizer.SGD(learning_rate=0.001)
|
||||||
|
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
|
||||||
|
optimizer.minimize(avg_cost)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -0,0 +1,76 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import paddle
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class TestFleetMetaOptimizer(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
os.environ["POD_IP"] = "127.0.0.1"
|
||||||
|
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001"
|
||||||
|
os.environ["PADDLE_TRAINERS_NUM"] = "2"
|
||||||
|
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \
|
||||||
|
"127.0.0.1:36001,127.0.0.2:36001"
|
||||||
|
|
||||||
|
def test_graph_execution_optimizer(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
input_x = paddle.fluid.layers.data(
|
||||||
|
name="x", shape=[32], dtype='float32')
|
||||||
|
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
|
||||||
|
|
||||||
|
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
|
||||||
|
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
|
||||||
|
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
|
||||||
|
cost = paddle.fluid.layers.cross_entropy(
|
||||||
|
input=prediction, label=input_y)
|
||||||
|
avg_cost = paddle.fluid.layers.mean(x=cost)
|
||||||
|
|
||||||
|
strategy = paddle.fleet.DistributedStrategy()
|
||||||
|
|
||||||
|
optimizer = paddle.optimizer.SGD(learning_rate=0.01)
|
||||||
|
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
|
||||||
|
optimizer.minimize(avg_cost)
|
||||||
|
|
||||||
|
def test_recompute_optimizer(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
input_x = paddle.fluid.layers.data(
|
||||||
|
name="x", shape=[32], dtype='float32')
|
||||||
|
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
|
||||||
|
|
||||||
|
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
|
||||||
|
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
|
||||||
|
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
|
||||||
|
cost = paddle.fluid.layers.cross_entropy(
|
||||||
|
input=prediction, label=input_y)
|
||||||
|
avg_cost = paddle.fluid.layers.mean(x=cost)
|
||||||
|
|
||||||
|
strategy = paddle.fleet.DistributedStrategy()
|
||||||
|
strategy.recompute = True
|
||||||
|
strategy.recompute_checkpoints = [fc_2]
|
||||||
|
|
||||||
|
optimizer = paddle.optimizer.SGD(learning_rate=0.01)
|
||||||
|
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
|
||||||
|
optimizer.minimize(avg_cost)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -0,0 +1,47 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import os
|
||||||
|
import paddle
|
||||||
|
import socket
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
class TestFleetPrivateFunction(unittest.TestCase):
|
||||||
|
def test_wait_port(self):
|
||||||
|
def init_server(port):
|
||||||
|
import time
|
||||||
|
time.sleep(5)
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
sock.bind(("127.0.0.1", port))
|
||||||
|
sock.listen(10)
|
||||||
|
while True:
|
||||||
|
c, addr = sock.accept()
|
||||||
|
c.send("0")
|
||||||
|
c.close()
|
||||||
|
break
|
||||||
|
|
||||||
|
thr = threading.Thread(target=init_server, args=(9292, ))
|
||||||
|
thr.start()
|
||||||
|
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
ep = ["127.0.0.1:9292"]
|
||||||
|
fleet.base.private_helper_function.wait_server_ready(ep)
|
||||||
|
|
||||||
|
thr.join()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -0,0 +1,40 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import paddle
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class TestFleetRuntime(unittest.TestCase):
|
||||||
|
def test_fleet_runtime_base(self):
|
||||||
|
import paddle.fleet.runtime
|
||||||
|
base = paddle.fleet.runtime.runtime_base.RuntimeBase()
|
||||||
|
base._run_worker()
|
||||||
|
base._init_server()
|
||||||
|
base._run_server()
|
||||||
|
base._stop_worker()
|
||||||
|
|
||||||
|
def test_fleet_collective_runtime(self):
|
||||||
|
import paddle.fleet.runtime
|
||||||
|
collective_runtime = paddle.fleet.runtime.CollectiveRuntime()
|
||||||
|
collective_runtime._init_worker()
|
||||||
|
collective_runtime._run_worker()
|
||||||
|
collective_runtime._init_worker()
|
||||||
|
collective_runtime._run_server()
|
||||||
|
collective_runtime._stop_worker()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -0,0 +1,68 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import paddle
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class TestFleetUtil(unittest.TestCase):
|
||||||
|
def test_util_base(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
util = fleet.UtilBase()
|
||||||
|
strategy = fleet.DistributedStrategy()
|
||||||
|
util._set_strategy(strategy)
|
||||||
|
role_maker = None # should be fleet.PaddleCloudRoleMaker()
|
||||||
|
util._set_role_maker(role_maker)
|
||||||
|
|
||||||
|
def test_util_factory(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
factory = fleet.base.util_factory.UtilFactory()
|
||||||
|
strategy = fleet.DistributedStrategy()
|
||||||
|
role_maker = None # should be fleet.PaddleCloudRoleMaker()
|
||||||
|
optimize_ops = []
|
||||||
|
params_grads = []
|
||||||
|
util = factory._create_util(strategy, role_maker, optimize_ops,
|
||||||
|
params_grads)
|
||||||
|
self.assertEqual(util.role_maker, None)
|
||||||
|
|
||||||
|
def test_get_util(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
default_util = fleet.util
|
||||||
|
self.assertEqual(default_util, None)
|
||||||
|
|
||||||
|
def test_set_user_defined_util(self):
|
||||||
|
import paddle.fleet as fleet
|
||||||
|
|
||||||
|
class UserDefinedUtil(fleet.UtilBase):
|
||||||
|
def __init__(self):
|
||||||
|
super(UserDefinedUtil, self).__init__()
|
||||||
|
|
||||||
|
def get_user_id(self):
|
||||||
|
return 10
|
||||||
|
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
||||||
|
fleet.init(role)
|
||||||
|
my_util = UserDefinedUtil()
|
||||||
|
fleet.util = my_util
|
||||||
|
user_id = fleet.util.get_user_id()
|
||||||
|
self.assertEqual(user_id, 10)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue