From b7e5f956e5290aa7a81e4064ff3f9ae6d4c2ec5d Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Thu, 14 Jan 2021 20:22:10 +0800 Subject: [PATCH] workspace of comm op can be reused Signed-off-by: zhoufeng --- graphengine | 2 +- .../kernel_compiler/hccl/hccl_kernel.cc | 31 ++++++++++++++----- .../kernel_compiler/hccl/hccl_kernel.h | 4 +-- .../ccsrc/runtime/device/kernel_runtime.cc | 1 + .../runtime/hccl_adapter/hccl_adapter.cc | 29 +++++++++++++++-- .../ccsrc/runtime/hccl_adapter/hccl_adapter.h | 17 +++++----- tests/ut/cpp/stub/ge/ge_task_launch_stub.cc | 2 +- 7 files changed, 62 insertions(+), 24 deletions(-) diff --git a/graphengine b/graphengine index f65be61197..40e5c42a12 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit f65be61197ed36dfc9dc10b91b58bf93835fa27b +Subproject commit 40e5c42a12c4daa1530e8db9d006d5b3be5b378f diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc index 49f2f68b81..cfc792843a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc @@ -46,7 +46,7 @@ std::string MsOpNameToHcomOpType(const std::string &ms_op_type) { namespace mindspore { namespace kernel { -void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) { +void HcclKernelFactory::Register(const std::string &name, HcclKernelCreater &&fun) { hcclKernelMap_.emplace(name, std::move(fun)); } @@ -99,7 +99,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) { if (op_name_ == kReceive) { auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(receive_type_); if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) { - MS_LOG(ERROR) << "HcomDataType cann't support Current Ascend Data Type : " << receive_type_; + MS_LOG(ERROR) << "HcomDataType cannot support Current Ascend Data Type : " << receive_type_; return false; } hccl_data_type_list_.emplace_back(iter->second); @@ -180,9 +180,17 @@ const std::vector &HcclKernel::GetOutputSizeList() const { return output_size_list_; } -const std::vector &HcclKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } +const std::vector &HcclKernel::GetWorkspaceSizeList() const { + if (!workspace_size_list_.empty() || hccl_data_type_list_.empty()) { + return workspace_size_list_; + } + + workspace_size_list_.emplace_back(hccl::CalcWorkspaceSize(anf_node_, hccl_data_type_list_[0])); + return workspace_size_list_; +} -std::vector HcclKernel::GenTask(const std::vector &inputs, const std::vector &, +std::vector HcclKernel::GenTask(const std::vector &inputs, + const std::vector &workspace, const std::vector &outputs, uint32_t stream_id) { std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_); if (hccl_type == kReceive) { @@ -221,10 +229,19 @@ std::vector HcclKernel::GenTask(const std::vector &inpu MS_LOG(EXCEPTION) << "Set data memcpy_s failed, ret = " << sec_ret; } + void *workspace_addr = nullptr; + if (task.workspace_size != 0) { + if (workspace.empty()) { + MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node_->DebugString() << " is empty"; + } + MS_EXCEPTION_IF_NULL(workspace.at(0)); + workspace_addr = workspace.at(0)->addr; + } + results.emplace_back(std::make_shared( - kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, task.workspace_size, - task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, op_type_, data_type, - group_, NeedDump())); + kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, workspace_addr, + task.workspace_size, task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, + op_type_, data_type, group_, NeedDump())); } return results; diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h index 9f2780a51c..4930ba6b6d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h @@ -68,7 +68,7 @@ class HcclKernelFactory { public: static HcclKernelFactory &Get(); - void Registe(const string &name, HcclKernelCreater &&fun); + void Register(const string &name, HcclKernelCreater &&fun); static std::shared_ptr Get(const string &name); private: @@ -78,7 +78,7 @@ class HcclKernelFactory { class _HcclKernelRegister { public: _HcclKernelRegister(const string &name, HcclKernelCreater &&fun) { - HcclKernelFactory::Get().Registe(name, std::move(fun)); + HcclKernelFactory::Get().Register(name, std::move(fun)); } ~_HcclKernelRegister() = default; }; diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index abc93b2a46..02b1e0beb4 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -433,6 +433,7 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) { void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) { AssignCommunicationNodeInputMem(type, node); AssignCommunicationNodeOutputMem(type, node); + AssignWorkSpaceMem(type, node); } void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) { diff --git a/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc b/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc index f5782de019..b0b5885ca5 100644 --- a/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc +++ b/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc @@ -99,7 +99,7 @@ bool FinalizeHccl() { if (ops_kernel_info_store != nullptr) { auto ret = ops_kernel_info_store->Finalize(); if (ret != ge::SUCCESS) { - MS_LOG(ERROR) << "Destory info store failed, ret = " << ret; + MS_LOG(ERROR) << "Destroy info store failed, ret = " << ret; return false; } } @@ -107,7 +107,7 @@ bool FinalizeHccl() { if (ops_kernel_builder != nullptr) { auto ret = ops_kernel_builder->Finalize(); if (ret != ge::SUCCESS) { - MS_LOG(ERROR) << "Destory builder failed, ret = " << ret; + MS_LOG(ERROR) << "Destroy builder failed, ret = " << ret; return false; } } @@ -151,7 +151,30 @@ bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vectorDebugString() << " ,dtype is " << datatype; + auto [ge_node, ge_graph] = GenerateStubGeNode(node, datatype); + MS_EXCEPTION_IF_NULL(ge_node); + auto op = ge_node->GetOpDesc(); + MS_EXCEPTION_IF_NULL(op); + + MS_LOG(INFO) << "Start to call CalcOpRunningParam"; + ge::Status ret = ops_kernel_builder->CalcOpRunningParam(*ge_node); + if (ret != ge::SUCCESS) { + MS_LOG(ERROR) << "OpsKernelBuilder CalcOpRunningParam failed, ret = " << ret; + return false; + } + + auto workspace_sizes = op->GetWorkspaceBytes(); + if (workspace_sizes.size() != 1) { + MS_LOG(EXCEPTION) << "Unexpected workspace size " << workspace_sizes.size(); + } + int64_t workspace_size = workspace_sizes[0]; + MS_LOG(INFO) << "Node " << node->DebugString() << " workspace size is " << workspace_size; + ge_graph.reset(); + return workspace_size; +} void *GetHcclOpsKernelInfoStore() { return ops_kernel_info_store.get(); } diff --git a/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.h b/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.h index cd82fbcc0a..6488a0ffb6 100644 --- a/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.h +++ b/mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.h @@ -23,21 +23,18 @@ #include "mindspore/core/ir/anf.h" #include "hccl/hccl_types.h" -#define MS_API __attribute__((visibility("default"))) - namespace mindspore::hccl { -struct MS_API HcclTaskInfo { +struct HcclTaskInfo { std::string private_def; int64_t workspace_size; int64_t stream_num; }; -MS_API bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file); -MS_API bool FinalizeHccl(); -MS_API bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector *task_info_lists); -MS_API bool CalcOpRunningParam(const AnfNodePtr &node); -MS_API void *GetHcclOpsKernelInfoStore(); -MS_API std::string GetHcclType(const AnfNodePtr &node); +bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file); +bool FinalizeHccl(); +bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector *task_info_lists); +int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype); +void *GetHcclOpsKernelInfoStore(); +std::string GetHcclType(const AnfNodePtr &node); } // namespace mindspore::hccl -#undef MS_API #endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_HCCL_ADAPTER_H diff --git a/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc b/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc index 331c23dba1..24fbde1fc9 100644 --- a/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc +++ b/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc @@ -64,7 +64,7 @@ namespace hccl { bool InitHccl(uint32_t, std::string_view, std::string_view) { return true; } bool FinalizeHccl() { return true; } bool GenTask(const AnfNodePtr &, HcclDataType, std::vector *) { return true; } -bool CalcOpRunningParam(const AnfNodePtr &) { return true; } +int64_t CalcWorkspaceSize(const AnfNodePtr &, HcclDataType) { return 0; } void *GetHcclOpsKernelInfoStore() { return nullptr; } std::string GetHcclType(const AnfNodePtr &) { return ""; } } // namespace hccl