fix repeat rtmalloc device mem

pull/1468/head
zhou_lili 4 years ago
parent b964b15ee4
commit e8fcd806f6

@ -84,7 +84,7 @@ Status CalInputsHostMemSize(const std::vector<DataBuffer> &inputs,
inputs_size.emplace_back(index, input_size); inputs_size.emplace_back(index, input_size);
GE_CHK_STATUS_RET(CheckInt64AddOverflow(total_size, input_size), "Total size is beyond the INT64_MAX."); GE_CHK_STATUS_RET(CheckInt64AddOverflow(total_size, input_size), "Total size is beyond the INT64_MAX.");
total_size += input_size; total_size += input_size;
GELOGD("The %zu input mem type is host, tensor size is %ld.", index, input_size); GELOGD("The %zu input mem type is host, the tensor size is %ld.", index, input_size);
} }
index++; index++;
} }
@ -99,20 +99,16 @@ Status UpdateInputsBufferAddr(StreamResource *stream_resource, rtStream_t stream
const std::vector<std::pair<size_t, uint64_t>> &inputs_size, const std::vector<std::pair<size_t, uint64_t>> &inputs_size,
std::vector<DataBuffer> &update_buffers) { std::vector<DataBuffer> &update_buffers) {
GE_CHECK_NOTNULL(stream_resource); GE_CHECK_NOTNULL(stream_resource);
if (stream_resource->Init() != SUCCESS) {
GELOGE(FAILED, "[Malloc][Memory]Failed to malloc device buffer.");
return FAILED;
}
auto dst_addr = reinterpret_cast<uint8_t *>(stream_resource->GetDeviceBufferAddr()); auto dst_addr = reinterpret_cast<uint8_t *>(stream_resource->GetDeviceBufferAddr());
// copy host mem from input_buffer to device mem of dst_addr // copy host mem from input_buffer to device mem of dst_addr
for (const auto &input_size : inputs_size) { for (const auto &input_size : inputs_size) {
size_t index = input_size.first; auto index = input_size.first;
auto size = input_size.second; auto size = input_size.second;
GELOGD("Do H2D for %zu input, dst size is %zu, src length is %lu.", index, size, update_buffers[index].length); GELOGD("Do h2d for %zu input, dst size is %zu, src length is %lu.", index, size, update_buffers[index].length);
GE_CHK_RT_RET(rtMemcpyAsync(dst_addr, size, update_buffers[index].data, update_buffers[index].length, GE_CHK_RT_RET(rtMemcpyAsync(dst_addr, size, update_buffers[index].data, update_buffers[index].length,
RT_MEMCPY_HOST_TO_DEVICE_EX, stream)); RT_MEMCPY_HOST_TO_DEVICE_EX, stream));
update_buffers[index].data = dst_addr; update_buffers[index].data = dst_addr;
dst_addr = reinterpret_cast<uint8_t *>(dst_addr + size); dst_addr = dst_addr + size;
} }
return SUCCESS; return SUCCESS;
} }

@ -81,8 +81,13 @@ StreamResource *SingleOpManager::GetResource(uintptr_t resource_id, rtStream_t s
auto it = stream_resources_.find(resource_id); auto it = stream_resources_.find(resource_id);
StreamResource *res = nullptr; StreamResource *res = nullptr;
if (it == stream_resources_.end()) { if (it == stream_resources_.end()) {
res = new (std::nothrow) StreamResource(resource_id); res = new(std::nothrow) StreamResource(resource_id);
if (res != nullptr) { if (res != nullptr) {
if (res->Init() != SUCCESS) {
GELOGE(FAILED, "[Malloc][Memory]Failed to malloc device buffer.");
delete res;
return nullptr;
}
res->SetStream(stream); res->SetStream(stream);
stream_resources_.emplace(resource_id, res); stream_resources_.emplace(resource_id, res);
} }

Loading…
Cancel
Save