add ps cache check

pull/9957/head
limingqi107 4 years ago
parent 3258c6e76d
commit a844d52b42

@ -1813,6 +1813,11 @@ void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) co
void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
// PS embeddingLookup cache check.
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
MS_LOG(EXCEPTION) << "The other parameter cann't set ps mode when the embeddingLookup cache is enabled in "
"parameter server training mode.";
}
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return()); std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return());
for (auto &node : node_list) { for (auto &node : node_list) {
if (node != nullptr && node->isa<CNode>()) { if (node != nullptr && node->isa<CNode>()) {

@ -976,15 +976,12 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
backend->Link(runner.graph_id); backend->Link(runner.graph_id);
} }
// PS mode does not support loop sink. ConfigManager::GetInstance().set_iter_num(size);
// PS cache does not support loop sink.
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::Util::IsRoleOfWorker()) { if (ps::Util::IsRoleOfWorker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size)); ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size));
ConfigManager::GetInstance().set_iter_num(1); ConfigManager::GetInstance().set_iter_num(1);
} else {
#endif
ConfigManager::GetInstance().set_iter_num(size);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
} }
#endif #endif

@ -129,11 +129,19 @@ void PsCacheManager::CloneHashTable(const std::string &dest_param_name, const st
const Address &PsCacheManager::QueryHashTableAddr(const std::string &param_name) const { const Address &PsCacheManager::QueryHashTableAddr(const std::string &param_name) const {
auto iter = hash_tables_.find(param_name); auto iter = hash_tables_.find(param_name);
if (iter == hash_tables_.end()) { if (iter == hash_tables_.end()) {
MS_LOG(EXCEPTION) << "Can not find device_address of " << param_name; MS_LOG(EXCEPTION) << "Can not find device address of " << param_name;
} }
return iter->second.device_address; return iter->second.device_address;
} }
const size_t &PsCacheManager::QueryHashTableSize(const std::string &param_name) const {
auto iter = hash_tables_.find(param_name);
if (iter == hash_tables_.end()) {
MS_LOG(EXCEPTION) << "Can not find vocab cache size of " << param_name;
}
return iter->second.cache_vocab_size;
}
void PsCacheManager::Initialize() { void PsCacheManager::Initialize() {
MS_LOG(INFO) << "PS cache initialize."; MS_LOG(INFO) << "PS cache initialize.";
if (!worker.running()) { if (!worker.running()) {
@ -244,19 +252,19 @@ void PsCacheManager::set_channel_name(const std::string channel_name) {
void PsCacheManager::IncreaseStep() { void PsCacheManager::IncreaseStep() {
if (data_step_ >= UINT64_MAX) { if (data_step_ >= UINT64_MAX) {
MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t."; MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") will exceed the maximum value of uint64_t.";
} }
data_step_++; data_step_++;
set_current_graph_step(); set_current_graph_step();
if (graph_running_step_ > data_step_) { if (graph_running_step_ > data_step_) {
MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") << exceed the data step (" MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") exceed the data step (" << data_step_
<< data_step_ << ")."; << ").";
} }
} }
void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
if (graph_step_ >= UINT64_MAX) { if (graph_step_ >= UINT64_MAX) {
MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t."; MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t.";
} }
if (graph_step_ == 0) { if (graph_step_ == 0) {
MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_; MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_;
@ -271,8 +279,10 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
} }
void PsCacheManager::DoProcessData(uint32_t device_id, void *context) { void PsCacheManager::DoProcessData(uint32_t device_id, void *context) {
// PS embeddingLookup cache check.
if (!initialized_ps_cache_) { if (!initialized_ps_cache_) {
MS_LOG(EXCEPTION) << "PS cache does not init."; MS_LOG(EXCEPTION) << "Only the sink_mode of dataset supports embeddingLookup cache in parameter server training "
"mode, current dataset mode is not sink_mode.";
} }
auto process_data_thread = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context); auto process_data_thread = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context);
process_data_thread.detach(); process_data_thread.detach();

@ -120,6 +120,7 @@ class PsCacheManager {
size_t cache_vocab_size, size_t embedding_size); size_t cache_vocab_size, size_t embedding_size);
void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name); void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name);
const Address &QueryHashTableAddr(const std::string &param_name) const; const Address &QueryHashTableAddr(const std::string &param_name) const;
const size_t &QueryHashTableSize(const std::string &param_name) const;
bool IsHashTable(const std::string &param_name) { return hash_tables_.count(param_name) != 0; } bool IsHashTable(const std::string &param_name) { return hash_tables_.count(param_name) != 0; }
void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; } void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; }
bool initialized_ps_cache() const { return initialized_ps_cache_; } bool initialized_ps_cache() const { return initialized_ps_cache_; }

@ -325,7 +325,9 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
} }
need_alloc_nodes.push_back(item); need_alloc_nodes.push_back(item);
} }
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
bool ps_cache_check = false;
#endif
for (auto &item : need_alloc_nodes) { for (auto &item : need_alloc_nodes) {
auto output_size = AnfAlgo::GetOutputTensorNum(item); auto output_size = AnfAlgo::GetOutputTensorNum(item);
for (size_t index = 0; index < output_size; index++) { for (size_t index = 0; index < output_size; index++) {
@ -339,6 +341,13 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
const std::string &param_name = item->fullname_with_scope(); const std::string &param_name = item->fullname_with_scope();
if (ps::ps_cache_instance.IsHashTable(param_name)) { if (ps::ps_cache_instance.IsHashTable(param_name)) {
MS_LOG(INFO) << "Parameter(" << param_name << ")"
<< " enables the embeddingLookup cache in parameter server training mode.";
// PS embeddingLookup cache check.
if (!ps_cache_check) {
CheckIfSupportPSEmbeddingCache(graph);
ps_cache_check = true;
}
const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name); const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name);
MS_EXCEPTION_IF_NULL(address.addr); MS_EXCEPTION_IF_NULL(address.addr);
device_address = device_address =
@ -1024,5 +1033,83 @@ DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const st
MS_EXCEPTION_IF_NULL(base_ptr); MS_EXCEPTION_IF_NULL(base_ptr);
return device_address; return device_address;
} }
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *first_cache_input_index,
size_t *first_cache_size) {
MS_EXCEPTION_IF_NULL(graph);
for (const auto &kernel : graph->execution_order()) {
MS_EXCEPTION_IF_NULL(kernel);
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") {
continue;
}
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0);
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1);
MS_EXCEPTION_IF_NULL(input_param.first);
MS_EXCEPTION_IF_NULL(input_index.first);
auto param_name = input_param.first->fullname_with_scope();
if (!ps::ps_cache_instance.IsHashTable(param_name)) {
continue;
}
auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second);
MS_EXCEPTION_IF_NULL(input_index.first);
}
if ((!input_index.first->isa<Parameter>()) && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) {
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from "
<< input_index.first->fullname_with_scope();
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
"parameter server training mode.";
}
*first_cache_input_index = input_index.first;
*first_cache_size = size;
MS_LOG(INFO) << "The input index of the first embeddingLookup cache is from "
<< input_index.first->fullname_with_scope() << ", the cache size is " << size;
return;
}
}
void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
AnfNodePtr first_cache_input_index = nullptr;
size_t first_cache_size = 0;
GetFirstPSEmbeddingCache(graph, &first_cache_input_index, &first_cache_size);
MS_EXCEPTION_IF_NULL(first_cache_input_index);
for (const auto &kernel : graph->execution_order()) {
MS_EXCEPTION_IF_NULL(kernel);
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") {
continue;
}
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0);
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1);
MS_EXCEPTION_IF_NULL(input_param.first);
MS_EXCEPTION_IF_NULL(input_index.first);
auto param_name = input_param.first->fullname_with_scope();
while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second);
MS_EXCEPTION_IF_NULL(input_index.first);
}
if (input_index.first == first_cache_input_index) {
if (!ps::ps_cache_instance.IsHashTable(param_name)) {
MS_LOG(ERROR) << "The embeddingLookup(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
MS_LOG(EXCEPTION) << "All the embeddingLookups whose input indices are from dataset must enable cache at the "
"same time when one of them enables cache in parameter server training mode.";
}
auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
if (size != first_cache_size) {
MS_LOG(ERROR) << "The cache size(" << size << ") of embeddingLookup(" << kernel->fullname_with_scope()
<< ") is not the same as other embeddingLookup cache size.";
MS_LOG(EXCEPTION) << "The cache sizes of embeddingLookups are not the same in parameter server training mode.";
}
} else if (ps::ps_cache_instance.IsHashTable(param_name)) {
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from "
<< input_index.first->fullname_with_scope();
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
"parameter server training mode.";
}
}
}
#endif
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore

@ -131,6 +131,11 @@ class KernelRuntime {
void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph); void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph);
void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx);
DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *first_cache_input_index,
size_t *first_cache_size);
void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph);
#endif
protected: protected:
uint32_t device_id_{0}; uint32_t device_id_{0};

@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
"""embedding""" """embedding"""
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.context as context
from mindspore import log as logger from mindspore import log as logger
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
@ -23,8 +22,8 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker, _get_ps_context
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
@ -195,11 +194,6 @@ class EmbeddingLookup(Cell):
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
if not sparse and target == 'CPU': if not sparse and target == 'CPU':
raise ValueError('When target is CPU, embedding_lookup must be sparse.') raise ValueError('When target is CPU, embedding_lookup must be sparse.')
enable_ps = context.get_ps_context("enable_ps")
if not enable_ps and vocab_cache_size > 0:
logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning mode, "
"current mode is not parameter server trainning mode, so it will be ignored.")
vocab_cache_size = 0
if sparse: if sparse:
self.gatherv2 = P.SparseGatherV2() self.gatherv2 = P.SparseGatherV2()
else: else:
@ -207,22 +201,14 @@ class EmbeddingLookup(Cell):
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size') self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size') self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size')
self._process_vocab_cache(slice_mode)
self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size') self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size')
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.cache_enable = self.vocab_cache_size > 0
if self.cache_enable:
if is_auto_parallel:
self.vocab_cache_size = self.vocab_cache_size * get_group_size()
self.vocab_size = self.vocab_cache_size
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
name='embedding_table') name='embedding_table')
if self.cache_enable: if self.cache_enable:
self.embedding_table.cache_enable = True self._set_voacb_cache_enable(vocab_cache_size, embedding_size, vocab_size)
_set_cache_enable(True) parallel_mode = _get_parallel_mode()
if _is_role_worker(): is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
_insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
self.forward_unique = False self.forward_unique = False
self.gather_revert = P.GatherV2() self.gather_revert = P.GatherV2()
self.unique = P.Unique().shard(((1,),)) self.unique = P.Unique().shard(((1,),))
@ -241,7 +227,8 @@ class EmbeddingLookup(Cell):
self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size()))) self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size()))) self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
elif slice_mode == "table_row_slice" and is_auto_parallel: elif slice_mode == "table_row_slice" and is_auto_parallel:
if target == 'DEVICE' and not self.cache_enable: full_batch = _get_full_batch()
if target == 'DEVICE' and not full_batch:
indices_shape_size = 1 indices_shape_size = 1
self.gather_revert.shard(((1, 1), (get_group_size(),))) self.gather_revert.shard(((1, 1), (get_group_size(),)))
self.forward_unique = True self.forward_unique = True
@ -272,6 +259,39 @@ class EmbeddingLookup(Cell):
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name) self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32) self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
def _process_vocab_cache(self, slice_mode):
"""PS embeddingLookup cache check and process."""
self.cache_enable = False
if self.vocab_cache_size > 0:
if self.target == 'CPU':
logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
"current target is CPU, so it will be ignored.")
return
enable_ps = _get_ps_context("enable_ps")
if not enable_ps:
logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning "
"mode, current mode is not parameter server trainning mode, so it will be ignored.")
return
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
if is_auto_parallel:
device_num = get_group_size()
full_batch = _get_full_batch()
if device_num > 1 and not (full_batch and slice_mode == TABLE_ROW_SLICE):
raise ValueError("The embeddingLookup cache of parameter server parallel only be used "
"in 'full_batch' and 'table_row_slice' parallel strategy.")
self.vocab_cache_size = self.vocab_cache_size * device_num
self.cache_enable = True
self.vocab_size = self.vocab_cache_size
def _set_voacb_cache_enable(self, vocab_cache_size, embedding_size, vocab_size):
"""PS embeddingLookup cache enable set."""
self.embedding_table.cache_enable = True
self.embedding_table.is_param_ps = True
_set_cache_enable(True)
if _is_role_worker():
_insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
def construct(self, indices): def construct(self, indices):
if self.target == "CPU": if self.target == "CPU":
out = self.embeddinglookup(self.embedding_table, indices, 0) out = self.embeddinglookup(self.embedding_table, indices, 0)

Loading…
Cancel
Save