add ps cache check

pull/9957/head
limingqi107 4 years ago
parent 3258c6e76d
commit a844d52b42

@ -1813,6 +1813,11 @@ void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) co
void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
// PS embeddingLookup cache check.
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
MS_LOG(EXCEPTION) << "The other parameter cann't set ps mode when the embeddingLookup cache is enabled in "
"parameter server training mode.";
}
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return());
for (auto &node : node_list) {
if (node != nullptr && node->isa<CNode>()) {

@ -976,15 +976,12 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
backend->Link(runner.graph_id);
}
// PS mode does not support loop sink.
ConfigManager::GetInstance().set_iter_num(size);
// PS cache does not support loop sink.
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::Util::IsRoleOfWorker()) {
if (ps::Util::IsRoleOfWorker() && ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size));
ConfigManager::GetInstance().set_iter_num(1);
} else {
#endif
ConfigManager::GetInstance().set_iter_num(size);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
}
#endif

@ -129,11 +129,19 @@ void PsCacheManager::CloneHashTable(const std::string &dest_param_name, const st
const Address &PsCacheManager::QueryHashTableAddr(const std::string &param_name) const {
auto iter = hash_tables_.find(param_name);
if (iter == hash_tables_.end()) {
MS_LOG(EXCEPTION) << "Can not find device_address of " << param_name;
MS_LOG(EXCEPTION) << "Can not find device address of " << param_name;
}
return iter->second.device_address;
}
const size_t &PsCacheManager::QueryHashTableSize(const std::string &param_name) const {
auto iter = hash_tables_.find(param_name);
if (iter == hash_tables_.end()) {
MS_LOG(EXCEPTION) << "Can not find vocab cache size of " << param_name;
}
return iter->second.cache_vocab_size;
}
void PsCacheManager::Initialize() {
MS_LOG(INFO) << "PS cache initialize.";
if (!worker.running()) {
@ -244,19 +252,19 @@ void PsCacheManager::set_channel_name(const std::string channel_name) {
void PsCacheManager::IncreaseStep() {
if (data_step_ >= UINT64_MAX) {
MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t.";
MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") will exceed the maximum value of uint64_t.";
}
data_step_++;
set_current_graph_step();
if (graph_running_step_ > data_step_) {
MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") << exceed the data step ("
<< data_step_ << ").";
MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") exceed the data step (" << data_step_
<< ").";
}
}
void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
if (graph_step_ >= UINT64_MAX) {
MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t.";
MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t.";
}
if (graph_step_ == 0) {
MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_;
@ -271,8 +279,10 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
}
void PsCacheManager::DoProcessData(uint32_t device_id, void *context) {
// PS embeddingLookup cache check.
if (!initialized_ps_cache_) {
MS_LOG(EXCEPTION) << "PS cache does not init.";
MS_LOG(EXCEPTION) << "Only the sink_mode of dataset supports embeddingLookup cache in parameter server training "
"mode, current dataset mode is not sink_mode.";
}
auto process_data_thread = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context);
process_data_thread.detach();

@ -120,6 +120,7 @@ class PsCacheManager {
size_t cache_vocab_size, size_t embedding_size);
void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name);
const Address &QueryHashTableAddr(const std::string &param_name) const;
const size_t &QueryHashTableSize(const std::string &param_name) const;
bool IsHashTable(const std::string &param_name) { return hash_tables_.count(param_name) != 0; }
void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; }
bool initialized_ps_cache() const { return initialized_ps_cache_; }

@ -325,7 +325,9 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
}
need_alloc_nodes.push_back(item);
}
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
bool ps_cache_check = false;
#endif
for (auto &item : need_alloc_nodes) {
auto output_size = AnfAlgo::GetOutputTensorNum(item);
for (size_t index = 0; index < output_size; index++) {
@ -339,6 +341,13 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
const std::string &param_name = item->fullname_with_scope();
if (ps::ps_cache_instance.IsHashTable(param_name)) {
MS_LOG(INFO) << "Parameter(" << param_name << ")"
<< " enables the embeddingLookup cache in parameter server training mode.";
// PS embeddingLookup cache check.
if (!ps_cache_check) {
CheckIfSupportPSEmbeddingCache(graph);
ps_cache_check = true;
}
const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name);
MS_EXCEPTION_IF_NULL(address.addr);
device_address =
@ -1024,5 +1033,83 @@ DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const st
MS_EXCEPTION_IF_NULL(base_ptr);
return device_address;
}
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *first_cache_input_index,
size_t *first_cache_size) {
MS_EXCEPTION_IF_NULL(graph);
for (const auto &kernel : graph->execution_order()) {
MS_EXCEPTION_IF_NULL(kernel);
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") {
continue;
}
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0);
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1);
MS_EXCEPTION_IF_NULL(input_param.first);
MS_EXCEPTION_IF_NULL(input_index.first);
auto param_name = input_param.first->fullname_with_scope();
if (!ps::ps_cache_instance.IsHashTable(param_name)) {
continue;
}
auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second);
MS_EXCEPTION_IF_NULL(input_index.first);
}
if ((!input_index.first->isa<Parameter>()) && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) {
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from "
<< input_index.first->fullname_with_scope();
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
"parameter server training mode.";
}
*first_cache_input_index = input_index.first;
*first_cache_size = size;
MS_LOG(INFO) << "The input index of the first embeddingLookup cache is from "
<< input_index.first->fullname_with_scope() << ", the cache size is " << size;
return;
}
}
void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
AnfNodePtr first_cache_input_index = nullptr;
size_t first_cache_size = 0;
GetFirstPSEmbeddingCache(graph, &first_cache_input_index, &first_cache_size);
MS_EXCEPTION_IF_NULL(first_cache_input_index);
for (const auto &kernel : graph->execution_order()) {
MS_EXCEPTION_IF_NULL(kernel);
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") {
continue;
}
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0);
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1);
MS_EXCEPTION_IF_NULL(input_param.first);
MS_EXCEPTION_IF_NULL(input_index.first);
auto param_name = input_param.first->fullname_with_scope();
while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second);
MS_EXCEPTION_IF_NULL(input_index.first);
}
if (input_index.first == first_cache_input_index) {
if (!ps::ps_cache_instance.IsHashTable(param_name)) {
MS_LOG(ERROR) << "The embeddingLookup(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
MS_LOG(EXCEPTION) << "All the embeddingLookups whose input indices are from dataset must enable cache at the "
"same time when one of them enables cache in parameter server training mode.";
}
auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
if (size != first_cache_size) {
MS_LOG(ERROR) << "The cache size(" << size << ") of embeddingLookup(" << kernel->fullname_with_scope()
<< ") is not the same as other embeddingLookup cache size.";
MS_LOG(EXCEPTION) << "The cache sizes of embeddingLookups are not the same in parameter server training mode.";
}
} else if (ps::ps_cache_instance.IsHashTable(param_name)) {
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from "
<< input_index.first->fullname_with_scope();
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
"parameter server training mode.";
}
}
}
#endif
} // namespace device
} // namespace mindspore

@ -131,6 +131,11 @@ class KernelRuntime {
void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph);
void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx);
DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *first_cache_input_index,
size_t *first_cache_size);
void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph);
#endif
protected:
uint32_t device_id_{0};

@ -14,7 +14,6 @@
# ============================================================================
"""embedding"""
import mindspore.common.dtype as mstype
import mindspore.context as context
from mindspore import log as logger
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
@ -23,8 +22,8 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.communication.management import get_group_size
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode
from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker, _get_ps_context
from mindspore._checkparam import Rel
from mindspore._checkparam import Validator as validator
from mindspore.ops.primitive import constexpr
@ -195,11 +194,6 @@ class EmbeddingLookup(Cell):
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
if not sparse and target == 'CPU':
raise ValueError('When target is CPU, embedding_lookup must be sparse.')
enable_ps = context.get_ps_context("enable_ps")
if not enable_ps and vocab_cache_size > 0:
logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning mode, "
"current mode is not parameter server trainning mode, so it will be ignored.")
vocab_cache_size = 0
if sparse:
self.gatherv2 = P.SparseGatherV2()
else:
@ -207,22 +201,14 @@ class EmbeddingLookup(Cell):
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size')
self._process_vocab_cache(slice_mode)
self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size')
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.cache_enable = self.vocab_cache_size > 0
if self.cache_enable:
if is_auto_parallel:
self.vocab_cache_size = self.vocab_cache_size * get_group_size()
self.vocab_size = self.vocab_cache_size
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
name='embedding_table')
if self.cache_enable:
self.embedding_table.cache_enable = True
_set_cache_enable(True)
if _is_role_worker():
_insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
self._set_voacb_cache_enable(vocab_cache_size, embedding_size, vocab_size)
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.forward_unique = False
self.gather_revert = P.GatherV2()
self.unique = P.Unique().shard(((1,),))
@ -241,7 +227,8 @@ class EmbeddingLookup(Cell):
self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
elif slice_mode == "table_row_slice" and is_auto_parallel:
if target == 'DEVICE' and not self.cache_enable:
full_batch = _get_full_batch()
if target == 'DEVICE' and not full_batch:
indices_shape_size = 1
self.gather_revert.shard(((1, 1), (get_group_size(),)))
self.forward_unique = True
@ -272,6 +259,39 @@ class EmbeddingLookup(Cell):
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
def _process_vocab_cache(self, slice_mode):
"""PS embeddingLookup cache check and process."""
self.cache_enable = False
if self.vocab_cache_size > 0:
if self.target == 'CPU':
logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
"current target is CPU, so it will be ignored.")
return
enable_ps = _get_ps_context("enable_ps")
if not enable_ps:
logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning "
"mode, current mode is not parameter server trainning mode, so it will be ignored.")
return
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
if is_auto_parallel:
device_num = get_group_size()
full_batch = _get_full_batch()
if device_num > 1 and not (full_batch and slice_mode == TABLE_ROW_SLICE):
raise ValueError("The embeddingLookup cache of parameter server parallel only be used "
"in 'full_batch' and 'table_row_slice' parallel strategy.")
self.vocab_cache_size = self.vocab_cache_size * device_num
self.cache_enable = True
self.vocab_size = self.vocab_cache_size
def _set_voacb_cache_enable(self, vocab_cache_size, embedding_size, vocab_size):
"""PS embeddingLookup cache enable set."""
self.embedding_table.cache_enable = True
self.embedding_table.is_param_ps = True
_set_cache_enable(True)
if _is_role_worker():
_insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
def construct(self, indices):
if self.target == "CPU":
out = self.embeddinglookup(self.embedding_table, indices, 0)

Loading…
Cancel
Save