the integrated communicator (#19849)

* add a base class for the Communicator
* add AsyncCommunicator Impl for async distributed training
fix-python-transpose
tangwei12 6 years ago committed by GitHub
parent 8d92b36d51
commit 8f0b3c0516
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -86,13 +86,10 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use communicator";
if (operators::distributed::Communicator::GetInstance() == nullptr) {
operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope);
operators::distributed::Communicator::GetInstance()->Start();
} else {
VLOG(3) << "communicator has been initialized, skip";
}
auto *instance = operators::distributed::Communicator::InitInstance<
operators::distributed::AsyncCommunicator>(send_varname_to_ctx,
recv_varname_to_ctx, scope);
if (!instance->IsRunning()) instance->Start();
}
#endif
}

File diff suppressed because it is too large Load Diff

@ -161,27 +161,97 @@ using RpcCtxMap = std::unordered_map<std::string, RpcContext>;
class Communicator {
public:
Communicator(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope);
Communicator() {}
virtual ~Communicator() {}
~Communicator();
virtual void Start() = 0;
virtual void Stop() = 0;
virtual bool IsRunning() { return running_; }
void Start();
void Stop();
virtual void Send(const std::string& var_name,
const framework::Scope& scope) = 0;
virtual void Recv() = 0;
bool IsRunning() { return running_; }
virtual void InitImpl(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) = 0;
// send grad
void Send(const std::string& var_name, const framework::Scope& scope);
virtual void InitImpl(const paddle::framework::ProgramDesc& program,
Scope* recv_scope) = 0;
private:
// recv all parameter
static Communicator* GetInstance() { return communicator_.get(); }
static std::shared_ptr<Communicator> GetInstantcePtr() {
return communicator_;
}
template <typename T>
static Communicator* InitInstance(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) {
std::call_once(init_flag_, &Communicator::InitWithRpcCtx<T>,
send_varname_to_ctx, recv_varname_to_ctx, recv_scope);
return communicator_.get();
}
// Init is called by InitInstance.
template <typename T>
static void InitWithRpcCtx(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) {
if (communicator_.get() == nullptr) {
communicator_.reset(new T());
communicator_->InitImpl(send_varname_to_ctx, recv_varname_to_ctx,
recv_scope);
}
}
template <typename T>
static Communicator* InitInstance(
const paddle::framework::ProgramDesc& program, Scope* recv_scope) {
std::call_once(init_flag_, &Communicator::InitWithProgram<T>, program,
recv_scope);
return communicator_.get();
}
template <typename T>
static void InitWithProgram(const paddle::framework::ProgramDesc& program,
Scope* recv_scope) {
if (communicator_.get() == nullptr) {
communicator_.reset(new T());
communicator_->InitImpl(program, recv_scope);
}
}
protected:
bool running_ = false;
static std::shared_ptr<Communicator> communicator_;
static std::once_flag init_flag_;
};
class AsyncCommunicator : public Communicator {
public:
AsyncCommunicator() {}
~AsyncCommunicator();
void Start() override;
void Stop() override;
void Send(const std::string& var_name,
const framework::Scope& scope) override;
void Recv() override;
void RecvAll();
void RecvNonIndependent();
void InitImpl(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) override;
void InitImpl(const paddle::framework::ProgramDesc& program,
Scope* recv_scope) override;
void SendThread();
void RecvThread();
bool running_ = false;
private:
std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_;
@ -194,26 +264,6 @@ class Communicator {
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv
// the following code is for initialize the commnunicator
public:
static void Init(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope) {
if (communicator_ == nullptr) {
communicator_.reset(new Communicator(send_varname_to_ctx,
recv_varname_to_ctx, recv_scope));
}
}
static void Init(const paddle::framework::ProgramDesc& program,
Scope* param_scope);
static Communicator* GetInstance();
static std::shared_ptr<Communicator> GetInstantcePtr();
private:
static std::shared_ptr<Communicator> communicator_;
};
} // namespace distributed

@ -48,14 +48,7 @@ class SendOp : public framework::OperatorBase {
if (send_varnames.size() > 0) {
PADDLE_ENFORCE_EQ(ins.size(), 1, "");
if (distributed::Communicator::GetInstance() == nullptr) {
auto send_functor = distributed::ParameterSend<float>();
auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap,
height_sections, trainer_id);
send_functor(rpc_ctx, scope, true);
} else {
distributed::Communicator::GetInstance()->Send(ins[0], scope);
}
distributed::Communicator::GetInstance()->Send(ins[0], scope);
} else {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();

@ -26,6 +26,7 @@ namespace py = pybind11;
using paddle::framework::ProgramDesc;
using paddle::operators::distributed::Communicator;
using paddle::operators::distributed::AsyncCommunicator;
using paddle::framework::Scope;
namespace paddle {
@ -36,7 +37,7 @@ void BindCommunicator(py::module* m) {
py::class_<Communicator, std::shared_ptr<Communicator>>(*m,
"DistCommunicator")
.def(py::init([](const ProgramDesc& program, Scope* param_scope) {
Communicator::Init(program, param_scope);
Communicator::InitInstance<AsyncCommunicator>(program, param_scope);
return Communicator::GetInstantcePtr();
}))
.def("stop", &Communicator::Stop)

@ -75,11 +75,14 @@ class TestDistRunnerBase(object):
sync_mode,
dc_asgd=False,
current_endpoint=None,
nccl_comm_num=1):
nccl_comm_num=1,
hogwild_mode=False):
# NOTE: import fluid until runtime, or else forking processes will cause error.
config = fluid.DistributeTranspilerConfig()
config.enable_dc_asgd = dc_asgd
config.sync_mode = sync_mode
config.runtime_split_send_recv = hogwild_mode
if nccl_comm_num > 1:
config.nccl_comm_num = nccl_comm_num
# config.runtime_split_send_recv = True
@ -89,6 +92,7 @@ class TestDistRunnerBase(object):
program=main_program,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=sync_mode,
current_endpoint=current_endpoint)
return t
@ -96,9 +100,15 @@ class TestDistRunnerBase(object):
self.lr = args.lr
self.get_model(batch_size=args.batch_size)
# NOTE: pserver should not call memory optimize
t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(), args.endpoints,
args.trainers, args.sync_mode, args.dc_asgd)
t = self.get_transpiler(
trainer_id=args.trainer_id,
main_program=fluid.default_main_program(),
pserver_endpoints=args.endpoints,
trainers=args.trainers,
sync_mode=args.sync_mode,
dc_asgd=args.dc_asgd,
hogwild_mode=args.hogwild)
pserver_prog = t.get_pserver_program(args.current_endpoint)
startup_prog = t.get_startup_program(args.current_endpoint,
pserver_prog)
@ -120,7 +130,7 @@ class TestDistRunnerBase(object):
dist_strategy = DistributedStrategy()
dist_strategy.exec_strategy = exec_strategy
dist_strategy.fuse_memory_size = 1 #MB
dist_strategy.fuse_memory_size = 1 # MB
dist_strategy.fuse_laryer_size = 1
if args.use_local_sgd:
dist_strategy.use_local_sgd = True
@ -130,11 +140,11 @@ class TestDistRunnerBase(object):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
print_to_err("gpu_fleet", "fleet.node_num:")
#"fleet.node_id:", fleet.node_id(),
#"fleet.trainer_num:", fleet.worker_num())
# "fleet.node_id:", fleet.node_id(),
# "fleet.trainer_num:", fleet.worker_num())
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
self.get_model(batch_size=args.batch_size, dist_strategy=dist_strategy)
self.get_model(batch_size=args.batch_size, dist_strategy=dist_strategy)
trainer_prog = fleet._origin_program
dist_prog = fleet.main_program
@ -196,10 +206,15 @@ class TestDistRunnerBase(object):
print_to_err(
type(self).__name__,
"begin to run transpile on trainer with pserver mode")
t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(),
args.endpoints, args.trainers,
args.sync_mode, args.dc_asgd)
t = self.get_transpiler(
trainer_id=args.trainer_id,
main_program=fluid.default_main_program(),
pserver_endpoints=args.endpoints,
trainers=args.trainers,
sync_mode=args.sync_mode,
dc_asgd=args.dc_asgd,
hogwild_mode=args.hogwild)
trainer_prog = t.get_trainer_program()
print_to_err(
type(self).__name__,
@ -251,6 +266,9 @@ class TestDistRunnerBase(object):
build_stra.enable_inplace = False
build_stra.memory_optimize = False
if args.hogwild:
build_stra.async_mode = True
if args.enable_backward_deps:
build_stra.enable_backward_optimizer_op_deps = True
@ -411,6 +429,7 @@ def runtime_main(test_class):
parser.add_argument('--use_dgc', action='store_true')
parser.add_argument('--use_reduce', action='store_true')
parser.add_argument('--dc_asgd', action='store_true')
parser.add_argument('--hogwild', action='store_true')
parser.add_argument(
'--use_reader_alloc', action='store_true', required=False)
parser.add_argument('--batch_size', required=False, type=int, default=2)
@ -467,6 +486,7 @@ class TestDistBase(unittest.TestCase):
self._find_free_port(), self._find_free_port())
self._python_interp = sys.executable
self._sync_mode = True
self._hogwild_mode = False
self._enforce_place = None
self._use_reduce = False
self._dc_asgd = False # must use with async mode
@ -630,6 +650,9 @@ class TestDistBase(unittest.TestCase):
if self._sync_mode:
tr0_cmd += " --sync_mode"
tr1_cmd += " --sync_mode"
if self._hogwild_mode:
tr0_cmd += " --hogwild"
tr1_cmd += " --hogwild"
if self._use_reduce:
tr0_cmd += " --use_reduce"
tr1_cmd += " --use_reduce"
@ -703,8 +726,8 @@ class TestDistBase(unittest.TestCase):
tr_cmd += " %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method %s --lr %f"
tr_cmd = tr_cmd % \
(self._python_interp, model, self._ps_endpoints,
trainer_id, ep, update_method, self._lr)
(self._python_interp, model, self._ps_endpoints,
trainer_id, ep, update_method, self._lr)
if self._use_reduce:
tr_cmd += " --use_reduce"
@ -825,9 +848,9 @@ class TestDistBase(unittest.TestCase):
required_envs["GLOG_v"] = "10"
required_envs["GLOG_logtostderr"] = "1"
local_losses\
local_losses \
= self._run_local(model_file, required_envs,
check_error_log)
check_error_log)
if self._nccl2_mode:
if self._nccl2_reduce_layer:
tr0_losses, tr1_losses = self._run_cluster_nccl2(

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
@ -29,14 +30,13 @@ def skip_ci(func):
return __func__
@skip_ci
class TestDistCTR2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
def test_dist_ctr(self):
self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False)
self.check_with_place("dist_ctr.py", delta=1e-2, check_error_log=False)
@skip_ci
@ -54,5 +54,40 @@ class TestDistCTRWithL2Decay2x2(TestDistBase):
need_envs=need_envs)
class TestDistCTR2x2_ASYNC(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._hogwild_mode = True
self._enforce_place = "CPU"
def test_dist_ctr(self):
need_envs = {
"FLAGS_communicator_send_queue_size": "2",
"FLAGS_communicator_max_merge_var_num": "2",
"FLAGS_communicator_max_send_grad_num_before_recv": "2",
}
self.check_with_place(
"dist_ctr.py", delta=100, check_error_log=True, need_envs=need_envs)
class TestDistCTR2x2_ASYNC2(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._hogwild_mode = True
self._enforce_place = "CPU"
def test_dist_ctr(self):
need_envs = {
"FLAGS_communicator_send_queue_size": "2",
"FLAGS_communicator_max_merge_var_num": "2",
"FLAGS_communicator_max_send_grad_num_before_recv": "2",
"FLAGS_communicator_independent_recv_thread": "0"
}
self.check_with_place(
"dist_ctr.py", delta=100, check_error_log=True, need_envs=need_envs)
if __name__ == "__main__":
unittest.main()

@ -30,7 +30,6 @@ def skip_ci(func):
return __func__
@skip_ci
class TestDistMnist2x2(TestFleetBase):
def _setup_config(self):
self._sync_mode = False

@ -33,7 +33,7 @@ class TestDistSimnetBowDense2x2(TestDistBase):
self.check_with_place(
"dist_simnet_bow.py",
delta=1e-5,
check_error_log=False,
check_error_log=True,
need_envs=need_envs)

Loading…
Cancel
Save