|
|
@ -257,6 +257,7 @@ void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta
|
|
|
|
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
|
|
|
|
const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
|
|
|
|
std::unique_lock<std::mutex> lock(ps_->mutex());
|
|
|
|
std::unique_lock<std::mutex> lock(ps_->mutex());
|
|
|
|
const Key &key = req_data.keys[0];
|
|
|
|
const Key &key = req_data.keys[0];
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Initializing embedding table for key:" << key;
|
|
|
|
std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
|
|
|
|
std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
|
|
|
|
std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>();
|
|
|
|
std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>();
|
|
|
|
std::shared_ptr<std::vector<size_t>> input_shape = std::make_shared<std::vector<size_t>>();
|
|
|
|
std::shared_ptr<std::vector<size_t>> input_shape = std::make_shared<std::vector<size_t>>();
|
|
|
@ -348,6 +349,8 @@ void ParameterServer<T>::InitWeightKeyToOptims(const Key &key, const int &optim_
|
|
|
|
}
|
|
|
|
}
|
|
|
|
weight_key_to_optims_[key] = Util::optimizer_name(optim_id);
|
|
|
|
weight_key_to_optims_[key] = Util::optimizer_name(optim_id);
|
|
|
|
weight_key_to_optim_op_[key] = Util::optimizer_node_name(optim_id);
|
|
|
|
weight_key_to_optim_op_[key] = Util::optimizer_node_name(optim_id);
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Initializing optimizer id for key:" << key << ", optimizer name:" << weight_key_to_optims_[key]
|
|
|
|
|
|
|
|
<< ", optimizer op name:" << weight_key_to_optim_op_[key];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
@ -355,7 +358,7 @@ void ParameterServer<T>::InitOptimInputsShape(const Keys &keys, const Values &va
|
|
|
|
InputsShapePtr inputs_shape = std::make_shared<InputsShape>();
|
|
|
|
InputsShapePtr inputs_shape = std::make_shared<InputsShape>();
|
|
|
|
int val_idx = 0;
|
|
|
|
int val_idx = 0;
|
|
|
|
const Key &key = keys[0];
|
|
|
|
const Key &key = keys[0];
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Initializing optimizer inputs shape for key:" << key;
|
|
|
|
if (optim_inputs_shape_.count(key) == 0) {
|
|
|
|
if (optim_inputs_shape_.count(key) == 0) {
|
|
|
|
optim_inputs_shape_[key] = inputs_shape;
|
|
|
|
optim_inputs_shape_[key] = inputs_shape;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -413,7 +416,7 @@ const CNodePtr ParameterServer<T>::GetCNode(const std::string &name) const {
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
void ParameterServer<T>::InitWeight(const Key &key, const WeightPtr &weight) {
|
|
|
|
void ParameterServer<T>::InitWeight(const Key &key, const WeightPtr &weight) {
|
|
|
|
MS_LOG(INFO) << "Initializing weight for key " << key;
|
|
|
|
MS_LOG(INFO) << "Initializing weight for key " << key << ", server rank " << rank_id_;
|
|
|
|
if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) {
|
|
|
|
if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) {
|
|
|
|
weights_[key] = weight;
|
|
|
|
weights_[key] = weight;
|
|
|
|
tokens_[key] = 0;
|
|
|
|
tokens_[key] = 0;
|
|
|
@ -432,7 +435,6 @@ void ParameterServer<T>::InitGrad(const Key &key, const GradPtr &grad) {
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
void ParameterServer<T>::InitEmbeddingTable(
|
|
|
|
void ParameterServer<T>::InitEmbeddingTable(
|
|
|
|
const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {
|
|
|
|
const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) {
|
|
|
|
MS_LOG(INFO) << "Initializing embedding table for key " << key;
|
|
|
|
|
|
|
|
std::shared_ptr<PServerKernel> lookup = std::make_shared<kernel::ps::EmbeddingLookUpPSKernel>(rank_id_, pserver_num_);
|
|
|
|
std::shared_ptr<PServerKernel> lookup = std::make_shared<kernel::ps::EmbeddingLookUpPSKernel>(rank_id_, pserver_num_);
|
|
|
|
lookup->InitKernel(shapes);
|
|
|
|
lookup->InitKernel(shapes);
|
|
|
|
embedding_lookup_ops_[key] = lookup;
|
|
|
|
embedding_lookup_ops_[key] = lookup;
|
|
|
|