|
|
|
@ -131,15 +131,18 @@ void *AscendPsCache::MallocMemory(size_t size) {
|
|
|
|
|
return device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AscendPsCache::MallocConstantMemory(size_t constant_value) {
|
|
|
|
|
bool AscendPsCache::MallocConstantMemory(size_t cache_vocab_size) {
|
|
|
|
|
offset_addr_ = reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int)));
|
|
|
|
|
MS_ERROR_IF_NULL(offset_addr_);
|
|
|
|
|
rtMemset(offset_addr_, sizeof(int), 0, sizeof(int));
|
|
|
|
|
cache_vocab_size_addr_ =
|
|
|
|
|
reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int)));
|
|
|
|
|
MS_ERROR_IF_NULL(cache_vocab_size_addr_);
|
|
|
|
|
rtMemset(cache_vocab_size_addr_, sizeof(int), constant_value, sizeof(int));
|
|
|
|
|
return true;
|
|
|
|
|
int copy_value = SizeToInt(cache_vocab_size);
|
|
|
|
|
if (!CopyHostMemToDevice(cache_vocab_size_addr_, ©_value, sizeof(int))) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return SynchronizeStream();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AscendPsCache::RecordEvent() {
|
|
|
|
|