update by comments

fix_gru_py
typhoonzero 7 years ago
parent 928418a9ac
commit 7b0c0273f4

@ -42,7 +42,7 @@ class ParallelExecutor {
const std::vector<Scope*>& local_scopes, const std::vector<Scope*>& local_scopes,
bool allow_op_delay, bool use_default_grad_scale, bool allow_op_delay, bool use_default_grad_scale,
bool balance_parameter_opt_between_cards, bool balance_parameter_opt_between_cards,
size_t num_trainers = 0, size_t trainer_id = 0); size_t num_trainers = 1, size_t trainer_id = 0);
~ParallelExecutor(); ~ParallelExecutor();

@ -75,29 +75,29 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default // NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and // deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash. // that will cause a wired crash.
rpc_service_ = new detail::AsyncGRPCServer(endpoint, true);
detail::AsyncGRPCServer rpc_service(endpoint, true);
framework::ProgramDesc empty_program; framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace()); framework::Executor executor(dev_ctx.GetPlace());
rpc_service_->SetScope(scope); rpc_service.SetScope(scope);
rpc_service_->SetDevCtx(&dev_ctx); rpc_service.SetDevCtx(&dev_ctx);
rpc_service_->SetProgram(&empty_program); rpc_service.SetProgram(&empty_program);
rpc_service_->SetExecutor(&executor); rpc_service.SetExecutor(&executor);
std::thread server_thread( std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service_)); std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, &rpc_service));
rpc_service_->SetCond(0); rpc_service.SetCond(0);
VLOG(3) << "start getting nccl id from trainer 0..."; VLOG(3) << "start getting nccl id from trainer 0...";
auto recv = rpc_service_->Get(); auto recv = rpc_service.Get();
VLOG(3) << "got nccl id and stop server..."; VLOG(3) << "got nccl id and stop server...";
rpc_service_->ShutDown(); rpc_service.ShutDown();
VLOG(3) << "rpc server stopped"; VLOG(3) << "rpc server stopped";
// TODO(wuyi): reinit nccl communicators // TODO(wuyi): reinit nccl communicators
server_thread.join(); server_thread.join();
delete rpc_service_;
} }
protected: // protected:
mutable detail::AsyncGRPCServer* rpc_service_ = nullptr; // mutable detail::AsyncGRPCServer* rpc_service_ = nullptr;
}; };
class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {

@ -78,7 +78,7 @@ struct NCCLContextMap {
explicit NCCLContextMap(const std::vector<platform::Place> &places, explicit NCCLContextMap(const std::vector<platform::Place> &places,
ncclUniqueId *nccl_id = nullptr, ncclUniqueId *nccl_id = nullptr,
size_t num_trainers = 0, size_t trainer_id = 0) { size_t num_trainers = 1, size_t trainer_id = 0) {
PADDLE_ENFORCE(!places.empty()); PADDLE_ENFORCE(!places.empty());
order_.reserve(places.size()); order_.reserve(places.size());
for (auto &p : places) { for (auto &p : places) {
@ -100,7 +100,7 @@ struct NCCLContextMap {
PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
comms.get(), static_cast<int>(order_.size()), order_.data())); comms.get(), static_cast<int>(order_.size()), order_.data()));
} else { } else {
PADDLE_ENFORCE_GT(num_trainers, 0); PADDLE_ENFORCE_GT(num_trainers, 1);
// TODO(wuyi): need to ensure each node have same number of GPUs // TODO(wuyi): need to ensure each node have same number of GPUs
{ {
int nranks = num_trainers * order_.size(); int nranks = num_trainers * order_.size();

@ -32,7 +32,7 @@ class ParallelExecutor(object):
share_vars_from=None, share_vars_from=None,
use_default_grad_scale=True, use_default_grad_scale=True,
balance_parameter_opt_between_cards=False, balance_parameter_opt_between_cards=False,
num_trainers=0, num_trainers=1,
trainer_id=0): trainer_id=0):
""" """
ParallelExecutor can run program in parallel. ParallelExecutor can run program in parallel.
@ -57,7 +57,7 @@ class ParallelExecutor(object):
balance_parameter_opt_between_cards(bool, default True): Whether balance_parameter_opt_between_cards(bool, default True): Whether
updating different gradients on different cards. Currently, it updating different gradients on different cards. Currently, it
is not recommended. is not recommended.
num_trainers(int, default 0): If greater than 0, NCCL will be num_trainers(int, default 1): If greater than 1, NCCL will be
initialized with multpile rank of nodes, each node should have initialized with multpile rank of nodes, each node should have
same number of GPUs. Distributed training will be enabled then. same number of GPUs. Distributed training will be enabled then.
trainer_id(int, default 0): Must use together with num_trainers. trainer_id(int, default 0): Must use together with num_trainers.

Loading…
Cancel
Save