!10649 fix ascend ps cache loss invaild

From: @limingqi107
Reviewed-by: @cristoval,@chujinjin
Signed-off-by: @cristoval
pull/10649/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 6fa83590c1

@ -131,15 +131,18 @@ void *AscendPsCache::MallocMemory(size_t size) {
return device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(size);
}
bool AscendPsCache::MallocConstantMemory(size_t constant_value) {
bool AscendPsCache::MallocConstantMemory(size_t cache_vocab_size) {
offset_addr_ = reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int)));
MS_ERROR_IF_NULL(offset_addr_);
rtMemset(offset_addr_, sizeof(int), 0, sizeof(int));
cache_vocab_size_addr_ =
reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int)));
MS_ERROR_IF_NULL(cache_vocab_size_addr_);
rtMemset(cache_vocab_size_addr_, sizeof(int), constant_value, sizeof(int));
return true;
int copy_value = SizeToInt(cache_vocab_size);
if (!CopyHostMemToDevice(cache_vocab_size_addr_, &copy_value, sizeof(int))) {
return false;
}
return SynchronizeStream();
}
bool AscendPsCache::RecordEvent() {

@ -51,7 +51,7 @@ class AscendPsCache : public PsCacheBasic {
~AscendPsCache() override = default;
bool InitDevice(uint32_t device_id, const void *context) override;
void *MallocMemory(size_t size) override;
bool MallocConstantMemory(size_t constant_value) override;
bool MallocConstantMemory(size_t cache_vocab_size) override;
bool RecordEvent() override;
bool SynchronizeEvent() override;
bool SynchronizeStream() override;

@ -34,7 +34,7 @@ class PsCacheBasic {
virtual ~PsCacheBasic() = default;
virtual bool InitDevice(uint32_t device_id, const void *context) = 0;
virtual void *MallocMemory(size_t size) = 0;
virtual bool MallocConstantMemory(size_t constant_value) { return true; }
virtual bool MallocConstantMemory(size_t cache_vocab_size) { return true; }
virtual bool RecordEvent() = 0;
virtual bool SynchronizeEvent() = 0;
virtual bool SynchronizeStream() = 0;

@ -674,6 +674,7 @@ bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) {
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn(
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
hash_table_size, embedding_size, swap_indices_size));
RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream());
return true;
}

@ -171,7 +171,10 @@ class EmbeddingLookup(Cell):
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
or None. Default: None
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
vocab_cache_size (int): Cache size of the dictionary of embeddings.
vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in
parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding
optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE'
memory, so suggests setting a reasonable value to avoid insufficient memory.
Inputs:
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.

Loading…
Cancel
Save