From 6b3e1a687bf369ae30311ff2a40b99393bc15460 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Fri, 10 Jul 2020 09:27:46 +0800 Subject: [PATCH] Add worker proxy. --- mindspore/ccsrc/parallel/ps/worker_proxy.h | 311 +++++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 mindspore/ccsrc/parallel/ps/worker_proxy.h diff --git a/mindspore/ccsrc/parallel/ps/worker_proxy.h b/mindspore/ccsrc/parallel/ps/worker_proxy.h new file mode 100644 index 0000000000..8ffdde84ea --- /dev/null +++ b/mindspore/ccsrc/parallel/ps/worker_proxy.h @@ -0,0 +1,311 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ +#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ + +#include +#include +#include +#include +#include +#include "ps/ps.h" +#include "parallel/ps/util.h" + +namespace mindspore { +namespace parallel { +namespace ps { +template +class WorkerProxy : public ::ps::KVWorker { + public: + using Worker = ::ps::KVWorker; + using Callback = std::function; + using SlicedKVs = std::vector>>; + using Slicer = + std::function &send, const std::vector<::ps::Range> &ranges, SlicedKVs *sliced)>; + using ::ps::SimpleApp::obj_; + explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id) : Worker(app_id, customer_id) { + using _1 = std::placeholders::_1; + using _2 = std::placeholders::_2; + using _3 = std::placeholders::_3; + lookup_customer_ = std::unique_ptr<::ps::Customer>( + new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy::ProcessLookupResult, this, _1))); + lookup_slicer_ = std::bind(&WorkerProxy::LookupIdSlicer, this, _1, _2, _3); + init_embedding_slicer_ = std::bind(&WorkerProxy::EmbeddingTableInitSlicer, this, _1, _2, _3); + push_slicer_ = std::bind(&WorkerProxy::PushSlicer, this, _1, _2, _3); + broadcast_slicer_ = std::bind(&WorkerProxy::BroadcastSlicer, this, _1, _2, _3); + } + ~WorkerProxy() override = default; + + void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); + void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + const ::ps::SArray &lens, ::ps::SArray *outs, int cmd = 0, const Callback &cb = nullptr, + int priority = 0); + int InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, + const ::ps::SArray &lens = {}, const Callback &cb = nullptr, int priority = 0); + void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens = {}, + int cmd = 0, int priority = 0); + + private: + template + int AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, C *vals, int cmd, + const Callback &cb); + void LookupIdSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced); + void EmbeddingTableInitSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced); + void PushSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced); + void BroadcastSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced); + void ProcessLookupResult(const ::ps::Message &msg); + void Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, const ::ps::KVPairs &kvs, + const Slicer &slicer); + + std::unique_ptr<::ps::Customer> lookup_customer_; + std::unordered_map<::ps::Key, std::shared_ptr>> embedding_table_ranges_; + std::unordered_map>> lookup_results_; + std::mutex mutex_; + Slicer lookup_slicer_; + Slicer init_embedding_slicer_; + Slicer push_slicer_; + Slicer broadcast_slicer_; + std::unordered_map lookup_callbacks_; +}; + +template +void WorkerProxy::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) { + uint64_t begin = 0; + uint64_t end = 0; + int server_num = ::ps::NumServers(); + for (int i = 0; i < server_num; i++) { + int local_row_cnt = Util::LocalShard(row_count, i, server_num); + if (i == 0) { + end = local_row_cnt - 1; + } else { + begin = end + 1; + end += local_row_cnt; + } + ::ps::Range range(begin, end); + if (embedding_table_ranges_.count(key) == 0) { + embedding_table_ranges_[key] = std::make_shared>(); + } + embedding_table_ranges_[key]->push_back(range); + } +} + +template +void WorkerProxy::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + const ::ps::SArray &lens, ::ps::SArray *outs, int cmd, const Callback &cb, + int priority) { + int ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb); + ::ps::KVPairs kvs; + kvs.keys = keys; + kvs.vals = lookup_ids; + kvs.lens = lens; + kvs.priority = priority; + Send(lookup_customer_.get(), ts, true, true, cmd, kvs, broadcast_slicer_); + lookup_customer_->WaitRequest(ts); +} + +template +int WorkerProxy::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, + const ::ps::SArray &lens, const Callback &cb, int priority) { + int ts = obj_->NewRequest(::ps::kServerGroup); + ::ps::KVPairs kvs; + kvs.keys = keys; + kvs.vals = vals; + kvs.lens = lens; + kvs.priority = priority; + Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, init_embedding_slicer_); + return ts; +} + +template +void WorkerProxy::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, + const ::ps::SArray &lens, int cmd, int priority) { + int ts = obj_->NewRequest(::ps::kServerGroup); + ::ps::KVPairs kvs; + kvs.keys = keys; + kvs.vals = vals; + kvs.lens = lens; + kvs.priority = priority; + Send(obj_, ts, true, false, cmd, kvs, push_slicer_); + obj_->WaitRequest(ts); +} + +template +template +int WorkerProxy::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + C *lookup_result, int cmd, const Callback &cb) { + int ts = lookup_customer_->NewRequest(::ps::kServerGroup); + const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable { + mutex_.lock(); + auto &kvs = lookup_results_[ts]; + mutex_.unlock(); + + size_t total_len = 0; + const auto &s = kvs[0]; + for (size_t i = 0; i < s.lens.size(); i++) { + total_len += s.lens[i]; + } + lookup_result->resize(total_len, 0); + T *result_addr = lookup_result->data(); + + for (const auto &s : kvs) { + size_t offset = 0; + for (size_t i = 0; i < s.vals.size(); i++) { + result_addr[offset++] += s.vals[i]; + } + } + + mutex_.lock(); + lookup_results_.erase(ts); + mutex_.unlock(); + if (cb) cb(); + }; + lookup_callbacks_[ts] = callback; + return ts; +} + +template +void WorkerProxy::LookupIdSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced) { + int *data = send.lens.data(); + size_t size = send.lens.size(); + std::vector lookup_ids(data, data + size); + std::sort(lookup_ids.begin(), lookup_ids.end()); + + const Key &key = send.keys[0]; + const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); + sliced->resize(ranges.size()); + + size_t index = 0; + for (size_t i = 0; i < ranges.size(); i++) { + const ::ps::Range &range = ranges[i]; + const auto &begin = range.begin(); + const auto &end = range.end(); + auto &kvs = sliced->at(i).second; + + auto lookup_id = static_cast(lookup_ids[index]); + while (lookup_id >= begin && lookup_id <= end) { + kvs.vals.push_back(lookup_id); + if (++index >= lookup_ids.size()) { + break; + } + lookup_id = static_cast(lookup_ids[index]); + } + kvs.keys.push_back(key); + kvs.lens.push_back(kvs.vals.size()); + + if (kvs.vals.size() == 0) { + sliced->at(i).first = false; + } else { + sliced->at(i).first = true; + } + } +} + +template +void WorkerProxy::EmbeddingTableInitSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced) { + const Key &key = send.keys[0]; + const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); + sliced->resize(ranges.size()); + for (size_t i = 0; i < ranges.size(); i++) { + sliced->at(i).first = true; + sliced->at(i).second = send; + } +} + +template +void WorkerProxy::PushSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced) { + auto server_num = ::ps::Postoffice::Get()->num_servers(); + sliced->resize(server_num); + for (int i = 0; i < server_num; i++) { + sliced->at(i).first = true; + sliced->at(i).second = send; + } +} + +template +void WorkerProxy::BroadcastSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced) { + auto server_num = ::ps::Postoffice::Get()->num_servers(); + sliced->resize(server_num); + for (int i = 0; i < server_num; i++) { + sliced->at(i).first = true; + sliced->at(i).second = send; + } +} + +template +void WorkerProxy::ProcessLookupResult(const ::ps::Message &msg) { + int ts = msg.meta.timestamp; + if (msg.meta.pull) { + CHECK_GE(msg.data.size(), (size_t)2); + ::ps::KVPairs kvs; + kvs.keys = msg.data[0]; + kvs.vals = msg.data[1]; + if (msg.data.size() > (size_t)2) { + kvs.lens = msg.data[2]; + } + mutex_.lock(); + lookup_results_[ts].push_back(kvs); + mutex_.unlock(); + } + if (lookup_customer_->NumResponse(ts) == ::ps::Postoffice::Get()->num_servers() - 1) { + const auto &cb = lookup_callbacks_[ts]; + cb(); + lookup_callbacks_.erase(ts); + } +} + +template +void WorkerProxy::Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, + const ::ps::KVPairs &kvs, const Slicer &slicer) { + SlicedKVs sliced; + slicer(kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced); + + for (size_t i = 0; i < sliced.size(); i++) { + const auto &s = sliced[i]; + if (!s.first) continue; + ::ps::Message msg; + msg.meta.app_id = customer->app_id(); + msg.meta.customer_id = customer->customer_id(); + msg.meta.request = true; + msg.meta.push = push; + msg.meta.pull = pull; + msg.meta.head = cmd; + msg.meta.timestamp = timestamp; + msg.meta.recver = ::ps::Postoffice::Get()->ServerRankToID(i); + msg.meta.priority = kvs.priority; + const auto &kvs = s.second; + if (kvs.keys.size()) { + msg.AddData(kvs.keys); + msg.AddData(kvs.vals); + if (kvs.lens.size()) { + msg.AddData(kvs.lens); + } + } + ::ps::Postoffice::Get()->van()->Send(msg); + } +} +} // namespace ps +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_