workspace of comm op can be reused

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
pull/13399/head
zhoufeng 4 years ago
parent d831aba239
commit b7e5f956e5

@ -1 +1 @@
Subproject commit f65be61197ed36dfc9dc10b91b58bf93835fa27b
Subproject commit 40e5c42a12c4daa1530e8db9d006d5b3be5b378f

@ -46,7 +46,7 @@ std::string MsOpNameToHcomOpType(const std::string &ms_op_type) {
namespace mindspore {
namespace kernel {
void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) {
void HcclKernelFactory::Register(const std::string &name, HcclKernelCreater &&fun) {
hcclKernelMap_.emplace(name, std::move(fun));
}
@ -99,7 +99,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) {
if (op_name_ == kReceive) {
auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(receive_type_);
if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) {
MS_LOG(ERROR) << "HcomDataType cann't support Current Ascend Data Type : " << receive_type_;
MS_LOG(ERROR) << "HcomDataType cannot support Current Ascend Data Type : " << receive_type_;
return false;
}
hccl_data_type_list_.emplace_back(iter->second);
@ -180,9 +180,17 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
return output_size_list_;
}
const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const {
if (!workspace_size_list_.empty() || hccl_data_type_list_.empty()) {
return workspace_size_list_;
}
workspace_size_list_.emplace_back(hccl::CalcWorkspaceSize(anf_node_, hccl_data_type_list_[0]));
return workspace_size_list_;
}
std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_);
if (hccl_type == kReceive) {
@ -221,10 +229,19 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
MS_LOG(EXCEPTION) << "Set data memcpy_s failed, ret = " << sec_ret;
}
void *workspace_addr = nullptr;
if (task.workspace_size != 0) {
if (workspace.empty()) {
MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node_->DebugString() << " is empty";
}
MS_EXCEPTION_IF_NULL(workspace.at(0));
workspace_addr = workspace.at(0)->addr;
}
results.emplace_back(std::make_shared<HcclTaskInfo>(
kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, task.workspace_size,
task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, op_type_, data_type,
group_, NeedDump()));
kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, workspace_addr,
task.workspace_size, task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_,
op_type_, data_type, group_, NeedDump()));
}
return results;

@ -68,7 +68,7 @@ class HcclKernelFactory {
public:
static HcclKernelFactory &Get();
void Registe(const string &name, HcclKernelCreater &&fun);
void Register(const string &name, HcclKernelCreater &&fun);
static std::shared_ptr<HcclKernel> Get(const string &name);
private:
@ -78,7 +78,7 @@ class HcclKernelFactory {
class _HcclKernelRegister {
public:
_HcclKernelRegister(const string &name, HcclKernelCreater &&fun) {
HcclKernelFactory::Get().Registe(name, std::move(fun));
HcclKernelFactory::Get().Register(name, std::move(fun));
}
~_HcclKernelRegister() = default;
};

@ -433,6 +433,7 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) {
AssignCommunicationNodeInputMem(type, node);
AssignCommunicationNodeOutputMem(type, node);
AssignWorkSpaceMem(type, node);
}
void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) {

@ -99,7 +99,7 @@ bool FinalizeHccl() {
if (ops_kernel_info_store != nullptr) {
auto ret = ops_kernel_info_store->Finalize();
if (ret != ge::SUCCESS) {
MS_LOG(ERROR) << "Destory info store failed, ret = " << ret;
MS_LOG(ERROR) << "Destroy info store failed, ret = " << ret;
return false;
}
}
@ -107,7 +107,7 @@ bool FinalizeHccl() {
if (ops_kernel_builder != nullptr) {
auto ret = ops_kernel_builder->Finalize();
if (ret != ge::SUCCESS) {
MS_LOG(ERROR) << "Destory builder failed, ret = " << ret;
MS_LOG(ERROR) << "Destroy builder failed, ret = " << ret;
return false;
}
}
@ -151,7 +151,30 @@ bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTask
return true;
}
bool CalcOpRunningParam(const AnfNodePtr &node) { return true; }
int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) {
MS_EXCEPTION_IF_NULL(ops_kernel_builder);
MS_LOG(INFO) << "Start calc workspace size for hccl node " << node->DebugString() << " ,dtype is " << datatype;
auto [ge_node, ge_graph] = GenerateStubGeNode(node, datatype);
MS_EXCEPTION_IF_NULL(ge_node);
auto op = ge_node->GetOpDesc();
MS_EXCEPTION_IF_NULL(op);
MS_LOG(INFO) << "Start to call CalcOpRunningParam";
ge::Status ret = ops_kernel_builder->CalcOpRunningParam(*ge_node);
if (ret != ge::SUCCESS) {
MS_LOG(ERROR) << "OpsKernelBuilder CalcOpRunningParam failed, ret = " << ret;
return false;
}
auto workspace_sizes = op->GetWorkspaceBytes();
if (workspace_sizes.size() != 1) {
MS_LOG(EXCEPTION) << "Unexpected workspace size " << workspace_sizes.size();
}
int64_t workspace_size = workspace_sizes[0];
MS_LOG(INFO) << "Node " << node->DebugString() << " workspace size is " << workspace_size;
ge_graph.reset();
return workspace_size;
}
void *GetHcclOpsKernelInfoStore() { return ops_kernel_info_store.get(); }

@ -23,21 +23,18 @@
#include "mindspore/core/ir/anf.h"
#include "hccl/hccl_types.h"
#define MS_API __attribute__((visibility("default")))
namespace mindspore::hccl {
struct MS_API HcclTaskInfo {
struct HcclTaskInfo {
std::string private_def;
int64_t workspace_size;
int64_t stream_num;
};
MS_API bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
MS_API bool FinalizeHccl();
MS_API bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists);
MS_API bool CalcOpRunningParam(const AnfNodePtr &node);
MS_API void *GetHcclOpsKernelInfoStore();
MS_API std::string GetHcclType(const AnfNodePtr &node);
bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
bool FinalizeHccl();
bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists);
int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype);
void *GetHcclOpsKernelInfoStore();
std::string GetHcclType(const AnfNodePtr &node);
} // namespace mindspore::hccl
#undef MS_API
#endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_HCCL_ADAPTER_H

@ -64,7 +64,7 @@ namespace hccl {
bool InitHccl(uint32_t, std::string_view, std::string_view) { return true; }
bool FinalizeHccl() { return true; }
bool GenTask(const AnfNodePtr &, HcclDataType, std::vector<HcclTaskInfo> *) { return true; }
bool CalcOpRunningParam(const AnfNodePtr &) { return true; }
int64_t CalcWorkspaceSize(const AnfNodePtr &, HcclDataType) { return 0; }
void *GetHcclOpsKernelInfoStore() { return nullptr; }
std::string GetHcclType(const AnfNodePtr &) { return ""; }
} // namespace hccl

Loading…
Cancel
Save