!5171 fix global_step error in multi-cases

Merge pull request !5171 from liangzelang/fix_global_step_error
pull/5171/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit e83899add6

@ -87,6 +87,7 @@ class DynamicMemPoolBestFit {
void ReleaseDeviceRes();
// Display the information of memory block and memory buf.
void DumpDynamicMemPoolInfo();
SizeMapMemBuf GetIdleMemBufMap() { return global_idle_mem_buf_map_; }
// Get the related memory statistics information.
size_t total_mem_statistics() const { return total_mem_statistics_; }

@ -29,6 +29,7 @@
#include "backend/kernel_compiler/kernel.h"
#include "backend/session/session_factory.h"
#include "backend/session/ascend_control_parser.h"
#include "runtime/device/ascend/ascend_memory_pool.h"
namespace mindspore {
namespace session {
@ -37,7 +38,7 @@ enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2,
class AscendSession : public SessionBasic {
public:
AscendSession() { final_graph_id_ = kInvalidGraphId; }
~AscendSession() override = default;
~AscendSession() override { mindspore::device::ascend::AscendMemoryPool::GetInstance().ResetIdleMemBuf(); }
void Init(uint32_t device_id) override {
SessionBasic::Init(device_id);
context_ = std::make_shared<Context>(kAscendDevice, device_id);

@ -43,6 +43,13 @@ bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) {
return true;
}
void AscendMemoryPool::ResetIdleMemBuf() {
auto idle_mem_buf_map = DynamicMemPoolBestFit::GetIdleMemBufMap();
for (auto &it : idle_mem_buf_map) {
rtMemset(it.second->device_addr_, it.first, 0, it.first);
}
}
size_t AscendMemoryPool::AlignMemorySize(size_t size) const {
if (size == 0) {
MS_LOG(EXCEPTION) << "The align memory size is a zero !";

@ -31,6 +31,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override;
bool FreeDeviceMem(const DeviceMemPtr &addr) override;
void ResetIdleMemBuf();
void set_device_mem_size(uint64_t device_mem_size);
void set_device_mem_pool_base(uint8_t *device_mem_pool_base);
void set_device_mem_pool_offset(uint64_t device_mem_pool_offset);

Loading…
Cancel
Save