solve bug in heter mode (#31531)

* heter bug

* format

* format
fix_imperative_dygraph_error
Thunderbrook 4 years ago committed by GitHub
parent 6148b87f9d
commit 3789a69923
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -168,6 +168,7 @@ class DeviceWorker {
virtual void CacheProgram(const ProgramDesc& main_program) {}
virtual void ProduceTasks() {}
virtual void GetXpuOpIndex() {}
virtual void Schedule(int taskid) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
virtual void SetStream(const gpuStream_t stream) {}
virtual void SetEvent(const gpuEvent_t event) {}

@ -62,9 +62,8 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc,
void DistMultiTrainer::RegisterHeterCallback() {
auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->RegisterHeterCallback([this](int worker, int taskid) {
// workers_[worker]->Schedule(taskid);
});
fleet_ptr->RegisterHeterCallback(
[this](int worker, int taskid) { workers_[worker]->Schedule(taskid); });
}
void DistMultiTrainer::InitDumpEnv() {

@ -193,7 +193,6 @@ void FleetWrapper::HeterPullSparseVars(
for (auto& t : fea_values) {
pull_result_ptr.push_back(t.data());
}
/*
auto status = pslib_ptr_->_worker_ptr->heter_pull_sparse(
workerid, pull_result_ptr.data(), table_id, fea_keys.data(),
fea_keys.size(), task->taskid_);
@ -207,7 +206,6 @@ void FleetWrapper::HeterPullSparseVars(
exit(-1);
}
}
*/
}
void FleetWrapper::HeterPushSparseVars(

@ -1039,11 +1039,17 @@ class HeterRoleMaker(GeneralRoleMaker):
self._node_type = 1
self._cur_endpoint = worker_endpoints[current_id]
gloo = fluid.core.Gloo()
gloo.init(current_id,
len(worker_endpoints),
self._hdfs_path.rstrip("/") + "/trainer",
self._hdfs_name, self._hdfs_ugi, self._iface,
self._prefix)
gloo.set_rank(current_id)
gloo.set_size(len(worker_endpoints))
gloo.set_prefix(self._prefix)
gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
gloo.set_hdfs_store(
self._hdfs_path.rstrip("/") + "/trainer", self._hdfs_name,
self._hdfs_ugi)
gloo.init()
self._node_type_comm = gloo
elif training_role == "XPU":
role = Role.XPU
@ -1051,10 +1057,17 @@ class HeterRoleMaker(GeneralRoleMaker):
self._node_type = 2
self._cur_endpoint = xpu_endpoints[current_id]
gloo = fluid.core.Gloo()
gloo.init(current_id,
len(xpu_endpoints),
gloo.set_rank(current_id)
gloo.set_size(len(xpu_endpoints))
gloo.set_prefix(self._prefix)
gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
gloo.set_hdfs_store(
self._hdfs_path.rstrip("/") + "/xpu", self._hdfs_name,
self._hdfs_ugi, self._iface, self._prefix)
self._hdfs_ugi)
gloo.init()
self._node_type_comm = gloo
elif training_role == "PSERVER":
role = Role.SERVER
@ -1070,30 +1083,47 @@ class HeterRoleMaker(GeneralRoleMaker):
self._node_type = 0
self._cur_endpoint = cur_endpoint
gloo = fluid.core.Gloo()
gloo.init(current_id,
len(eplist),
self._hdfs_path.rstrip("/") + "/pserver",
self._hdfs_name, self._hdfs_ugi, self._iface,
self._prefix)
gloo.set_rank(current_id)
gloo.set_size(len(eplist))
gloo.set_prefix(self._prefix)
gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
gloo.set_hdfs_store(
self._hdfs_path.rstrip("/") + "/pserver", self._hdfs_name,
self._hdfs_ugi)
gloo.init()
self._node_type_comm = gloo
if training_role == "TRAINER" or training_role == "XPU":
gloo = fluid.core.Gloo()
heter_list = worker_endpoints + xpu_endpoints
gloo.init(
heter_list.index(self._cur_endpoint),
len(heter_list),
gloo.set_rank(heter_list.index(self._cur_endpoint))
gloo.set_size(len(heter_list))
gloo.set_prefix(self._prefix)
gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
gloo.set_hdfs_store(
self._hdfs_path.rstrip("/") + "/heter", self._hdfs_name,
self._hdfs_ugi, self._iface, self._prefix)
self._hdfs_ugi)
gloo.init()
self._heter_comm = gloo
gloo = fluid.core.Gloo()
all_list = worker_endpoints + eplist + xpu_endpoints
gloo.init(
all_list.index(self._cur_endpoint),
len(all_list),
gloo.set_rank(all_list.index(self._cur_endpoint))
gloo.set_size(len(all_list))
gloo.set_prefix(self._prefix)
gloo.set_iface(self._iface)
gloo.set_timeout_seconds(self._init_timeout_seconds,
self._run_timeout_seconds)
gloo.set_hdfs_store(
self._hdfs_path.rstrip("/") + "/all", self._hdfs_name,
self._hdfs_ugi, self._iface, self._prefix)
self._hdfs_ugi)
gloo.init()
self._all_comm = gloo
self._trainers_num = trainers_num

Loading…
Cancel
Save