workspace of comm op can be reused

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
pull/13399/head
zhoufeng 4 years ago
parent d831aba239
commit b7e5f956e5

@ -1 +1 @@
Subproject commit f65be61197ed36dfc9dc10b91b58bf93835fa27b Subproject commit 40e5c42a12c4daa1530e8db9d006d5b3be5b378f

@ -46,7 +46,7 @@ std::string MsOpNameToHcomOpType(const std::string &ms_op_type) {
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) { void HcclKernelFactory::Register(const std::string &name, HcclKernelCreater &&fun) {
hcclKernelMap_.emplace(name, std::move(fun)); hcclKernelMap_.emplace(name, std::move(fun));
} }
@ -99,7 +99,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) {
if (op_name_ == kReceive) { if (op_name_ == kReceive) {
auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(receive_type_); auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(receive_type_);
if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) { if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) {
MS_LOG(ERROR) << "HcomDataType cann't support Current Ascend Data Type : " << receive_type_; MS_LOG(ERROR) << "HcomDataType cannot support Current Ascend Data Type : " << receive_type_;
return false; return false;
} }
hccl_data_type_list_.emplace_back(iter->second); hccl_data_type_list_.emplace_back(iter->second);
@ -180,9 +180,17 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
return output_size_list_; return output_size_list_;
} }
const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const {
if (!workspace_size_list_.empty() || hccl_data_type_list_.empty()) {
return workspace_size_list_;
}
workspace_size_list_.emplace_back(hccl::CalcWorkspaceSize(anf_node_, hccl_data_type_list_[0]));
return workspace_size_list_;
}
std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) { const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_); std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_);
if (hccl_type == kReceive) { if (hccl_type == kReceive) {
@ -221,10 +229,19 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
MS_LOG(EXCEPTION) << "Set data memcpy_s failed, ret = " << sec_ret; MS_LOG(EXCEPTION) << "Set data memcpy_s failed, ret = " << sec_ret;
} }
void *workspace_addr = nullptr;
if (task.workspace_size != 0) {
if (workspace.empty()) {
MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node_->DebugString() << " is empty";
}
MS_EXCEPTION_IF_NULL(workspace.at(0));
workspace_addr = workspace.at(0)->addr;
}
results.emplace_back(std::make_shared<HcclTaskInfo>( results.emplace_back(std::make_shared<HcclTaskInfo>(
kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, task.workspace_size, kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, workspace_addr,
task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, op_type_, data_type, task.workspace_size, task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_,
group_, NeedDump())); op_type_, data_type, group_, NeedDump()));
} }
return results; return results;

@ -68,7 +68,7 @@ class HcclKernelFactory {
public: public:
static HcclKernelFactory &Get(); static HcclKernelFactory &Get();
void Registe(const string &name, HcclKernelCreater &&fun); void Register(const string &name, HcclKernelCreater &&fun);
static std::shared_ptr<HcclKernel> Get(const string &name); static std::shared_ptr<HcclKernel> Get(const string &name);
private: private:
@ -78,7 +78,7 @@ class HcclKernelFactory {
class _HcclKernelRegister { class _HcclKernelRegister {
public: public:
_HcclKernelRegister(const string &name, HcclKernelCreater &&fun) { _HcclKernelRegister(const string &name, HcclKernelCreater &&fun) {
HcclKernelFactory::Get().Registe(name, std::move(fun)); HcclKernelFactory::Get().Register(name, std::move(fun));
} }
~_HcclKernelRegister() = default; ~_HcclKernelRegister() = default;
}; };

@ -433,6 +433,7 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) { void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) {
AssignCommunicationNodeInputMem(type, node); AssignCommunicationNodeInputMem(type, node);
AssignCommunicationNodeOutputMem(type, node); AssignCommunicationNodeOutputMem(type, node);
AssignWorkSpaceMem(type, node);
} }
void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) { void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) {

@ -99,7 +99,7 @@ bool FinalizeHccl() {
if (ops_kernel_info_store != nullptr) { if (ops_kernel_info_store != nullptr) {
auto ret = ops_kernel_info_store->Finalize(); auto ret = ops_kernel_info_store->Finalize();
if (ret != ge::SUCCESS) { if (ret != ge::SUCCESS) {
MS_LOG(ERROR) << "Destory info store failed, ret = " << ret; MS_LOG(ERROR) << "Destroy info store failed, ret = " << ret;
return false; return false;
} }
} }
@ -107,7 +107,7 @@ bool FinalizeHccl() {
if (ops_kernel_builder != nullptr) { if (ops_kernel_builder != nullptr) {
auto ret = ops_kernel_builder->Finalize(); auto ret = ops_kernel_builder->Finalize();
if (ret != ge::SUCCESS) { if (ret != ge::SUCCESS) {
MS_LOG(ERROR) << "Destory builder failed, ret = " << ret; MS_LOG(ERROR) << "Destroy builder failed, ret = " << ret;
return false; return false;
} }
} }
@ -151,7 +151,30 @@ bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTask
return true; return true;
} }
bool CalcOpRunningParam(const AnfNodePtr &node) { return true; } int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) {
MS_EXCEPTION_IF_NULL(ops_kernel_builder);
MS_LOG(INFO) << "Start calc workspace size for hccl node " << node->DebugString() << " ,dtype is " << datatype;
auto [ge_node, ge_graph] = GenerateStubGeNode(node, datatype);
MS_EXCEPTION_IF_NULL(ge_node);
auto op = ge_node->GetOpDesc();
MS_EXCEPTION_IF_NULL(op);
MS_LOG(INFO) << "Start to call CalcOpRunningParam";
ge::Status ret = ops_kernel_builder->CalcOpRunningParam(*ge_node);
if (ret != ge::SUCCESS) {
MS_LOG(ERROR) << "OpsKernelBuilder CalcOpRunningParam failed, ret = " << ret;
return false;
}
auto workspace_sizes = op->GetWorkspaceBytes();
if (workspace_sizes.size() != 1) {
MS_LOG(EXCEPTION) << "Unexpected workspace size " << workspace_sizes.size();
}
int64_t workspace_size = workspace_sizes[0];
MS_LOG(INFO) << "Node " << node->DebugString() << " workspace size is " << workspace_size;
ge_graph.reset();
return workspace_size;
}
void *GetHcclOpsKernelInfoStore() { return ops_kernel_info_store.get(); } void *GetHcclOpsKernelInfoStore() { return ops_kernel_info_store.get(); }

@ -23,21 +23,18 @@
#include "mindspore/core/ir/anf.h" #include "mindspore/core/ir/anf.h"
#include "hccl/hccl_types.h" #include "hccl/hccl_types.h"
#define MS_API __attribute__((visibility("default")))
namespace mindspore::hccl { namespace mindspore::hccl {
struct MS_API HcclTaskInfo { struct HcclTaskInfo {
std::string private_def; std::string private_def;
int64_t workspace_size; int64_t workspace_size;
int64_t stream_num; int64_t stream_num;
}; };
MS_API bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file); bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
MS_API bool FinalizeHccl(); bool FinalizeHccl();
MS_API bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists); bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists);
MS_API bool CalcOpRunningParam(const AnfNodePtr &node); int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype);
MS_API void *GetHcclOpsKernelInfoStore(); void *GetHcclOpsKernelInfoStore();
MS_API std::string GetHcclType(const AnfNodePtr &node); std::string GetHcclType(const AnfNodePtr &node);
} // namespace mindspore::hccl } // namespace mindspore::hccl
#undef MS_API
#endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_HCCL_ADAPTER_H #endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_HCCL_ADAPTER_H

@ -64,7 +64,7 @@ namespace hccl {
bool InitHccl(uint32_t, std::string_view, std::string_view) { return true; } bool InitHccl(uint32_t, std::string_view, std::string_view) { return true; }
bool FinalizeHccl() { return true; } bool FinalizeHccl() { return true; }
bool GenTask(const AnfNodePtr &, HcclDataType, std::vector<HcclTaskInfo> *) { return true; } bool GenTask(const AnfNodePtr &, HcclDataType, std::vector<HcclTaskInfo> *) { return true; }
bool CalcOpRunningParam(const AnfNodePtr &) { return true; } int64_t CalcWorkspaceSize(const AnfNodePtr &, HcclDataType) { return 0; }
void *GetHcclOpsKernelInfoStore() { return nullptr; } void *GetHcclOpsKernelInfoStore() { return nullptr; }
std::string GetHcclType(const AnfNodePtr &) { return ""; } std::string GetHcclType(const AnfNodePtr &) { return ""; }
} // namespace hccl } // namespace hccl

Loading…
Cancel
Save