From fdbe4c19ba0ac7b244b80e360b0f8e355de775e0 Mon Sep 17 00:00:00 2001
From: lvchangquan <lvchangquan@huawei.com>
Date: Fri, 24 Jul 2020 16:52:41 +0800
Subject: [PATCH]   use kernel_runtime::mem_manager to reduce rtMalloc and
 rtFree time in trans data format

---
 .../device/ascend/ascend_device_address.cc    | 48 ++++++++-----------
 .../ccsrc/runtime/device/kernel_runtime.cc    | 16 +++++--
 .../ccsrc/runtime/device/kernel_runtime.h     |  6 ++-
 3 files changed, 37 insertions(+), 33 deletions(-)

diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc
index 026195cc29..9d4466c46d 100644
--- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc
+++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc
@@ -153,6 +153,16 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s
   return true;
 }
 
+DeviceAddressPtr AssignLaunchMemory(size_t size, const std::string &format, TypeId type) {
+  auto ms_context = MsContext::GetInstance();
+  MS_EXCEPTION_IF_NULL(ms_context);
+  auto device_id = ms_context->device_id();
+  auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
+  MS_EXCEPTION_IF_NULL(runtime_instance);
+  auto address_ptr = runtime_instance->AssignSingleOpLaunchMemory(size, format, type);
+  return address_ptr;
+}
+
 size_t GetCommonAlignSize(size_t input_size) {
   return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize;
 }
@@ -325,18 +335,15 @@ void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, v
   AddressPtrList kernel_inputs = {input_address};
   AddressPtrList kernel_outputs = {output_address};
   AddressPtrList kernel_workspaces;
-  std::vector<void *> workspaces_address_ptr(workspace_size_list.size(), nullptr);
   if (!workspace_size_list.empty()) {
     for (size_t i = 0; i < workspace_size_list.size(); ++i) {
       auto workspace_size = GetCommonAlignSize(workspace_size_list[i]);
-      auto ret_malloc = rtMalloc(&workspaces_address_ptr[i], workspace_size, RT_MEMORY_HBM);
-      if (ret_malloc != RT_ERROR_NONE) {
-        MS_LOG(ERROR) << "Failed to rtMalloc memory";
-      }
+      auto workspace_address_ptr = AssignLaunchMemory(workspace_size, "", kTypeUnknown);
+      MS_EXCEPTION_IF_NULL(workspace_address_ptr);
       auto workspace_address = std::make_shared<kernel::Address>();
       MS_EXCEPTION_IF_NULL(workspace_address);
-      workspace_address->addr = workspaces_address_ptr[i];
-      workspace_address->size = workspace_size;
+      workspace_address->addr = workspace_address_ptr->GetMutablePtr();
+      workspace_address->size = workspace_address_ptr->GetSize();
       kernel_workspaces.push_back(workspace_address);
     }
   }
@@ -350,15 +357,6 @@ void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, v
   if (!ret) {
     MS_LOG(ERROR) << "Launch kernel failed.";
   }
-  SyncStream();
-  if (!workspace_size_list.empty()) {
-    for (size_t i = 0; i < workspace_size_list.size(); ++i) {
-      auto ret_free = rtFree(workspaces_address_ptr[i]);
-      if (ret_free != RT_ERROR_NONE) {
-        MS_LOG(ERROR) << "Failed to rtFree memory";
-      }
-    }
-  }
 }
 
 kernel::KernelModPtr AscendDeviceAddress::CompileTransDataAndObtainKernelMod(const nlohmann::json &kernel_json) const {
@@ -418,19 +416,17 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const
     size = device_dtype_size * shape_size;
   }
   size = GetCommonAlignSize(size);
-  void *output_address_ptr = nullptr;
-  auto ret_malloc = rtMalloc(&output_address_ptr, size, RT_MEMORY_HBM);
-  if (ret_malloc != RT_ERROR_NONE) {
-    MS_LOG(ERROR) << "Failed to rtMalloc memory";
-  }
+  auto output_address = AssignLaunchMemory(size, kOpFormat_NCHW, type_id_);
+  MS_EXCEPTION_IF_NULL(output_address);
   auto workspace_size_list = GetWorkspaceSizeList(kernel_json);
   // launch
-  LaunchTransData(kernel_mod_ptr, output_address_ptr, size, workspace_size_list);
+  LaunchTransData(kernel_mod_ptr, output_address->GetMutablePtr(), output_address->GetSize(), workspace_size_list);
+  SyncStream();
   if (type_id_ == type) {
-    SyncMemory(host_ptr, output_address_ptr, host_size, RT_MEMCPY_DEVICE_TO_HOST);
+    SyncMemory(host_ptr, output_address->GetPtr(), host_size, RT_MEMCPY_DEVICE_TO_HOST);
   } else {
     auto host = std::vector<uint8_t>(size);
-    SyncMemory(host.data(), output_address_ptr, size, RT_MEMCPY_DEVICE_TO_HOST);
+    SyncMemory(host.data(), output_address->GetPtr(), size, RT_MEMCPY_DEVICE_TO_HOST);
     auto shape_size = trans::ShapeSize(host_shape);
     const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, host_size};
     sync_ok = trans::TransDataType(type_args, host_ptr);
@@ -439,10 +435,6 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const
       return false;
     }
   }
-  auto ret_free = rtFree(output_address_ptr);
-  if (ret_free != RT_ERROR_NONE) {
-    MS_LOG(ERROR) << "Failed to rtFree memory";
-  }
   return sync_ok;
 }
 
diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc
index 0ff65c4784..b86cdbed4f 100644
--- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc
+++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc
@@ -842,9 +842,10 @@ void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
   MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource";
 }
 
-bool KernelRuntime::LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, AddressPtrList kernel_inputs,
-                                                  AddressPtrList kernel_outputs,
-                                                  AddressPtrList kernel_workspaces) const {
+bool KernelRuntime::LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr,
+                                                  const AddressPtrList &kernel_inputs,
+                                                  const AddressPtrList &kernel_outputs,
+                                                  const AddressPtrList &kernel_workspaces) const {
   MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
   auto ret = kernel_mod_ptr->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
   if (!ret) {
@@ -854,6 +855,15 @@ bool KernelRuntime::LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mo
   return true;
 }
 
+DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const std::string &format, TypeId type) {
+  auto device_address = CreateDeviceAddress(nullptr, size, format, type);
+  MS_EXCEPTION_IF_NULL(device_address);
+  MS_EXCEPTION_IF_NULL(mem_manager_);
+  auto base_ptr = mem_manager_->MallocMem(kDynamicMem, size);
+  device_address->set_ptr(base_ptr);
+  return device_address;
+}
+
 #ifdef ENABLE_DUMP_E2E
 bool KernelRuntime::SetDumpConf() {
   dump_conf_ptr_ = std::make_shared<Dump>();
diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h
index d368b3eedf..b613ee277a 100644
--- a/mindspore/ccsrc/runtime/device/kernel_runtime.h
+++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h
@@ -65,8 +65,9 @@ class KernelRuntime {
   virtual bool RunTask(const session::KernelGraph *graph);
   virtual bool GenTask(const session::KernelGraph *graph);
   bool LaunchKernel(const session::KernelGraph *graph);
-  bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, AddressPtrList kernel_inputs,
-                                     AddressPtrList kernel_outputs, AddressPtrList kernel_workspaces) const;
+  bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, const AddressPtrList &kernel_inputs,
+                                     const AddressPtrList &kernel_outputs,
+                                     const AddressPtrList &kernel_workspaces) const;
   virtual void AssignStaticMemoryInput(const session::KernelGraph *graph);
   virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph);
   virtual void ClearGraphRuntimeResource(uint32_t graph_id);
@@ -79,6 +80,7 @@ class KernelRuntime {
   // for GPU and D to impl
   virtual void ReleaseDeviceRes() {}
   void set_device_id(uint32_t device_id) { device_id_ = device_id; }
+  DeviceAddressPtr AssignSingleOpLaunchMemory(size_t size, const std::string &format, TypeId type);
 
  protected:
   virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,