|
|
|
@ -112,7 +112,7 @@ class WorkerProxy : public ::ps::KVWorker<T> {
|
|
|
|
|
std::unique_ptr<::ps::Customer> general_customer_;
|
|
|
|
|
std::unordered_map<::ps::Key, std::shared_ptr<std::vector<::ps::Range>>> embedding_table_ranges_;
|
|
|
|
|
std::unordered_map<int, std::vector<::ps::KVPairs<T>>> lookup_results_;
|
|
|
|
|
std::unordered_map<int, ::ps::KVPairs<T>> gathered_response_;
|
|
|
|
|
std::unordered_map<int, std::map<int, ::ps::KVPairs<T>>> gathered_response_;
|
|
|
|
|
std::mutex mutex_;
|
|
|
|
|
Slicer lookup_slicer_;
|
|
|
|
|
Slicer sparse_slicer_;
|
|
|
|
@ -337,12 +337,19 @@ int WorkerProxy<T>::AddGeneralRspCB(const ::ps::SArray<::ps::Key> &keys, ::ps::S
|
|
|
|
|
int ts = general_customer_->NewRequest(::ps::kServerGroup);
|
|
|
|
|
const auto &callback = [this, ts, keys, vals, lens, cb]() mutable {
|
|
|
|
|
mutex_.lock();
|
|
|
|
|
auto &kvs = gathered_response_[ts];
|
|
|
|
|
std::map<int, ::ps::KVPairs<T>> server_kvs = gathered_response_[ts];
|
|
|
|
|
mutex_.unlock();
|
|
|
|
|
|
|
|
|
|
*vals = kvs.vals;
|
|
|
|
|
if (lens) {
|
|
|
|
|
*lens = kvs.lens;
|
|
|
|
|
vals->clear();
|
|
|
|
|
for (auto kvs : server_kvs) {
|
|
|
|
|
for (auto val : kvs.second.vals) {
|
|
|
|
|
vals->push_back(val);
|
|
|
|
|
}
|
|
|
|
|
if (lens) {
|
|
|
|
|
for (auto len : kvs.second.lens) {
|
|
|
|
|
lens->push_back(len);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mutex_.lock();
|
|
|
|
@ -464,43 +471,50 @@ void WorkerProxy<T>::SparseSlicer(int timestamp, const ::ps::KVPairs<T> &send, c
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
size_t indices_size = indice_ids.size();
|
|
|
|
|
int slice_segment_size = indices_size * segment_size;
|
|
|
|
|
T *src_grad_data = new T[slice_segment_size];
|
|
|
|
|
int *src_indice_data = new int[indices_size];
|
|
|
|
|
PrepareSparseGradient(begin, end, distinct_ids, indice_to_grads, indice_data, segment_size, src_grad_data,
|
|
|
|
|
src_indice_data);
|
|
|
|
|
|
|
|
|
|
// Reduce the sparse gradient and indice
|
|
|
|
|
T *new_grad = new T[slice_segment_size];
|
|
|
|
|
int *new_indices = new int[indices_size];
|
|
|
|
|
mindspore::kernel::SparseGradient<int> unique_sparse_grad({new_grad, new_indices, indices_size});
|
|
|
|
|
Util::ReduceSparseGradient(src_grad_data, src_indice_data, indices_size, segment_size, first_dim_size,
|
|
|
|
|
outer_dim_size, &unique_sparse_grad);
|
|
|
|
|
|
|
|
|
|
// Update the length of reduce sparse gradient and indice
|
|
|
|
|
::ps::SArray<int> reduced_lens;
|
|
|
|
|
reduced_lens.CopyFrom(kvs.lens);
|
|
|
|
|
reduced_lens[grad_index] = unique_sparse_grad.indices_size_ * segment_size;
|
|
|
|
|
reduced_lens[indice_index] = unique_sparse_grad.indices_size_;
|
|
|
|
|
|
|
|
|
|
// Build the sparse value to be sent
|
|
|
|
|
size_t total_size = 0;
|
|
|
|
|
for (auto size : reduced_lens) {
|
|
|
|
|
total_size += size;
|
|
|
|
|
}
|
|
|
|
|
::ps::SArray<T> reduced_data(total_size, 0);
|
|
|
|
|
BuildSparseValue(reduced_lens, grad_index, indice_index, data, unique_sparse_grad.value_,
|
|
|
|
|
unique_sparse_grad.indices_, &reduced_data);
|
|
|
|
|
if (indices_size > 0) {
|
|
|
|
|
int slice_segment_size = indices_size * segment_size;
|
|
|
|
|
T *src_grad_data = new T[slice_segment_size];
|
|
|
|
|
int *src_indice_data = new int[indices_size];
|
|
|
|
|
PrepareSparseGradient(begin, end, distinct_ids, indice_to_grads, indice_data, segment_size, src_grad_data,
|
|
|
|
|
src_indice_data);
|
|
|
|
|
|
|
|
|
|
// Reduce the sparse gradient and indice
|
|
|
|
|
T *new_grad = new T[slice_segment_size];
|
|
|
|
|
int *new_indices = new int[indices_size];
|
|
|
|
|
mindspore::kernel::SparseGradient<int> unique_sparse_grad({new_grad, new_indices, indices_size});
|
|
|
|
|
Util::ReduceSparseGradient(src_grad_data, src_indice_data, indices_size, segment_size, first_dim_size,
|
|
|
|
|
outer_dim_size, &unique_sparse_grad);
|
|
|
|
|
|
|
|
|
|
// Update the length of reduce sparse gradient and indice
|
|
|
|
|
::ps::SArray<int> reduced_lens;
|
|
|
|
|
reduced_lens.CopyFrom(kvs.lens);
|
|
|
|
|
reduced_lens[grad_index] = unique_sparse_grad.indices_size_ * segment_size;
|
|
|
|
|
reduced_lens[indice_index] = unique_sparse_grad.indices_size_;
|
|
|
|
|
|
|
|
|
|
// Build the sparse value to be sent
|
|
|
|
|
size_t total_size = 0;
|
|
|
|
|
for (auto size : reduced_lens) {
|
|
|
|
|
total_size += size;
|
|
|
|
|
}
|
|
|
|
|
::ps::SArray<T> reduced_data(total_size, 0);
|
|
|
|
|
BuildSparseValue(reduced_lens, grad_index, indice_index, data, unique_sparse_grad.value_,
|
|
|
|
|
unique_sparse_grad.indices_, &reduced_data);
|
|
|
|
|
|
|
|
|
|
kvs.lens = reduced_lens;
|
|
|
|
|
kvs.vals = reduced_data;
|
|
|
|
|
kvs.lens = reduced_lens;
|
|
|
|
|
kvs.vals = reduced_data;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (indices_size <= 0) {
|
|
|
|
|
sliced->at(i).first = false;
|
|
|
|
|
} else {
|
|
|
|
|
sliced->at(i).first = true;
|
|
|
|
|
expected_result_count_[timestamp] += 1;
|
|
|
|
|
::ps::SArray<T> no_keys;
|
|
|
|
|
::ps::SArray<T> no_vals;
|
|
|
|
|
::ps::SArray<T> no_lens;
|
|
|
|
|
no_keys.push_back(key);
|
|
|
|
|
no_vals.push_back(-100);
|
|
|
|
|
kvs.vals = no_vals;
|
|
|
|
|
kvs.lens = no_lens;
|
|
|
|
|
}
|
|
|
|
|
sliced->at(i).first = true;
|
|
|
|
|
expected_result_count_[timestamp] += 1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -554,8 +568,8 @@ void WorkerProxy<T>::BuildSparseValue(const ::ps::SArray<int> &lengths, const si
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Fill the reduced indice
|
|
|
|
|
int indice_offset = grad_offset + lengths[grad_index];
|
|
|
|
|
data_size = lengths[indice_index] * sizeof(T);
|
|
|
|
|
int indice_offset = grad_offset + data_size;
|
|
|
|
|
T *indice_data = reduced_data->data() + indice_offset;
|
|
|
|
|
T *convert = new T[lengths[indice_index]];
|
|
|
|
|
for (int i = 0; i < lengths[indice_index]; i++) {
|
|
|
|
@ -656,7 +670,7 @@ void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) {
|
|
|
|
|
lookup_results_[ts].push_back(kvs);
|
|
|
|
|
mutex_.unlock();
|
|
|
|
|
}
|
|
|
|
|
if (lookup_customer_->NumResponse(ts) == expected_result_count_[ts] - 1) {
|
|
|
|
|
if (lookup_customer_->NumResponse(ts) + 1 == server_num_) {
|
|
|
|
|
const auto &cb = lookup_callbacks_[ts];
|
|
|
|
|
cb();
|
|
|
|
|
lookup_callbacks_.erase(ts);
|
|
|
|
@ -676,15 +690,8 @@ void WorkerProxy<T>::ProcessResponse(const ::ps::Message &msg) {
|
|
|
|
|
kvs.lens = msg.data[2];
|
|
|
|
|
}
|
|
|
|
|
mutex_.lock();
|
|
|
|
|
for (auto key : kvs.keys) {
|
|
|
|
|
gathered_response_[ts].keys.push_back(key);
|
|
|
|
|
}
|
|
|
|
|
for (auto val : kvs.vals) {
|
|
|
|
|
gathered_response_[ts].vals.push_back(val);
|
|
|
|
|
}
|
|
|
|
|
for (auto len : kvs.lens) {
|
|
|
|
|
gathered_response_[ts].lens.push_back(len);
|
|
|
|
|
}
|
|
|
|
|
int rsp_server_rank = ::ps::Postoffice::Get()->IDtoRank(msg.meta.sender);
|
|
|
|
|
gathered_response_[ts][rsp_server_rank] = kvs;
|
|
|
|
|
mutex_.unlock();
|
|
|
|
|
if (general_customer_->NumResponse(ts) + 1 == server_num_) {
|
|
|
|
|
const auto &cb = general_callbacks_[ts];
|
|
|
|
|