|
|
|
@ -28,7 +28,9 @@
|
|
|
|
|
#include <thread>
|
|
|
|
|
#include <cmath>
|
|
|
|
|
#include <random>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <list>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include "ir/func_graph.h"
|
|
|
|
|
#include "backend/session/session_basic.h"
|
|
|
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
|
|
@ -52,6 +54,7 @@ namespace mindspore {
|
|
|
|
|
namespace parallel {
|
|
|
|
|
namespace ps {
|
|
|
|
|
using mindspore::kernel::ps::PServerKernel;
|
|
|
|
|
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
|
|
|
|
template <typename T>
|
|
|
|
|
class ParameterServer {
|
|
|
|
|
public:
|
|
|
|
@ -126,6 +129,8 @@ class ParameterServer {
|
|
|
|
|
void ResetGradAccumCount();
|
|
|
|
|
const CNodePtr GetCNode(const std::string &name) const;
|
|
|
|
|
std::mutex &mutex();
|
|
|
|
|
void GetEmbeddingTableParamPtr();
|
|
|
|
|
void SyncEmbeddingTables();
|
|
|
|
|
|
|
|
|
|
size_t pserver_num_;
|
|
|
|
|
size_t worker_num_;
|
|
|
|
@ -154,6 +159,7 @@ class ParameterServer {
|
|
|
|
|
std::condition_variable apply_grads_cv_;
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<std::thread> thread_;
|
|
|
|
|
std::map<Key, ParameterPtr> embedding_tables_;
|
|
|
|
|
|
|
|
|
|
friend class ServerHandler;
|
|
|
|
|
};
|
|
|
|
@ -329,6 +335,7 @@ bool ParameterServer<T>::Init(const FuncGraphPtr &func_graph) {
|
|
|
|
|
InitOptimInfoBuilders();
|
|
|
|
|
ps_->set_request_handle(*handler_);
|
|
|
|
|
thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this));
|
|
|
|
|
GetEmbeddingTableParamPtr();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -464,6 +471,7 @@ template <typename T>
|
|
|
|
|
void ParameterServer<T>::Finalize() {
|
|
|
|
|
running_ = false;
|
|
|
|
|
apply_grads_cv_.notify_one();
|
|
|
|
|
SyncEmbeddingTables();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -647,6 +655,53 @@ inline std::mutex &ParameterServer<T>::mutex() {
|
|
|
|
|
return mutex_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ParameterServer<T>::GetEmbeddingTableParamPtr() {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph_);
|
|
|
|
|
auto cnodes = func_graph_->GetOrderedCnodes();
|
|
|
|
|
Key count = 0;
|
|
|
|
|
for (auto cnode : cnodes) {
|
|
|
|
|
std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
|
|
|
|
|
if (cnode_name == kEmbeddingLookupOpName) {
|
|
|
|
|
auto embedding_table = AnfAlgo::GetInputNode(cnode, 0);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(embedding_table);
|
|
|
|
|
MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count;
|
|
|
|
|
embedding_tables_.insert(std::make_pair(count, embedding_table->cast<ParameterPtr>()));
|
|
|
|
|
count++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ParameterServer<T>::SyncEmbeddingTables() {
|
|
|
|
|
for (auto embedding_table : embedding_tables_) {
|
|
|
|
|
Key key = embedding_table.first;
|
|
|
|
|
if (embedding_lookup_ops_.count(key) == 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Can't find look up PS kernel for key " << key;
|
|
|
|
|
}
|
|
|
|
|
auto lookup = embedding_lookup_ops_[key];
|
|
|
|
|
const std::vector<size_t> &input_shapes = lookup->input_sizes();
|
|
|
|
|
std::vector<int> new_tensor_shape(input_shapes.begin(), input_shapes.end());
|
|
|
|
|
|
|
|
|
|
tensor::TensorPtr new_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, new_tensor_shape);
|
|
|
|
|
float *new_tensor_data_ptr = reinterpret_cast<float *>(new_tensor->data_c());
|
|
|
|
|
size_t new_tensor_size = static_cast<size_t>(new_tensor->data().nbytes());
|
|
|
|
|
size_t embedding_table_size = weights_[key]->size() * sizeof(float);
|
|
|
|
|
if (new_tensor_size != embedding_table_size) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Shape of embedding table can't match. New tensor size:" << new_tensor_size
|
|
|
|
|
<< ", embedding_table size:" << embedding_table_size;
|
|
|
|
|
}
|
|
|
|
|
int ret = memcpy_s(new_tensor_data_ptr, new_tensor_size, weights_[key]->data(), embedding_table_size);
|
|
|
|
|
if (ret != 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto paramter_tensor_ptr = embedding_table.second->default_param();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(paramter_tensor_ptr);
|
|
|
|
|
paramter_tensor_ptr->cast<tensor::TensorPtr>()->AssignValue(*new_tensor);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
|
::ps::Start(0);
|
|
|
|
@ -657,7 +712,6 @@ void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
|
Init(func_graph);
|
|
|
|
|
thread_->join();
|
|
|
|
|
::ps::Finalize(0, true);
|
|
|
|
|
exit(1);
|
|
|
|
|
}
|
|
|
|
|
} // namespace ps
|
|
|
|
|
} // namespace parallel
|
|
|
|
|