ps cache support sparse

pull/11735/head
lizhenyu 4 years ago
parent 424e68a803
commit f17534af08

@ -53,7 +53,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(prim::kPrimReduceAny->name(), {1});
Register(prim::kPrimUnsortedSegmentMin->name(), {2});
Register(prim::kPrimUnsortedSegmentMax->name(), {2});
Register(kSparseGatherV2, {2});
Register(kSparseGatherV2OpName, {2});
Register(kUnsortedSegmentProdOpName, {2});
Register(kSimpleMeanGradOpName, {1});
Register(kMeanGradOpName, {1});
@ -109,7 +109,7 @@ bool ConstInputToAttrInfoRegistry::GetRegisterByOpName(const std::string &op_nam
ConstInputToAttrInfoRegister *reg) const {
if (op_input_to_attr_map_.find(op_name) != op_input_to_attr_map_.end()) {
*reg = op_input_to_attr_map_.at(op_name);
MS_LOG(DEBUG) << op_name << " const2attr find in registery.";
MS_LOG(DEBUG) << op_name << " const2attr find in registry.";
return true;
}
return false;

@ -31,15 +31,22 @@ std::string GetOpPythonPath(const OperatorName &op_name) {
// almost all ops are defined in two main paths
const std::string ops_module = OP_PATH;
const std::string inner_ops_module = INNER_OP_PATH;
const std::string functional_op_module = FUNCTIONAL_OP_PATH;
py::module mod = py::module::import(common::SafeCStr(ops_module));
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) {
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name;
}
py::module functional_mod = py::module::import(common::SafeCStr(functional_op_module));
if (py::hasattr(inner_mod, common::SafeCStr(op_name))) {
return inner_ops_module;
}
if (py::hasattr(mod, common::SafeCStr(op_name))) {
return ops_module;
}
return inner_ops_module;
if (!py::hasattr(functional_mod, common::SafeCStr(op_name))) {
MS_LOG(EXCEPTION) << ops_module << " and " << inner_ops_module << " and " << functional_op_module
<< " don't have op:" << op_name;
}
return functional_op_module;
}
ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
@ -141,7 +148,7 @@ Status GenerateGraph::Init(const CNodePtr &cnode) {
}
AnfNodePtr GenerateGraph::PushBack(const std::vector<AnfNodePtr> &inputs) {
CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to creat anfnode
CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to create anfnode
MS_EXCEPTION_IF_NULL(cnode);
cnode->set_scope(scope_);
if (inputs.size() < 2) {

@ -24,8 +24,10 @@
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/graph_util/generate_graph.h"
#include "frontend/parallel/context.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "ps/ps_cache/ps_cache_manager.h"
#include "utils/ms_context.h"
#endif
namespace mindspore {
@ -158,6 +160,15 @@ Status GatherV2PInfo::GetAttrs() {
if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) {
dynamic_shape_indices_ = true;
}
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
bool enable_sparse = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_SPARSE);
if (ps::PsDataPrefetch::GetInstance().cache_enable() && enable_sparse) {
dynamic_shape_indices_ = true;
}
#endif
return SUCCESS;
}
@ -531,7 +542,7 @@ Status GatherV2PInfo::InferBias() {
}
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
bias_ = 0;
bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
return SUCCESS;
}
#endif

@ -68,6 +68,7 @@ constexpr char REDUCE_OP_MAX[] = "max";
constexpr char REDUCE_OP_MIN[] = "min";
constexpr char OP_PATH[] = "mindspore.ops.operations";
constexpr char INNER_OP_PATH[] = "mindspore.ops.operations._inner_ops";
constexpr char FUNCTIONAL_OP_PATH[] = "mindspore.ops.functional";
constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils";
constexpr char GET_OP_FUNCTION[] = "_get_python_op";
constexpr char KEEP_DIMS[] = "keep_dims";

@ -23,9 +23,13 @@
#include "ir/value.h"
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/graph_util/generate_graph.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/ps_cache/ps_cache_manager.h"
#endif
namespace mindspore {
namespace parallel {
@ -186,5 +190,63 @@ Status UniqueInfo::GenerateStrategies(int64_t stage_id) {
}
return SUCCESS;
}
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
GenerateGraph gen_g = GenerateGraph();
if (gen_g.Init(cnode) != SUCCESS) {
MS_LOG(ERROR) << "GenerateGraph Init failed";
return FAILED;
}
auto bias = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
auto slice_size = SizeToLong(ps::PsCacheManager::GetInstance().vocab_cache_size());
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias)});
auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub});
auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size - 1)});
auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum});
auto unique = gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node()});
auto tuple_getitem_0 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), unique, CreatInt64Imm(0)});
auto tuple_getitem_1 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), unique, CreatInt64Imm(1)});
auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), tuple_getitem_1});
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype});
auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), tuple_getitem_1, cast});
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
OperatorAttrs attrs = {attr_op};
AnfNodePtr reduce_op;
reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul});
auto make_tuple = gen_g.PushBack({gen_g.NewOpInst(MAKE_TUPLE), tuple_getitem_0, reduce_op});
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(sub, 1), std::make_pair(unique, 1)};
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
std::make_pair(input_nodes, make_tuple));
return SUCCESS;
}
#endif
ReplaceGraphPtr UniqueInfo::replace_graph(const CNodePtr &cnode) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
auto inputs = cnode->inputs();
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Invalid inputs";
}
const auto &primitive = GetValueNode<PrimitivePtr>(inputs[0]);
const auto &attr = primitive->GetAttr("cache_enable");
if (attr == nullptr) {
return nullptr;
}
auto need_mask = GetValue<bool>(attr);
if (!need_mask) {
return nullptr;
}
if (ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
}
return replace_graph_;
}
#endif
return nullptr;
}
} // namespace parallel
} // namespace mindspore

@ -39,6 +39,7 @@ class UniqueInfo : public OperatorInfo {
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
Status GenerateStrategies(int64_t stage_id) override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
@ -50,8 +51,12 @@ class UniqueInfo : public OperatorInfo {
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferAsLossDivisor() override { return SUCCESS; }
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
Status ComputeReplaceGraph(const CNodePtr &cnode);
#endif
private:
std::string replace_op_name_ = UNIQUE;
int64_t dev_num_ = 1;
};
} // namespace parallel

@ -321,7 +321,8 @@ PYBIND11_MODULE(_c_expression, m) {
.def("insert_weight_init_info", &PSContext::InsertWeightInitInfo, "Insert embedding table initialization seed.")
.def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.")
.def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.")
.def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.");
.def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.")
.def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode.");
(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
.def(py::init())

@ -773,12 +773,14 @@ void ParameterServer<T>::GetEmbeddingTableParamPtr() {
for (auto cnode : cnodes) {
MS_EXCEPTION_IF_NULL(cnode);
std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName) {
if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName || cnode_name == kSparseGatherV2OpName) {
auto embedding_table = AnfAlgo::GetInputNode(cnode, 0);
MS_EXCEPTION_IF_NULL(embedding_table);
MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count;
embedding_tables_.insert(std::make_pair(count, embedding_table->cast<ParameterPtr>()));
count++;
if (embedding_table->isa<Parameter>()) {
MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count;
embedding_tables_.insert(std::make_pair(count, embedding_table->cast<ParameterPtr>()));
count++;
}
}
}
}

@ -35,11 +35,11 @@ void PsCacheManager::InsertHashTableSize(const std::string &param_name, size_t c
if (vocab_size_ == 0) {
vocab_size_ = vocab_size;
}
if (cache_vocab_size_ == 0) {
cache_vocab_size_ = cache_vocab_size;
if (vocab_cache_size_ == 0) {
vocab_cache_size_ = cache_vocab_size;
}
if (host_cache_vocab_size_ == 0) {
host_cache_vocab_size_ = cache_vocab_size * kHostCacheScaleFactor;
if (host_vocab_cache_size_ == 0) {
host_vocab_cache_size_ = cache_vocab_size * kHostCacheScaleFactor;
}
}
@ -148,8 +148,8 @@ void PsCacheManager::Initialize() {
Util::SetInternalEnvVar();
worker.Run();
}
embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, cache_vocab_size_);
embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_cache_vocab_size_);
embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, vocab_cache_size_);
embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_vocab_cache_size_);
AddEmbeddingTable();
AllocMemForHashTable();
SetLocalIdRank();
@ -220,13 +220,13 @@ void PsCacheManager::AllocMemForHashTable() {
for (auto &item : hash_tables_) {
size_t embedding_size = item.second.embedding_size;
auto &device_address = item.second.device_address;
device_address.size = cache_vocab_size_ * embedding_size * sizeof(float);
device_address.size = vocab_cache_size_ * embedding_size * sizeof(float);
auto addr = embedding_device_cache_->cache_->MallocMemory(device_address.size);
MS_EXCEPTION_IF_NULL(addr);
device_address.addr = addr;
auto &host_address = item.second.host_address;
auto host_address_ptr = new float[host_cache_vocab_size_ * embedding_size];
auto host_address_ptr = new float[host_vocab_cache_size_ * embedding_size];
MS_EXCEPTION_IF_NULL(host_address_ptr);
host_address = std::shared_ptr<float[]>(host_address_ptr, std::default_delete<float[]>());
MS_EXCEPTION_IF_NULL(host_address);
@ -239,21 +239,28 @@ void PsCacheManager::AllocMemForHashTable() {
embedding_device_cache_->hash_swap_value_addr_ = reinterpret_cast<float *>(
embedding_device_cache_->cache_->MallocMemory(max_embedding_size * batch_elements_ * sizeof(float)));
MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_);
if (!(embedding_device_cache_->cache_->MallocConstantMemory(cache_vocab_size_))) {
if (!(embedding_device_cache_->cache_->MallocConstantMemory(vocab_cache_size_))) {
MS_LOG(EXCEPTION) << "MallocConstantMemory failed.";
}
}
void PsCacheManager::SetLocalIdRank() {
auto worker_num = ::ps::NumWorkers();
auto worker_id = ::ps::MyRank();
auto local_shard_size = FloatToSize(std::ceil(SizeToFloat(vocab_size_) / worker_num));
range_bound_.first = local_shard_size * worker_id;
range_bound_.second = std::min(range_bound_.first + local_shard_size, vocab_size_);
MS_LOG(INFO) << "Worker num:" << worker_num << ", worker id:" << worker_id << ", rank id begin:" << range_bound_.first
<< ", rank id end:" << range_bound_.second;
auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num));
vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_);
emb_table_slice_bounds_.first = local_shard_size * rank_id_;
emb_table_slice_bounds_.second = std::min(emb_table_slice_bounds_.first + local_shard_size, SizeToInt(vocab_size_));
cache_indices_bounds_.first = SizeToInt(vocab_cache_size_) * rank_id_;
cache_indices_bounds_.second = cache_indices_bounds_.first + SizeToInt(vocab_cache_size_);
MS_LOG(INFO) << "Worker num:" << worker_num << ", rank id:" << rank_id_
<< ", id begin:" << emb_table_slice_bounds_.first << ", id end:" << emb_table_slice_bounds_.second
<< ", cache indices begin: " << cache_indices_bounds_.first
<< ", cache indices end: " << cache_indices_bounds_.second
<< ", vocab_cache_size_diff: " << vocab_cache_size_diff_;
}
int PsCacheManager::cache_indices_lower_bound() const { return cache_indices_bounds_.first; }
std::string PsCacheManager::channel_name() {
std::lock_guard<std::mutex> locker(channel_mutex_);
return channel_name_;
@ -398,8 +405,8 @@ bool PsCacheManager::ProcessData() {
return true;
}
bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
bool *in_device, size_t *hash_hit_count) {
bool PsCacheManager::CheckCacheHitOrOutRangeTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
bool *in_device, bool *out_range, size_t *hash_hit_count) {
MS_ERROR_IF_NULL(batch_ids);
MS_ERROR_IF_NULL(hash_index);
MS_ERROR_IF_NULL(in_device);
@ -410,9 +417,19 @@ bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batc
const auto &hash_id_to_index = device_hash_map->hash_id_to_index();
for (size_t i = 0; i < batch_ids_len; ++i) {
if (batch_ids[i] < emb_table_slice_bounds_.first) {
hash_index[i] = batch_ids[i] - vocab_cache_size_diff_;
out_range[i] = true;
continue;
}
if (batch_ids[i] >= emb_table_slice_bounds_.second) {
hash_index[i] = batch_ids[i] + cache_indices_bounds_.second;
out_range[i] = true;
continue;
}
auto iter = hash_id_to_index.find(batch_ids[i]);
if (iter != hash_id_to_index.end()) {
hash_index[i] = iter->second;
hash_index[i] = iter->second + cache_indices_bounds_.first;
if (device_hash_map->hash_step(iter->second) != data_step_) {
++(*hash_hit_count);
device_hash_map->set_hash_step(iter->second, data_step_);
@ -423,11 +440,12 @@ bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batc
return true;
}
bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
bool *in_device) {
bool PsCacheManager::CheckCacheHitOrOutRange(const int *batch_ids, const size_t batch_ids_len, int *hash_index,
bool *in_device, bool *out_range) {
MS_ERROR_IF_NULL(batch_ids);
MS_ERROR_IF_NULL(hash_index);
MS_ERROR_IF_NULL(in_device);
MS_ERROR_IF_NULL(out_range);
size_t thread_num = batch_ids_len / kMinIdsPerThread + 1;
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
@ -441,8 +459,9 @@ bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_id
break;
}
size_t task_proc_lens = batch_ids_len / thread_num + (i < (batch_ids_len % thread_num) ? 1 : 0);
threads[i] = std::thread(&PsCacheManager::CheckIDInDeviceTask, this, batch_ids + task_offset, task_proc_lens,
hash_index + task_offset, in_device + task_offset, hash_hit_count + i);
threads[i] =
std::thread(&PsCacheManager::CheckCacheHitOrOutRangeTask, this, batch_ids + task_offset, task_proc_lens,
hash_index + task_offset, in_device + task_offset, out_range + task_offset, hash_hit_count + i);
task_offset += task_proc_lens;
}
if (task_offset != batch_ids_len) {
@ -477,27 +496,26 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len,
MS_ERROR_IF_NULL(hash_index);
statistics_info_.batch_id_count_ = batch_ids_len;
std::unique_ptr<bool[]> in_device(new bool[batch_ids_len]);
std::unique_ptr<bool[]> out_range(new bool[batch_ids_len]);
if (memset_s(in_device.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) {
MS_LOG(EXCEPTION) << "Data in device memset failed.";
MS_LOG(EXCEPTION) << "Initialize in_device array failed.";
}
if (memset_s(out_range.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) {
MS_LOG(EXCEPTION) << "Initialize out_range array failed.";
}
CheckIDInDevice(batch_ids, batch_ids_len, hash_index, in_device.get());
RETURN_IF_FALSE(CheckCacheHitOrOutRange(batch_ids, batch_ids_len, hash_index, in_device.get(), out_range.get()));
RETURN_IF_FALSE(ResetEmbeddingHashMap());
for (size_t i = 0; i < batch_ids_len; i++) {
if (in_device[i]) {
if (in_device[i] || out_range[i]) {
continue;
}
bool need_swap_host_to_device = true;
bool need_swap_device_to_host = true;
auto id = batch_ids[i];
if ((id < SizeToInt(range_bound_.first)) || (id >= SizeToInt(range_bound_.second))) {
hash_index[i] = -1;
continue;
}
int index = INVALID_INDEX_VALUE;
RETURN_IF_FALSE(ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device, &index));
hash_index[i] = index;
RETURN_IF_FALSE(ParseDeviceData(batch_ids[i], &need_swap_device_to_host, &need_swap_host_to_device, &index));
hash_index[i] = index + cache_indices_bounds_.first;
if (need_swap_host_to_device) {
RETURN_IF_FALSE(ParseHostDataHostToDevice(id));
RETURN_IF_FALSE(ParseHostDataHostToDevice(batch_ids[i]));
}
if (need_swap_device_to_host) {
RETURN_IF_FALSE(ParseHostDataDeviceToHost());
@ -667,7 +685,7 @@ void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size,
bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
const int *indices_addr, float *output_addr) {
size_t first_dim_size = host_cache_vocab_size_;
size_t first_dim_size = host_vocab_cache_size_;
size_t outer_dim_size = embedding_size;
size_t thread_num = indices_lens / 10000 + 1;
@ -697,7 +715,7 @@ bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l
bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices,
float *insert_data, float *hash_table_addr) {
size_t first_dim_size = host_cache_vocab_size_;
size_t first_dim_size = host_vocab_cache_size_;
size_t thread_num = insert_indices_size / 10000 + 1;
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num;
std::thread threads[kMaxThreadNum];

@ -125,7 +125,10 @@ class PsCacheManager {
const size_t &QueryHashTableSize(const std::string &param_name) const;
bool IsHashTable(const std::string &param_name) { return hash_tables_.count(param_name) != 0; }
void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; }
void set_rank_id(int rank_id) { rank_id_ = rank_id; }
bool initialized_ps_cache() const { return initialized_ps_cache_; }
size_t vocab_cache_size() const { return vocab_cache_size_; }
int cache_indices_lower_bound() const;
void DoProcessData(uint32_t device_id, void *context);
void IncreaseGraphStep(const std::string &channel_name);
void SyncEmbeddingTable();
@ -170,10 +173,12 @@ class PsCacheManager {
void DumpStatisticsInfo(size_t each_print_step = 1000);
bool SyncHostEmbeddingTable();
bool SyncDeviceEmbeddingTable();
bool CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device,
size_t *hash_hit_count);
bool CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device);
bool CheckCacheHitOrOutRangeTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device,
bool *out_range, size_t *hash_hit_count);
bool CheckCacheHitOrOutRange(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device,
bool *out_range);
bool ResetEmbeddingHashMap();
bool initialized_ps_cache_{false};
std::string channel_name_;
std::mutex channel_mutex_;
@ -190,11 +195,14 @@ class PsCacheManager {
std::shared_ptr<EmbeddingHostCache> embedding_host_cache_;
size_t vocab_size_{0};
size_t cache_vocab_size_{0};
size_t host_cache_vocab_size_{0};
size_t vocab_cache_size_{0};
size_t host_vocab_cache_size_{0};
size_t batch_elements_{0};
PsCacheStatisticsInfo statistics_info_;
std::pair<size_t, size_t> range_bound_;
std::pair<int, int> emb_table_slice_bounds_;
std::pair<int, int> cache_indices_bounds_;
int vocab_cache_size_diff_{0};
int rank_id_{0};
std::atomic_bool finish_insert_init_info_{false};
std::atomic_bool finish_init_parameter_server_{false};
std::atomic_bool running_{false};

@ -129,5 +129,11 @@ void PSContext::set_cache_enable(bool cache_enable) const {
PsDataPrefetch::GetInstance().set_cache_enable(cache_enable);
#endif
}
void PSContext::set_rank_id(int rank_id) const {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
ps_cache_instance.set_rank_id(rank_id);
#endif
}
} // namespace ps
} // namespace mindspore

@ -52,6 +52,7 @@ class PSContext {
void InsertAccumuInitInfo(const std::string &param_name, float init_val) const;
void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const;
void set_cache_enable(bool cache_enable) const;
void set_rank_id(int rank_id) const;
private:
PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), rank_id_(-1) {}

@ -391,7 +391,7 @@ bool AscendKernelRuntime::GenDynamicKernel(const session::KernelGraph *graph) {
bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
InnerSetContext();
if (graph->is_dynamic_shape()) {
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE && (ConfigManager::GetInstance().iter_num() > 1)) {
MS_LOG(EXCEPTION) << "Dynamic shape is not supported with sink mode.";
}
if (DumpJsonParser::GetInstance().async_dump_enabled()) {

@ -851,7 +851,7 @@ void GPUKernelRuntime::UpdateHostSwapInQueue(const DeviceAddressPtr device_addre
MS_LOG(WARNING) << "Unexpected device address status: " << status;
break;
default:
MS_LOG(EXCEPTION) << "Invaild device address status: " << status;
MS_LOG(EXCEPTION) << "Invalid device address status: " << status;
}
}
@ -1092,6 +1092,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel)
MS_EXCEPTION_IF_NULL(mem_reuse_util_);
auto cnode = kernel->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// Can not free the input addr of communication op when enable multi stream
if (AnfAlgo::IsCommunicationOp(kernel)) {
return;
}
@ -1106,7 +1107,9 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel)
}
auto kernel_with_index = GetPrevNodeOutput(kernel, i);
if (AnfAlgo::IsCommunicationOp(kernel_with_index.first)) {
// Maintain output addr of fused communication op to improve training performance
if (AnfAlgo::IsCommunicationOp(kernel_with_index.first) &&
AnfAlgo::GetInputTensorNum(kernel_with_index.first) > 1) {
continue;
}

@ -1049,7 +1049,8 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
MS_EXCEPTION_IF_NULL(graph);
for (const auto &kernel : graph->execution_order()) {
MS_EXCEPTION_IF_NULL(kernel);
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") {
auto kernel_name = AnfAlgo::GetCNodeName(kernel);
if (kernel_name != kGatherV2OpName && kernel_name != kSparseGatherV2OpName) {
continue;
}
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
@ -1061,13 +1062,15 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
continue;
}
auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true);
while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
MS_EXCEPTION_IF_NULL(input_index.first);
}
if (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) {
auto input_index_node_name = AnfAlgo::GetCNodeName(input_index.first);
if (input_index.first->isa<CNode>() && (input_index_node_name != kGetNextOpName)) {
bool full_batch = parallel::ParallelContext::GetInstance()->full_batch();
if ((!full_batch) || (AnfAlgo::GetCNodeName(input_index.first) != "Minimum")) {
if ((!full_batch && (input_index_node_name != kUniqueOpName)) ||
(full_batch && (input_index_node_name != kMinimumOpName))) {
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope()
<< ") cache is from " << input_index.first->fullname_with_scope();
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
@ -1082,6 +1085,28 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
}
}
void KernelRuntime::CheckSparsePSEmbeddingCache(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto pre_node = AnfAlgo::GetPrevNodeOutput(node, 1, true);
while (pre_node.first->isa<CNode>() && (AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) {
pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
MS_EXCEPTION_IF_NULL(pre_node.first);
}
if (!(pre_node.first->isa<CNode>()) || (AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) {
MS_LOG(EXCEPTION) << "The input_indices of kernel[SparseGatherV2] must be unique in parameter server cache mode";
}
pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
while (pre_node.first->isa<CNode>() && (AnfAlgo::GetCNodeName(pre_node.first) == kCastOpName)) {
pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
MS_EXCEPTION_IF_NULL(pre_node.first);
}
if (!(pre_node.first->isa<CNode>()) || (AnfAlgo::GetCNodeName(pre_node.first) != kGetNextOpName)) {
MS_LOG(EXCEPTION) << "The input indices of kernel[Unique] must be produced from dataset directly and the indices "
"value can not be changed before delivering to kernel[Unique] in parameter server cache mode.";
}
}
void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
AnfNodePtr first_cache_input_index = nullptr;
@ -1090,16 +1115,23 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *g
MS_EXCEPTION_IF_NULL(first_cache_input_index);
for (const auto &kernel : graph->execution_order()) {
MS_EXCEPTION_IF_NULL(kernel);
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") {
auto kernel_name = AnfAlgo::GetCNodeName(kernel);
if (kernel_name != kGatherV2OpName && kernel_name != kSparseGatherV2OpName) {
continue;
}
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
MS_EXCEPTION_IF_NULL(input_param.first);
MS_EXCEPTION_IF_NULL(input_index.first);
if (!input_param.first->isa<Parameter>()) {
continue;
}
auto param_name = input_param.first->fullname_with_scope();
while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true);
if (ps::ps_cache_instance.IsHashTable(param_name) && (kernel_name == kSparseGatherV2OpName)) {
CheckSparsePSEmbeddingCache(kernel);
}
while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
MS_EXCEPTION_IF_NULL(input_index.first);
}
if (input_index.first == first_cache_input_index) {

@ -138,6 +138,7 @@ class KernelRuntime {
void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *first_cache_input_index,
size_t *first_cache_size);
void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph);
void CheckSparsePSEmbeddingCache(const CNodePtr &node);
#endif
protected:

@ -83,7 +83,7 @@ constexpr auto kScatterNdOpName = "ScatterNd";
constexpr auto kStridedSliceAssignOpName = "StridedSliceAssign";
constexpr auto kStridedSliceOpName = "StridedSlice";
constexpr auto kStridedSliceGradOpName = "StridedSliceGrad";
constexpr auto kSparseGatherV2 = "SparseGatherV2";
constexpr auto kSparseGatherV2OpName = "SparseGatherV2";
constexpr auto kUnsortedSegmentProdOpName = "UnsortedSegmentProd";
constexpr auto kUnsortedSegmentMinOpName = "UnsortedSegmentMin";
constexpr auto kFlattenGradOpName = "FlattenGrad";

@ -73,6 +73,13 @@ inline size_t FloatToSize(float u) {
}
inline float IntToFloat(int32_t v) { return static_cast<float>(v); }
inline int FloatToInt(float u) {
if (u > static_cast<float>((std::numeric_limits<int>::max)())) {
MS_LOG(EXCEPTION) << "The float value(" << u << ") exceeds the maximum value of int.";
}
return static_cast<int>(u);
}
inline float SizeToFloat(size_t v) { return static_cast<float>(v); }
inline double LongToDouble(int64_t v) { return static_cast<double>(v); }

@ -20,10 +20,12 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.communication.management import get_group_size
from mindspore.context import ParallelMode, get_context
from mindspore.communication.management import get_group_size, get_rank
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker, _get_ps_context
from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context
from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _set_rank_id
from mindspore import context
from mindspore._checkparam import Rel
from mindspore._checkparam import Validator as validator
from mindspore.ops.primitive import constexpr
@ -227,8 +229,6 @@ class EmbeddingLookup(Cell):
self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size')
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
name='embedding_table')
if self.cache_enable and enable_ps:
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size)
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.gather_revert = P.GatherV2()
@ -238,6 +238,10 @@ class EmbeddingLookup(Cell):
self.shape = P.Shape()
if is_auto_parallel:
self.unique = P.Unique().shard(((1,),))
if self.cache_enable and enable_ps:
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size)
if is_auto_parallel:
self.unique.add_prim_attr('cache_enable', True)
indices_shape_size = 2
if slice_mode == "field_slice" and is_auto_parallel:
if not manual_shapes:
@ -252,7 +256,7 @@ class EmbeddingLookup(Cell):
self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
elif slice_mode == "table_row_slice" and is_auto_parallel:
full_batch = _get_full_batch()
if target == 'DEVICE' and not full_batch:
if (target == 'DEVICE' and not full_batch) or (self.cache_enable and enable_ps and sparse):
indices_shape_size = 1
self.gather_revert.shard(((1, 1), (get_group_size(),)))
self.forward_unique = True
@ -293,7 +297,7 @@ class EmbeddingLookup(Cell):
raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target.")
if not self.sparse:
raise ValueError("The configuration of 'vocab_cache_size' is valid only 'sparse' is true.")
if get_context("device_target") != 'Ascend':
if context.get_context("device_target") != 'Ascend':
raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'ascend'.")
logger.info("EmbeddingLookup cache enable takes effect.")
@ -320,21 +324,29 @@ class EmbeddingLookup(Cell):
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
if is_auto_parallel:
device_num = get_group_size()
rank_size = get_group_size()
rank_id = get_rank()
full_batch = _get_full_batch()
if device_num > 1 and not (full_batch and slice_mode == "table_row_slice"):
if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"):
raise ValueError("The embeddingLookup cache of parameter server parallel only be used "
"in 'full_batch' and 'table_row_slice' parallel strategy.")
self.vocab_cache_size = self.vocab_cache_size * device_num
self.vocab_cache_size = self.vocab_cache_size * rank_size
_set_rank_id(rank_id)
self.cache_enable = True
if _is_role_worker():
self.vocab_size = self.vocab_cache_size
if context.get_context("enable_sparse") != self.sparse:
raise ValueError("The value of parameter 'sparse' must be same for all EmbeddingLookup "
"kernels and equal the value of 'enable_sparse' in context setting in "
"parameter server cache mode")
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
"""PS embeddingLookup cache enable set."""
self.embedding_table.cache_enable = True
self.embedding_table.is_param_ps = True
_set_cache_enable(True)
if self.sparse:
self.forward_unique = True
if _is_role_worker():
_insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)

@ -28,14 +28,15 @@ _lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt")
@_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool")
"Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool",
"Bool")
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, beta2_power,
beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter):
beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter, cache_enable):
"""Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse."""
success = True
indices = gradient.indices
values = gradient.values
if ps_parameter:
if ps_parameter and not cache_enable:
op_shape = P.Shape()
shapes = (op_shape(params), op_shape(m), op_shape(v),
op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
@ -75,12 +76,12 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
@_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter):
"Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, beta2_power,
beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter, cache_enable):
"""Apply lazy adam optimizer to the weight parameter using Tensor."""
success = True
if ps_parameter:
if ps_parameter and not cache_enable:
op_shape = P.Shape()
success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
(op_shape(params), op_shape(moment1), op_shape(moment2))), params))
@ -245,12 +246,14 @@ class LazyAdam(Optimizer):
success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.use_locking, self.use_nesterov, self._is_device,
self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps),
lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters)
lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters,
self.cache_enable)
else:
success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
self.use_locking, self.use_nesterov, self._is_device,
self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps, lr),
gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters)
gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters,
self.cache_enable)
return success
@Optimizer.target.setter

@ -142,3 +142,6 @@ def _set_cache_enable(cache_enable):
os.environ['GOTO_NUM_THREADS'] = '2'
os.environ['OMP_NUM_THREADS'] = '2'
ps_context().set_cache_enable(cache_enable)
def _set_rank_id(rank_id):
ps_context().set_rank_id(rank_id)

@ -190,7 +190,10 @@ def _get_python_op(op_name, op_path, instance_name, arglist):
"""Get python operator."""
module = __import__(op_path, fromlist=["None"])
cls = getattr(module, op_name)
op = cls(*arglist)
if op_path != "mindspore.ops.functional":
op = cls(*arglist)
else:
op = cls
op.set_prim_instance_name(instance_name)
return op

@ -17,7 +17,8 @@
#bash run_parameter_server_train_cluster.sh RANK_SIZE EPOCHS DEVICE_TARGET DATASET
# LOCAL_WORKER_NUM LOCAL_SERVER_NUM SERVER_NUM
# SCHED_HOST SCHED_PORT ROLE RANK_TABLE_FILE VOCAB_CACHE_SIZE
# SCHED_HOST SCHED_PORT ROLE RANK_TABLE_FILE
# VOCAB_CACHE_SIZE SPARSE
execute_path=$(pwd)
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
@ -37,11 +38,16 @@ export MS_SCHED_PORT=$9
export MS_ROLE=${10}
export RANK_TABLE_FILE=${11}
export VOCAB_CACHE_SIZE=${12}
export SPARSE=${13}
if [[ ! -n "${12}" ]]; then
export VOCAB_CACHE_SIZE=0
fi
if [[ ! -n "${13}" ]]; then
export SPARSE=0
fi
echo "=====Role is $MS_ROLE======"
if [[ "$MS_ROLE" == "MS_SCHED" ]]; then
@ -73,7 +79,7 @@ if [[ "$MS_ROLE" == "MS_WORKER" ]]; then
mpirun --allow-run-as-root -n $LOCAL_WORKER_NUM --output-filename log_output --merge-stderr-to-stdout \
python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \
--device_target=$DEVICE --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
--vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker.log 2>&1 &
--vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker.log 2>&1 &
else
for((i=0;i<$LOCAL_WORKER_NUM;i++));
do
@ -84,7 +90,7 @@ if [[ "$MS_ROLE" == "MS_WORKER" ]]; then
export DEVICE_ID=$i
python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \
--device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
--vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker_$i.log 2>&1 &
--vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker_$i.log 2>&1 &
done
fi
fi

@ -17,7 +17,7 @@
#bash run_parameter_server_train_distribute.sh RANK_SIZE EPOCHS DEVICE_TARGET DATASET
# SERVER_NUM SCHED_HOST SCHED_PORT RANK_TABLE_FILE
# VOCAB_CACHE_SIZE
# VOCAB_CACHE_SIZE SPARSE
execute_path=$(pwd)
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
@ -33,11 +33,16 @@ export MS_SCHED_HOST=$6
export MS_SCHED_PORT=$7
export RANK_TABLE_FILE=$8
export VOCAB_CACHE_SIZE=$9
export SPARSE=${10}
if [[ ! -n "$9" ]]; then
export VOCAB_CACHE_SIZE=0
fi
if [[ ! -n "${10}" ]]; then
export SPARSE=0
fi
export MS_ROLE=MS_SCHED
rm -rf ${execute_path}/sched/
mkdir ${execute_path}/sched/
@ -65,7 +70,7 @@ if [[ "X$DEVICE_TARGET" == "XGPU" ]]; then
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \
--device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
--vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker.log 2>&1 &
--vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker.log 2>&1 &
else
for((i=0;i<$MS_WORKER_NUM;i++));
do
@ -76,7 +81,7 @@ else
export DEVICE_ID=$i
python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \
--device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
--vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker_$i.log 2>&1 &
--vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker_$i.log 2>&1 &
done
fi

@ -16,7 +16,7 @@
#bash run_parameter_server_train_standalone.sh EPOCHS DEVICE_TARGET DATASET SERVER_NUM SCHED_HOST
# SCHED_PORT DEVICE_ID VOCAB_CACHE_SIZE
# SCHED_PORT DEVICE_ID VOCAB_CACHE_SIZE SPARSE
execute_path=$(pwd)
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
@ -31,11 +31,16 @@ export MS_SCHED_HOST=$5
export MS_SCHED_PORT=$6
DEVICE_ID=$7
export VOCAB_CACHE_SIZE=$8
export SPARSE=$9
if [[ ! -n "$8" ]]; then
export VOCAB_CACHE_SIZE=0
fi
if [[ ! -n "$9" ]]; then
export SPARSE=0
fi
# Set device id
if [[ "X$DEVICE_TARGET" == "XGPU" ]]; then
if [[ ! -n "$DEVICE_ID" ]]; then
@ -76,4 +81,4 @@ mkdir ${execute_path}/worker/
cd ${execute_path}/worker/ || exit
python -s ${self_path}/../train_and_eval_parameter_server_standalone.py --device_target=$DEVICE_TARGET \
--epochs=$EPOCH_SIZE --data_path=$DATASET --parameter_server=1 \
--vocab_cache_size=$VOCAB_CACHE_SIZE --dropout_flag=1 >worker.log 2>&1 &
--vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=1 >worker.log 2>&1 &

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save