|
|
|
@ -466,41 +466,34 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
|
|
|
|
|
const std::vector<std::string> &var_tables,
|
|
|
|
|
const framework::Scope &scope) {
|
|
|
|
|
waiting_ = false;
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
var_tables.size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
|
|
|
|
|
|
|
|
|
|
auto table_name = var_tables[0];
|
|
|
|
|
if (table_name == STEP_COUNTER) return;
|
|
|
|
|
|
|
|
|
|
auto before_send = GetCurrentUS();
|
|
|
|
|
std::unordered_map<std::string, std::unordered_set<int64_t>> ids_table;
|
|
|
|
|
size_t splited_var_nums =
|
|
|
|
|
send_varname_to_ctx_[table_name].splited_varnames.size();
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < var_tables.size(); i++) {
|
|
|
|
|
auto table_name = var_tables[i];
|
|
|
|
|
if (table_name == STEP_COUNTER) {
|
|
|
|
|
continue;
|
|
|
|
|
} else {
|
|
|
|
|
size_t splited_var_nums =
|
|
|
|
|
send_varname_to_ctx_[table_name].splited_varnames.size();
|
|
|
|
|
|
|
|
|
|
for (size_t j = 0; j < splited_var_nums; j++) {
|
|
|
|
|
if (ids_table.find(
|
|
|
|
|
send_varname_to_ctx_[table_name].splited_varnames[j]) ==
|
|
|
|
|
ids_table.end()) {
|
|
|
|
|
ids_table.insert(std::pair<std::string, std::unordered_set<int64_t>>(
|
|
|
|
|
send_varname_to_ctx_[table_name].splited_varnames[j],
|
|
|
|
|
std::unordered_set<int64_t>()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::unordered_map<std::string, std::unordered_set<int64_t>> ids_table;
|
|
|
|
|
|
|
|
|
|
auto *var = scope.FindVar(var_names[i]);
|
|
|
|
|
auto var_tensor = var->Get<framework::LoDTensor>();
|
|
|
|
|
int element_number = var_tensor.numel();
|
|
|
|
|
const int64_t *var_mutable_data = var_tensor.data<int64_t>();
|
|
|
|
|
for (size_t j = 0; j < splited_var_nums; j++) {
|
|
|
|
|
ids_table.insert(std::pair<std::string, std::unordered_set<int64_t>>(
|
|
|
|
|
send_varname_to_ctx_[table_name].splited_varnames[j],
|
|
|
|
|
std::unordered_set<int64_t>()));
|
|
|
|
|
}
|
|
|
|
|
auto *var = scope.FindVar(var_names[0]);
|
|
|
|
|
auto &rows = var->Get<framework::SelectedRows>().rows();
|
|
|
|
|
|
|
|
|
|
// insert ids which has not been record
|
|
|
|
|
for (int j = 0; j < element_number; j++) {
|
|
|
|
|
auto ep_idx = var_mutable_data[j] % splited_var_nums;
|
|
|
|
|
ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx])
|
|
|
|
|
.insert(var_mutable_data[j]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// insert ids which has not been record
|
|
|
|
|
for (size_t j = 0; j < rows.size(); j++) {
|
|
|
|
|
auto ep_idx = rows[j] % splited_var_nums;
|
|
|
|
|
ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx])
|
|
|
|
|
.insert(rows[j]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto before_push = GetCurrentUS();
|
|
|
|
|
for (auto &iter : ids_table) {
|
|
|
|
|
auto &key = iter.first;
|
|
|
|
@ -512,8 +505,8 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
|
|
|
|
|
<< "'s queue";
|
|
|
|
|
}
|
|
|
|
|
auto after_send = GetCurrentUS();
|
|
|
|
|
VLOG(3) << "run send_op finish. using " << (before_push - before_send) << "; "
|
|
|
|
|
<< (after_send - before_push);
|
|
|
|
|
VLOG(3) << "run send " << table_name << " op finish. using "
|
|
|
|
|
<< (before_push - before_send) << "; " << (after_send - before_push);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GeoCommunicator::MainThread() {
|
|
|
|
|