|
|
|
@ -22,6 +22,7 @@
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include "ps/ps.h"
|
|
|
|
|
#include "frontend/parallel/ps/util.h"
|
|
|
|
|
|
|
|
|
@ -34,24 +35,23 @@ class WorkerProxy : public ::ps::KVWorker<T> {
|
|
|
|
|
using Worker = ::ps::KVWorker<T>;
|
|
|
|
|
using Callback = std::function<void()>;
|
|
|
|
|
using SlicedKVs = std::vector<std::pair<bool, ::ps::KVPairs<T>>>;
|
|
|
|
|
using Slicer =
|
|
|
|
|
std::function<void(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &ranges, SlicedKVs *sliced)>;
|
|
|
|
|
using Slicer = std::function<void(int ts, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &ranges,
|
|
|
|
|
SlicedKVs *sliced)>;
|
|
|
|
|
using ::ps::SimpleApp::obj_;
|
|
|
|
|
explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id) : Worker(app_id, customer_id) {
|
|
|
|
|
using _1 = std::placeholders::_1;
|
|
|
|
|
using _2 = std::placeholders::_2;
|
|
|
|
|
using _3 = std::placeholders::_3;
|
|
|
|
|
using std::placeholders::_1;
|
|
|
|
|
using std::placeholders::_2;
|
|
|
|
|
using std::placeholders::_3;
|
|
|
|
|
using std::placeholders::_4;
|
|
|
|
|
lookup_customer_ = std::unique_ptr<::ps::Customer>(
|
|
|
|
|
new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy<T>::ProcessLookupResult, this, _1)));
|
|
|
|
|
lookup_slicer_ = std::bind(&WorkerProxy<T>::LookupIdSlicer, this, _1, _2, _3);
|
|
|
|
|
init_embedding_slicer_ = std::bind(&WorkerProxy<T>::EmbeddingTableInitSlicer, this, _1, _2, _3);
|
|
|
|
|
push_slicer_ = std::bind(&WorkerProxy<T>::PushSlicer, this, _1, _2, _3);
|
|
|
|
|
broadcast_slicer_ = std::bind(&WorkerProxy<T>::BroadcastSlicer, this, _1, _2, _3);
|
|
|
|
|
lookup_slicer_ = std::bind(&WorkerProxy<T>::LookupIdSlicer, this, _1, _2, _3, _4);
|
|
|
|
|
broadcast_slicer_ = std::bind(&WorkerProxy<T>::BroadcastSlicer, this, _1, _2, _3, _4);
|
|
|
|
|
}
|
|
|
|
|
~WorkerProxy() override = default;
|
|
|
|
|
|
|
|
|
|
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
|
|
|
|
|
void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids,
|
|
|
|
|
void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
|
|
|
|
const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int cmd = 0, const Callback &cb = nullptr,
|
|
|
|
|
int priority = 0);
|
|
|
|
|
int InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
|
|
|
|
@ -61,15 +61,11 @@ class WorkerProxy : public ::ps::KVWorker<T> {
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
template <typename C>
|
|
|
|
|
int AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids, C *vals, int cmd,
|
|
|
|
|
int AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, C *vals, int cmd,
|
|
|
|
|
const Callback &cb);
|
|
|
|
|
void LookupIdSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
|
|
|
|
void LookupIdSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
|
|
|
|
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced);
|
|
|
|
|
void EmbeddingTableInitSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
|
|
|
|
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced);
|
|
|
|
|
void PushSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
|
|
|
|
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced);
|
|
|
|
|
void BroadcastSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
|
|
|
|
void BroadcastSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
|
|
|
|
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced);
|
|
|
|
|
void ProcessLookupResult(const ::ps::Message &msg);
|
|
|
|
|
void Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, const ::ps::KVPairs<T> &kvs,
|
|
|
|
@ -80,10 +76,9 @@ class WorkerProxy : public ::ps::KVWorker<T> {
|
|
|
|
|
std::unordered_map<int, std::vector<::ps::KVPairs<T>>> lookup_results_;
|
|
|
|
|
std::mutex mutex_;
|
|
|
|
|
Slicer lookup_slicer_;
|
|
|
|
|
Slicer init_embedding_slicer_;
|
|
|
|
|
Slicer push_slicer_;
|
|
|
|
|
Slicer broadcast_slicer_;
|
|
|
|
|
std::unordered_map<int, Callback> lookup_callbacks_;
|
|
|
|
|
std::unordered_map<int, int> expected_result_count_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -108,17 +103,21 @@ void WorkerProxy<T>::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_c
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void WorkerProxy<T>::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids,
|
|
|
|
|
void WorkerProxy<T>::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
|
|
|
|
const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int cmd, const Callback &cb,
|
|
|
|
|
int priority) {
|
|
|
|
|
int ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb);
|
|
|
|
|
::ps::KVPairs<T> kvs;
|
|
|
|
|
kvs.keys = keys;
|
|
|
|
|
kvs.vals = lookup_ids;
|
|
|
|
|
kvs.lens = lens;
|
|
|
|
|
kvs.lens = lookup_ids;
|
|
|
|
|
kvs.priority = priority;
|
|
|
|
|
Send(lookup_customer_.get(), ts, true, true, cmd, kvs, broadcast_slicer_);
|
|
|
|
|
expected_result_count_[ts] = 0;
|
|
|
|
|
Send(lookup_customer_.get(), ts, true, true, cmd, kvs, lookup_slicer_);
|
|
|
|
|
int server_num = ::ps::NumServers();
|
|
|
|
|
int expect_rt_count = expected_result_count_[ts];
|
|
|
|
|
lookup_customer_->AddResponse(ts, server_num - expect_rt_count);
|
|
|
|
|
lookup_customer_->WaitRequest(ts);
|
|
|
|
|
expected_result_count_.erase(ts);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -130,7 +129,7 @@ int WorkerProxy<T>::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, cons
|
|
|
|
|
kvs.vals = vals;
|
|
|
|
|
kvs.lens = lens;
|
|
|
|
|
kvs.priority = priority;
|
|
|
|
|
Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, init_embedding_slicer_);
|
|
|
|
|
Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, broadcast_slicer_);
|
|
|
|
|
return ts;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -143,13 +142,13 @@ void WorkerProxy<T>::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::S
|
|
|
|
|
kvs.vals = vals;
|
|
|
|
|
kvs.lens = lens;
|
|
|
|
|
kvs.priority = priority;
|
|
|
|
|
Send(obj_, ts, true, false, cmd, kvs, push_slicer_);
|
|
|
|
|
Send(obj_, ts, true, false, cmd, kvs, broadcast_slicer_);
|
|
|
|
|
obj_->WaitRequest(ts);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
template <typename C>
|
|
|
|
|
int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids,
|
|
|
|
|
int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
|
|
|
|
|
C *lookup_result, int cmd, const Callback &cb) {
|
|
|
|
|
int ts = lookup_customer_->NewRequest(::ps::kServerGroup);
|
|
|
|
|
const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable {
|
|
|
|
@ -158,18 +157,28 @@ int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps:
|
|
|
|
|
mutex_.unlock();
|
|
|
|
|
|
|
|
|
|
size_t total_len = 0;
|
|
|
|
|
const auto &s = kvs[0];
|
|
|
|
|
for (size_t i = 0; i < s.lens.size(); i++) {
|
|
|
|
|
total_len += s.lens[i];
|
|
|
|
|
std::unordered_map<Key, std::shared_ptr<std::pair<T *, int>>> id_addr_map;
|
|
|
|
|
for (const auto &s : kvs) {
|
|
|
|
|
int offset = 0;
|
|
|
|
|
int len = s.vals.size() / s.keys.size();
|
|
|
|
|
for (size_t i = 0; i < s.keys.size(); i++) {
|
|
|
|
|
const Key &key = s.keys[i];
|
|
|
|
|
T *addr = s.vals.data() + offset;
|
|
|
|
|
offset += len;
|
|
|
|
|
total_len += len;
|
|
|
|
|
id_addr_map[key] = std::make_shared<std::pair<T *, int>>(std::make_pair(addr, len));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
lookup_result->resize(total_len, 0);
|
|
|
|
|
T *result_addr = lookup_result->data();
|
|
|
|
|
|
|
|
|
|
for (const auto &s : kvs) {
|
|
|
|
|
size_t offset = 0;
|
|
|
|
|
for (size_t i = 0; i < s.vals.size(); i++) {
|
|
|
|
|
result_addr[offset++] += s.vals[i];
|
|
|
|
|
T *result_addr = lookup_result->data();
|
|
|
|
|
int offset = 0;
|
|
|
|
|
for (size_t i = 0; i < lookup_ids.size(); i++) {
|
|
|
|
|
auto &pair = id_addr_map[static_cast<Key>(lookup_ids[i])];
|
|
|
|
|
auto ret = memcpy_s(result_addr + offset, pair->second, pair->first, pair->second);
|
|
|
|
|
if (ret != 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
|
|
|
|
}
|
|
|
|
|
offset += pair->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mutex_.lock();
|
|
|
|
@ -182,31 +191,30 @@ int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void WorkerProxy<T>::LookupIdSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
|
|
|
|
void WorkerProxy<T>::LookupIdSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
|
|
|
|
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) {
|
|
|
|
|
int *data = send.lens.data();
|
|
|
|
|
size_t size = send.lens.size();
|
|
|
|
|
std::vector<int> lookup_ids(data, data + size);
|
|
|
|
|
std::sort(lookup_ids.begin(), lookup_ids.end());
|
|
|
|
|
int *lookup_ids = send.lens.data();
|
|
|
|
|
size_t id_size = send.lens.size();
|
|
|
|
|
|
|
|
|
|
const Key &key = send.keys[0];
|
|
|
|
|
const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
|
|
|
|
|
sliced->resize(ranges.size());
|
|
|
|
|
|
|
|
|
|
size_t index = 0;
|
|
|
|
|
for (size_t i = 0; i < ranges.size(); i++) {
|
|
|
|
|
const ::ps::Range &range = ranges[i];
|
|
|
|
|
const auto &begin = range.begin();
|
|
|
|
|
const auto &end = range.end();
|
|
|
|
|
std::unordered_set<int> unique_ids;
|
|
|
|
|
auto &kvs = sliced->at(i).second;
|
|
|
|
|
|
|
|
|
|
auto lookup_id = static_cast<uint64_t>(lookup_ids[index]);
|
|
|
|
|
while (lookup_id >= begin && lookup_id <= end) {
|
|
|
|
|
kvs.vals.push_back(lookup_id);
|
|
|
|
|
if (++index >= lookup_ids.size()) {
|
|
|
|
|
break;
|
|
|
|
|
for (size_t j = 0; j < id_size; j++) {
|
|
|
|
|
auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
|
|
|
|
|
if (lookup_id >= begin && lookup_id <= end) {
|
|
|
|
|
unique_ids.insert(lookup_id);
|
|
|
|
|
}
|
|
|
|
|
lookup_id = static_cast<uint64_t>(lookup_ids[index]);
|
|
|
|
|
}
|
|
|
|
|
for (const auto &lookup_id : unique_ids) {
|
|
|
|
|
kvs.vals.push_back(lookup_id);
|
|
|
|
|
}
|
|
|
|
|
kvs.keys.push_back(key);
|
|
|
|
|
kvs.lens.push_back(kvs.vals.size());
|
|
|
|
@ -215,35 +223,13 @@ void WorkerProxy<T>::LookupIdSlicer(const ::ps::KVPairs<T> &send, const std::vec
|
|
|
|
|
sliced->at(i).first = false;
|
|
|
|
|
} else {
|
|
|
|
|
sliced->at(i).first = true;
|
|
|
|
|
expected_result_count_[timestamp] += 1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void WorkerProxy<T>::EmbeddingTableInitSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
|
|
|
|
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) {
|
|
|
|
|
const Key &key = send.keys[0];
|
|
|
|
|
const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
|
|
|
|
|
sliced->resize(ranges.size());
|
|
|
|
|
for (size_t i = 0; i < ranges.size(); i++) {
|
|
|
|
|
sliced->at(i).first = true;
|
|
|
|
|
sliced->at(i).second = send;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void WorkerProxy<T>::PushSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
|
|
|
|
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) {
|
|
|
|
|
auto server_num = ::ps::Postoffice::Get()->num_servers();
|
|
|
|
|
sliced->resize(server_num);
|
|
|
|
|
for (int i = 0; i < server_num; i++) {
|
|
|
|
|
sliced->at(i).first = true;
|
|
|
|
|
sliced->at(i).second = send;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void WorkerProxy<T>::BroadcastSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
|
|
|
|
void WorkerProxy<T>::BroadcastSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
|
|
|
|
|
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) {
|
|
|
|
|
auto server_num = ::ps::Postoffice::Get()->num_servers();
|
|
|
|
|
sliced->resize(server_num);
|
|
|
|
@ -268,7 +254,7 @@ void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) {
|
|
|
|
|
lookup_results_[ts].push_back(kvs);
|
|
|
|
|
mutex_.unlock();
|
|
|
|
|
}
|
|
|
|
|
if (lookup_customer_->NumResponse(ts) == ::ps::Postoffice::Get()->num_servers() - 1) {
|
|
|
|
|
if (lookup_customer_->NumResponse(ts) == expected_result_count_[ts] - 1) {
|
|
|
|
|
const auto &cb = lookup_callbacks_[ts];
|
|
|
|
|
cb();
|
|
|
|
|
lookup_callbacks_.erase(ts);
|
|
|
|
@ -279,7 +265,7 @@ template <typename T>
|
|
|
|
|
void WorkerProxy<T>::Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd,
|
|
|
|
|
const ::ps::KVPairs<T> &kvs, const Slicer &slicer) {
|
|
|
|
|
SlicedKVs sliced;
|
|
|
|
|
slicer(kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced);
|
|
|
|
|
slicer(timestamp, kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < sliced.size(); i++) {
|
|
|
|
|
const auto &s = sliced[i];
|
|
|
|
|