From d6308151e0afa90834c9bdea72e018ff47c2a5fc Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Mon, 18 Jan 2021 15:10:16 +0800 Subject: [PATCH] reuse workspace memory of hccl op Signed-off-by: zhoufeng --- ge/ge_runtime/task/hccl_task.cc | 16 ++-------------- inc/framework/ge_runtime/task_info.h | 5 ++++- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/ge/ge_runtime/task/hccl_task.cc b/ge/ge_runtime/task/hccl_task.cc index 06165053..c3040697 100644 --- a/ge/ge_runtime/task/hccl_task.cc +++ b/ge/ge_runtime/task/hccl_task.cc @@ -52,15 +52,7 @@ HcclTask::HcclTask(const ModelContext &model_context, const std::shared_ptrworkspace_size() > 0) { - rtError_t rt_ret = rtMalloc(&workspace_mem_, task_info_->workspace_size(), RT_MEMORYINFO_HBM); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return false; - } + workspace_mem_ = task_info_->workspace_addr(); } GELOGI("HcclTaskInfo Distribute Start. begin to call function LoadTask in hccl."); diff --git a/inc/framework/ge_runtime/task_info.h b/inc/framework/ge_runtime/task_info.h index f59c6454..4530bff7 100644 --- a/inc/framework/ge_runtime/task_info.h +++ b/inc/framework/ge_runtime/task_info.h @@ -271,13 +271,14 @@ class FusionEndTaskInfo : public TaskInfo { class HcclTaskInfo : public TaskInfo { public: HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr, - void *output_data_addr, int64_t workspace_size, int64_t hccl_stream_num, + void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, const std::vector &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag) : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag), hccl_type_(hccl_type), input_data_addr_(input_data_addr), output_data_addr_(output_data_addr), + workspace_addr_(workspace_addr), workspace_size_(workspace_size), hccl_stream_num_(hccl_stream_num), private_def_(private_def), @@ -292,6 +293,7 @@ class HcclTaskInfo : public TaskInfo { const std::string &hccl_type() const { return hccl_type_; } void *input_data_addr() const { return input_data_addr_; } void *output_data_addr() const { return output_data_addr_; } + void *workspace_addr() const { return workspace_addr_; } int64_t workspace_size() const { return workspace_size_; } int64_t hccl_stream_num() const { return hccl_stream_num_; } const std::vector &private_def() const { return private_def_; } @@ -306,6 +308,7 @@ class HcclTaskInfo : public TaskInfo { std::string hccl_type_; void *input_data_addr_; void *output_data_addr_; + void *workspace_addr_; int64_t workspace_size_; int64_t hccl_stream_num_; std::vector private_def_;