|
|
|
@ -15,23 +15,25 @@
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include "hybrid/node_executor/hccl/hccl_node_executor.h"
|
|
|
|
|
#include "common/ge/ge_util.h"
|
|
|
|
|
#include "common/ge/plugin_manager.h"
|
|
|
|
|
#include "common/math/math_util.h"
|
|
|
|
|
#include "framework/common/debug/ge_log.h"
|
|
|
|
|
#include "graph/attr_value.h"
|
|
|
|
|
#include "graph/debug/ge_attr_define.h"
|
|
|
|
|
#include "graph/manager/util/hcom_util.h"
|
|
|
|
|
#include "graph/runtime_inference_context.h"
|
|
|
|
|
#include "hccl/hcom.h"
|
|
|
|
|
#include "graph/utils/type_utils.h"
|
|
|
|
|
#include "hybrid/executor/hybrid_execution_context.h"
|
|
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
|
namespace {
|
|
|
|
|
const size_t kVarTableDims = 2;
|
|
|
|
|
const size_t kVarTableRowCnt = 3;
|
|
|
|
|
const size_t kVarTableIdxAddr = 1;
|
|
|
|
|
const size_t kVarTableIdxLen = 2;
|
|
|
|
|
constexpr size_t kVarTableDims = 2;
|
|
|
|
|
constexpr size_t kVarTableRowCnt = 3;
|
|
|
|
|
constexpr size_t kVarTableIdxAddr = 1;
|
|
|
|
|
constexpr size_t kVarTableIdxLen = 2;
|
|
|
|
|
const std::set<std::string> kRdmaReadTypes = { HCOMREMOTEREAD, HCOMREMOTEREFREAD };
|
|
|
|
|
const std::set<std::string> kRdmaWriteTypes = { HCOMREMOTEWRITE, HCOMREMOTESCATTERWRITE };
|
|
|
|
|
const std::set<std::string> kRdmaScatterTypes = { HCOMREMOTEREFREAD, HCOMREMOTESCATTERWRITE };
|
|
|
|
|
} // namespace
|
|
|
|
|
namespace ge {
|
|
|
|
|
namespace hybrid {
|
|
|
|
|
|
|
|
|
|
REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::HCCL, HcclNodeExecutor);
|
|
|
|
@ -142,11 +144,22 @@ Status RdmaNodeTask::Init(TaskContext &context) {
|
|
|
|
|
GE_CHECK_NOTNULL(peer_node->GetOpDesc());
|
|
|
|
|
|
|
|
|
|
remote_index_ = {peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx()};
|
|
|
|
|
if (node_item.node->GetType() == HCOMREMOTEREAD) {
|
|
|
|
|
if (kRdmaReadTypes.count(node_item.node->GetType()) > 0) {
|
|
|
|
|
local_index_ = 0;
|
|
|
|
|
} else {
|
|
|
|
|
local_index_ = op_desc->GetInputIndexByName("local");
|
|
|
|
|
}
|
|
|
|
|
int32_t offset_idx = node_item.op_desc->GetInputIndexByName("local_offset");
|
|
|
|
|
if ((offset_idx != -1) && (node_item.op_desc->GetInputDescPtr(offset_idx) != nullptr)) {
|
|
|
|
|
skip_flag_ = true;
|
|
|
|
|
GE_CHECK_NOTNULL(node_item.node->GetInDataAnchor(offset_idx));
|
|
|
|
|
GE_CHECK_NOTNULL(node_item.node->GetInDataAnchor(offset_idx)->GetPeerOutAnchor());
|
|
|
|
|
GE_CHECK_NOTNULL(node_item.node->GetInDataAnchor(offset_idx)->GetPeerOutAnchor()->GetOwnerNode());
|
|
|
|
|
GE_CHECK_NOTNULL(node_item.node->GetInDataAnchor(offset_idx)->GetPeerOutAnchor()->GetOwnerNode()->GetOpDesc());
|
|
|
|
|
offset_index_ = {
|
|
|
|
|
node_item.node->GetInDataAnchor(offset_idx)->GetPeerOutAnchor()->GetOwnerNode()->GetOpDesc()->GetId(),
|
|
|
|
|
node_item.node->GetInDataAnchor(offset_idx)->GetPeerOutAnchor()->GetIdx() };
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -158,8 +171,13 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess
|
|
|
|
|
GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor));
|
|
|
|
|
auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData());
|
|
|
|
|
if (data == nullptr) {
|
|
|
|
|
GELOGE(FAILED, "Tensor data is nullptr.");
|
|
|
|
|
return FAILED;
|
|
|
|
|
if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) {
|
|
|
|
|
GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName());
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
} else {
|
|
|
|
|
GELOGE(FAILED, "Tensor data is nullptr.");
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims();
|
|
|
|
|
if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) {
|
|
|
|
@ -183,30 +201,63 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess
|
|
|
|
|
auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr);
|
|
|
|
|
GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release()))));
|
|
|
|
|
}
|
|
|
|
|
} else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) {
|
|
|
|
|
AllocationAttr attr;
|
|
|
|
|
attr.SetMemType(RDMA_HBM);
|
|
|
|
|
GE_CHK_STATUS_RET(context.AllocateOutputs(&attr))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TensorValue *tv;
|
|
|
|
|
if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) {
|
|
|
|
|
tv = context.MutableOutput(0);
|
|
|
|
|
if (kRdmaReadTypes.count(context.GetNodeItem().NodeType()) > 0) {
|
|
|
|
|
tv = context.MutableOutput(local_index_);
|
|
|
|
|
} else {
|
|
|
|
|
tv = context.MutableInput(local_index_);
|
|
|
|
|
}
|
|
|
|
|
GE_CHECK_NOTNULL(tv);
|
|
|
|
|
auto local_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(tv->MutableData()));
|
|
|
|
|
auto row_num = dims.front();
|
|
|
|
|
addr_infos.resize(row_num);
|
|
|
|
|
auto device_len = tv->GetSize() / row_num;
|
|
|
|
|
if (device_len <= 0 || device_len > data[kVarTableIdxLen]) {
|
|
|
|
|
GELOGE(FAILED, "Local embedding length is out of range.");
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
if (skip_flag_) {
|
|
|
|
|
int32_t offset_idx = context.GetNodeItem().op_desc->GetInputIndexByName("local_offset");
|
|
|
|
|
GE_CHECK_NOTNULL(context.GetNodeItem().op_desc->GetInputDescPtr(offset_idx));
|
|
|
|
|
auto data_type = context.GetNodeItem().op_desc->GetInputDesc(offset_idx).GetDataType();
|
|
|
|
|
|
|
|
|
|
Tensor offset_tensor;
|
|
|
|
|
GE_CHK_STATUS_RET(ctx->GetTensor(offset_index_.first, offset_index_.second, offset_tensor))
|
|
|
|
|
if (static_cast<int64_t>(offset_tensor.GetSize() / GetSizeByDataType(data_type)) != row_num) {
|
|
|
|
|
GELOGE(PARAM_INVALID, "num of offset and remote addr mismatch, offset size=%zu, remote_addr size=%lld, dtype=%s",
|
|
|
|
|
offset_tensor.GetSize(), row_num, TypeUtils::DataTypeToSerialString(data_type).c_str());
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto idx = 0; idx < row_num; ++idx) {
|
|
|
|
|
FMK_INT64_MULCHECK(idx, kVarTableRowCnt);
|
|
|
|
|
auto line_idx = idx * kVarTableRowCnt;
|
|
|
|
|
addr_infos[idx] = {static_cast<uint32_t>(data[line_idx]), data[line_idx + kVarTableIdxAddr], local_addr,
|
|
|
|
|
device_len};
|
|
|
|
|
local_addr += device_len;
|
|
|
|
|
auto addr_offset = reinterpret_cast<uint64_t *>(offset_tensor.GetData());
|
|
|
|
|
GE_CHECK_NOTNULL(addr_offset);
|
|
|
|
|
auto base_addr = reinterpret_cast<float *>(tv->MutableData());
|
|
|
|
|
GE_CHECK_NOTNULL(base_addr);
|
|
|
|
|
|
|
|
|
|
for (auto idx = 0; idx < row_num; idx++) {
|
|
|
|
|
FMK_INT64_MULCHECK(idx, kVarTableRowCnt)
|
|
|
|
|
auto line_idx = idx * kVarTableRowCnt;
|
|
|
|
|
addr_infos[idx] = { static_cast<uint32_t>(data[line_idx]),
|
|
|
|
|
data[line_idx + kVarTableIdxAddr],
|
|
|
|
|
reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(base_addr + addr_offset[idx])),
|
|
|
|
|
data[line_idx + kVarTableIdxLen] };
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto local_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(tv->MutableData()));
|
|
|
|
|
auto device_len = tv->GetSize() / row_num;
|
|
|
|
|
if (device_len <= 0 || device_len > data[kVarTableIdxLen]) {
|
|
|
|
|
GELOGE(FAILED, "Local embedding length is out of range, expect %lld, but %lld exactly.",
|
|
|
|
|
data[kVarTableIdxLen], device_len);
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto idx = 0; idx < row_num; ++idx) {
|
|
|
|
|
FMK_INT64_MULCHECK(idx, kVarTableRowCnt)
|
|
|
|
|
auto line_idx = idx * kVarTableRowCnt;
|
|
|
|
|
addr_infos[idx] = { static_cast<uint32_t>(data[line_idx]), data[line_idx + kVarTableIdxAddr], local_addr,
|
|
|
|
|
device_len };
|
|
|
|
|
local_addr += device_len;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
@ -226,6 +277,10 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do
|
|
|
|
|
}
|
|
|
|
|
vector<HcomRemoteAccessAddrInfo> addr_infos;
|
|
|
|
|
GE_CHK_STATUS_RET(ExtractTensor(context, addr_infos));
|
|
|
|
|
if (addr_infos.empty()) {
|
|
|
|
|
done_callback();
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto callback = [this](HcclResult status) {
|
|
|
|
|
if (status != HCCL_SUCCESS) {
|
|
|
|
@ -235,6 +290,11 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do
|
|
|
|
|
this->cond_.notify_all();
|
|
|
|
|
GELOGI("rdma callback success.");
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::string executor_type = context.GetNodeItem().NodeType();
|
|
|
|
|
if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) {
|
|
|
|
|
executor_type = context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD ? HCOMREMOTEREAD : HCOMREMOTEWRITE;
|
|
|
|
|
}
|
|
|
|
|
HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback);
|
|
|
|
|
if (hccl_ret != HCCL_SUCCESS) {
|
|
|
|
|
GELOGE(HCCL_E_INTERNAL, "Call HcomExecInitialize failed, ret: 0x%X", hccl_ret);
|
|
|
|
@ -262,7 +322,7 @@ Status HcclNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const
|
|
|
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(task.Init(context), "hccl node load hccl so failed.");
|
|
|
|
|
// allocate output mem, output mem or remote read will be calculated when node execute.
|
|
|
|
|
if (context.GetNodeItem().NodeType() != HCOMREMOTEREAD) {
|
|
|
|
|
if (kRdmaReadTypes.count(context.GetNodeItem().NodeType()) == 0) {
|
|
|
|
|
GE_CHK_STATUS_RET(context.AllocateOutputs(), "hccl node task allocate output failed.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -274,7 +334,7 @@ Status HcclNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const
|
|
|
|
|
Status HcclNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const {
|
|
|
|
|
GELOGI("[%s] HcclNodeExecutor::LoadTask in.", node->GetName().c_str());
|
|
|
|
|
GE_CHECK_NOTNULL(node);
|
|
|
|
|
if (node->GetType() == HCOMREMOTEREAD || node->GetType() == HCOMREMOTEWRITE) {
|
|
|
|
|
if ((kRdmaReadTypes.count(node->GetType()) > 0) || (kRdmaWriteTypes.count(node->GetType()) > 0)) {
|
|
|
|
|
task = MakeShared<RdmaNodeTask>();
|
|
|
|
|
} else {
|
|
|
|
|
task = MakeShared<HcclNodeTask>();
|
|
|
|
|