|
|
|
@ -15,11 +15,13 @@
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include "ps/worker.h"
|
|
|
|
|
#include "pipeline/jit/pipeline.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace ps {
|
|
|
|
|
void Worker::Run() {
|
|
|
|
|
std::lock_guard<std::mutex> lock(running_mutex_);
|
|
|
|
|
|
|
|
|
|
core::ClusterMetadata::instance()->Init(
|
|
|
|
|
PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(),
|
|
|
|
|
PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port());
|
|
|
|
@ -33,6 +35,14 @@ void Worker::Run() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Initialize();
|
|
|
|
|
worker_node_.set_event_callback([&](const core::NodeEvent &event) {
|
|
|
|
|
if ((event == core::NodeEvent::CLUSTER_TIMEOUT) ||
|
|
|
|
|
(event == core::NodeEvent::SCHEDULER_TIMEOUT || (event == core::NodeEvent::NODE_TIMEOUT))) {
|
|
|
|
|
MS_LOG(ERROR) << "Trigger timeout event:" << event << " begin to exit the system!";
|
|
|
|
|
Finalize();
|
|
|
|
|
exit(0);
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
MS_LOG(INFO) << "Worker starts connecting to scheduler and server...";
|
|
|
|
|
worker_node_.Start();
|
|
|
|
|
MS_LOG(INFO) << "Worker connected successfully.";
|
|
|
|
@ -86,7 +96,7 @@ void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs,
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "The total size is:" << total_size;
|
|
|
|
|
|
|
|
|
|
while (!IsReadyForPush(keys[0])) {
|
|
|
|
|
while (running_ && (!IsReadyForPush(keys[0]))) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
std::vector<int> sizes_int;
|
|
|
|
@ -109,7 +119,7 @@ void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs,
|
|
|
|
|
void Worker::Pull(const size_t key, void *dev_addr, const size_t size) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(dev_addr);
|
|
|
|
|
std::vector<float> variables(size / sizeof(float), 0);
|
|
|
|
|
while (!IsReadyForPull(key)) {
|
|
|
|
|
while (running_ && (!IsReadyForPull(key))) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
PullData({key}, &variables, nullptr, kPullCmd);
|
|
|
|
@ -214,7 +224,7 @@ void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &
|
|
|
|
|
|
|
|
|
|
std::string kv_data = embedding_table_meta.SerializeAsString();
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
|
|
|
|
|
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
|
|
|
|
|
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
|
|
|
|
|
if (ret != 0) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
|
|
|
@ -280,7 +290,7 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_
|
|
|
|
|
rank_ids.push_back(i);
|
|
|
|
|
std::string kv_data = messages.at(i).second.SerializeAsString();
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
|
|
|
|
|
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
|
|
|
|
|
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
|
|
|
|
|
if (ret != 0) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
|
|
|
@ -303,7 +313,7 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_
|
|
|
|
|
for (auto j = 0; j < message.values_size(); j++) {
|
|
|
|
|
values->push_back(message.values(j));
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "The embedding resp:" << values;
|
|
|
|
|
MS_LOG(DEBUG) << "The embedding resp:" << *values;
|
|
|
|
|
for (auto k = 0; k < message.keys_size(); k++) {
|
|
|
|
|
const Key &key = message.keys(k);
|
|
|
|
|
float *addr = values->data() + value_offset;
|
|
|
|
@ -358,7 +368,7 @@ void Worker::UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vecto
|
|
|
|
|
rank_ids.push_back(i);
|
|
|
|
|
std::string kv_data = messages.at(i).second.SerializeAsString();
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
|
|
|
|
|
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
|
|
|
|
|
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
|
|
|
|
|
if (ret != 0) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
|
|
|
@ -378,7 +388,7 @@ void Worker::Finalize() {
|
|
|
|
|
kvs.add_keys(0);
|
|
|
|
|
kvs.add_values(0.0f);
|
|
|
|
|
std::string kv_data = kvs.SerializeAsString();
|
|
|
|
|
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
|
|
|
|
|
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
|
|
|
|
|
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
|
|
|
|
|
if (ret != 0) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
|
|
|
@ -619,7 +629,7 @@ void Worker::PushData(const std::vector<Key> &keys, const std::vector<float> &va
|
|
|
|
|
SendForPush(cmd, kvs, worker_init_embedding_partitioner_, {});
|
|
|
|
|
} else {
|
|
|
|
|
std::string kv_data = kvs.SerializeAsString();
|
|
|
|
|
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
|
|
|
|
|
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
|
|
|
|
|
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
|
|
|
|
|
if (ret != 0) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
|
|
|
@ -920,7 +930,7 @@ void Worker::SendForPush(int cmd, const KVMessage &send, const KVPartitioner &pa
|
|
|
|
|
rank_ids.push_back(i);
|
|
|
|
|
std::string kv_data = messages.at(i).second.SerializeAsString();
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
|
|
|
|
|
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
|
|
|
|
|
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
|
|
|
|
|
if (ret != 0) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
|
|
|
@ -945,7 +955,7 @@ void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &pa
|
|
|
|
|
rank_ids.push_back(i);
|
|
|
|
|
std::string kv_data = messages.at(i).second.SerializeAsString();
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
|
|
|
|
|
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
|
|
|
|
|
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
|
|
|
|
|
if (ret != 0) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
|
|
|
|