|
|
|
@ -89,7 +89,7 @@ void Worker<T>::Run() {
|
|
|
|
|
if (!::ps::IsWorker()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The role is not worker.";
|
|
|
|
|
}
|
|
|
|
|
kv_worker_ = std::make_shared<WorkerProxy<T>>(0, 0, 1);
|
|
|
|
|
kv_worker_ = std::make_shared<WorkerProxy<T>>(0, 0, 1, 2);
|
|
|
|
|
running_ = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -121,7 +121,7 @@ void Worker<T>::Pull(const size_t key, void *dev_addr, const size_t size) {
|
|
|
|
|
while (!kv_worker_->IsReadyForPull(key)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
kv_worker_->Wait(kv_worker_->ZPull({key}, &variables));
|
|
|
|
|
kv_worker_->PullData({key}, &variables);
|
|
|
|
|
auto ret = memcpy_s(dev_addr, size, variables.data(), size);
|
|
|
|
|
if (ret != 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
|
|
|
@ -149,7 +149,7 @@ void Worker<T>::InitPSParamData(const std::vector<size_t> &keys, void *origin_ad
|
|
|
|
|
::ps::SArray<::ps::Key> key(keys);
|
|
|
|
|
::ps::SArray<int> lens;
|
|
|
|
|
lens.push_back(addr.size());
|
|
|
|
|
kv_worker_->Wait(kv_worker_->ZPush(key, addr, lens, kInitWeightsCmd));
|
|
|
|
|
kv_worker_->PushData(key, addr, lens, kInitWeightsCmd);
|
|
|
|
|
init_keys_[key[0]] = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -269,7 +269,6 @@ void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vecto
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
// Initialize parameters and optimizer kernels of Parameter Server.
|
|
|
|
|
void Worker<T>::InitPSParamAndOptim(const std::string ¶m_name, tensor::TensorPtr tensor) {
|
|
|
|
|
void *param_data = tensor->data_c();
|
|
|
|
|
size_t param_size = LongToSize(tensor->data().nbytes());
|
|
|
|
@ -290,6 +289,7 @@ void Worker<T>::InitPSParamAndOptim(const std::string ¶m_name, tensor::Tenso
|
|
|
|
|
if (!init) {
|
|
|
|
|
MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name
|
|
|
|
|
<< ", whether init in server: " << init_in_server;
|
|
|
|
|
kv_worker_->AddKeyToServerId(param_key);
|
|
|
|
|
if (!init_in_server) {
|
|
|
|
|
InitPSParamData({param_key}, param_data, param_size);
|
|
|
|
|
}
|
|
|
|
|