|
|
|
@ -46,7 +46,7 @@ std::string MsOpNameToHcomOpType(const std::string &ms_op_type) {
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace kernel {
|
|
|
|
|
void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) {
|
|
|
|
|
void HcclKernelFactory::Register(const std::string &name, HcclKernelCreater &&fun) {
|
|
|
|
|
hcclKernelMap_.emplace(name, std::move(fun));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -99,7 +99,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) {
|
|
|
|
|
if (op_name_ == kReceive) {
|
|
|
|
|
auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(receive_type_);
|
|
|
|
|
if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) {
|
|
|
|
|
MS_LOG(ERROR) << "HcomDataType cann't support Current Ascend Data Type : " << receive_type_;
|
|
|
|
|
MS_LOG(ERROR) << "HcomDataType cannot support Current Ascend Data Type : " << receive_type_;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
hccl_data_type_list_.emplace_back(iter->second);
|
|
|
|
@ -180,9 +180,17 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
|
|
|
|
|
return output_size_list_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
|
|
|
|
const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const {
|
|
|
|
|
if (!workspace_size_list_.empty() || hccl_data_type_list_.empty()) {
|
|
|
|
|
return workspace_size_list_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
workspace_size_list_.emplace_back(hccl::CalcWorkspaceSize(anf_node_, hccl_data_type_list_[0]));
|
|
|
|
|
return workspace_size_list_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
|
|
|
|
std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
const std::vector<AddressPtr> &workspace,
|
|
|
|
|
const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
|
|
|
|
|
std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_);
|
|
|
|
|
if (hccl_type == kReceive) {
|
|
|
|
@ -221,10 +229,19 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
|
|
|
|
|
MS_LOG(EXCEPTION) << "Set data memcpy_s failed, ret = " << sec_ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void *workspace_addr = nullptr;
|
|
|
|
|
if (task.workspace_size != 0) {
|
|
|
|
|
if (workspace.empty()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node_->DebugString() << " is empty";
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(workspace.at(0));
|
|
|
|
|
workspace_addr = workspace.at(0)->addr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
results.emplace_back(std::make_shared<HcclTaskInfo>(
|
|
|
|
|
kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, task.workspace_size,
|
|
|
|
|
task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, op_type_, data_type,
|
|
|
|
|
group_, NeedDump()));
|
|
|
|
|
kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, workspace_addr,
|
|
|
|
|
task.workspace_size, task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_,
|
|
|
|
|
op_type_, data_type, group_, NeedDump()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return results;
|
|
|
|
|