diff --git a/ge/graph/manager/rdma_pool_allocator.cc b/ge/graph/manager/rdma_pool_allocator.cc index 03e01bd2..93d1fd1d 100644 --- a/ge/graph/manager/rdma_pool_allocator.cc +++ b/ge/graph/manager/rdma_pool_allocator.cc @@ -202,7 +202,7 @@ Status RdmaPoolAllocator::GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size) { GELOGE(INTERNAL_ERROR, "Rdma base addr is nullptr."); return INTERNAL_ERROR; } - base_addr = reinterpret_cast(reinterpret_cast(rdma_base_addr_)); + base_addr = static_cast(reinterpret_cast(rdma_base_addr_)); mem_size = rdma_mem_size_; return SUCCESS; } diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index f47a02fd..413ff54c 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -701,6 +701,9 @@ Status HybridModelBuilder::LoadGraph() { GE_CHK_STATUS_RET(IdentifyVariableOutputs(*parent_node_item), "[%s] Failed to identify ref outputs.", parent_node_item->NodeName().c_str()); + GE_CHK_STATUS_RET(IdentifySameInputs(*parent_node_item), + "[%s] Failed to identify same outputs.", + parent_node_item->NodeName().c_str()); // if parent is function control op. need add a virtual partitioned call if (parent_node_item->IsControlOp()) { @@ -1162,6 +1165,46 @@ Status HybridModelBuilder::InitRuntimeParams() { return SUCCESS; } +Status HybridModelBuilder::IdentifySameInputs(NodeItem &node_item) { + GELOGD("Start to parse same inputs on net output: %s", node_item.NodeName().c_str()); + auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); + GE_CHECK_NOTNULL(subgraph); + auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT); + if (net_output_node == nullptr) { + GELOGD("Subgraph [%s] does not have net output", subgraph->GetName().c_str()); + return SUCCESS; + } + + auto net_output_desc = net_output_node->GetOpDesc(); + GE_CHECK_NOTNULL(net_output_desc); + + std::map connected_inputs; + for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { + auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); + if (out_data_anchor == nullptr) { + continue; + } + auto src_node = out_data_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + auto op_desc = src_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + std::string input_key = std::to_string(op_desc->GetId()) + "_" + std::to_string(out_data_anchor->GetIdx()); + auto it = connected_inputs.find(input_key); + if (it == connected_inputs.end()) { + connected_inputs.emplace(input_key, in_data_anchor->GetIdx()); + } else { + GELOGD("[%s] output [%d] reuse output [%d] input node = %s, idx = %d.", node_item.NodeName().c_str(), + in_data_anchor->GetIdx(), + it->second, + src_node->GetName().c_str(), + out_data_anchor->GetIdx()); + node_item.reuse_outputs.emplace(in_data_anchor->GetIdx(), it->second); + } + } + return SUCCESS; +} + Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str()); auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index d78d622b..86e5ec01 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -59,6 +59,7 @@ class HybridModelBuilder { Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); Status LoadTasks(); Status IdentifyVariableOutputs(NodeItem &node_item); + Status IdentifySameInputs(NodeItem &node_item); Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); Status ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies); diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index c10cf13e..0136425a 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -83,6 +83,7 @@ struct NodeItem { const NodeExecutor *node_executor = nullptr; std::map ref_outputs; std::map reuse_inputs; + std::map reuse_outputs; std::vector is_input_shape_static; bool is_output_shape_static = true; diff --git a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc index 0d6f52e8..723cb7bc 100644 --- a/ge/hybrid/node_executor/hccl/hccl_node_executor.cc +++ b/ge/hybrid/node_executor/hccl/hccl_node_executor.cc @@ -189,13 +189,20 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector(reinterpret_cast(tv->MutableData())); - addr_infos.resize(dims.front()); - for (auto idx = 0; idx < dims.front(); ++idx) { + auto row_num = dims.front(); + addr_infos.resize(row_num); + auto device_len = tv->GetSize() / row_num; + if (device_len <= 0 || device_len > data[kVarTableIdxLen]) { + GELOGE(FAILED, "Local embedding length is out of range."); + return FAILED; + } + + for (auto idx = 0; idx < row_num; ++idx) { FMK_INT64_MULCHECK(idx, kVarTableRowCnt); auto line_idx = idx * kVarTableRowCnt; addr_infos[idx] = {static_cast(data[line_idx]), data[line_idx + kVarTableIdxAddr], local_addr, - data[line_idx + kVarTableIdxLen]}; - local_addr += data[line_idx + kVarTableIdxLen]; + device_len}; + local_addr += device_len; } return SUCCESS; diff --git a/ge/hybrid/node_executor/task_context.cc b/ge/hybrid/node_executor/task_context.cc index 29fc777b..7bdc587b 100644 --- a/ge/hybrid/node_executor/task_context.cc +++ b/ge/hybrid/node_executor/task_context.cc @@ -221,16 +221,22 @@ Status TaskContext::AllocateOutput(int index, GE_CHECK_NOTNULL(ref_tensor); outputs_start_[index] = *ref_tensor; } else { - auto reuse_input = node_item_->reuse_inputs.find(index); - if (reuse_input != node_item_->reuse_inputs.end()) { - GELOGD("[%s] Output[%d] is referenced to input[%d]", GetNodeName(), index, reuse_input->second); - outputs_start_[index] = inputs_start_[reuse_input->second]; + auto reuse_output_it = node_item_->reuse_outputs.find(index); + if (reuse_output_it != node_item_->reuse_outputs.end()) { + GELOGD("[%s] reuse output [%d] with output [%d]", GetNodeName(), index, reuse_output_it->second); + outputs_start_[index] = outputs_start_[reuse_output_it->second]; } else { - GE_CHK_STATUS_RET_NOLOG(AllocateTensor(tensor_desc, outputs_start_[index], attr)); - GELOGD("Allocating output successfully. node: %s. index = %d, size = %zu", - node_item_->NodeName().c_str(), - index, - outputs_start_[index].GetSize()); + auto reuse_input = node_item_->reuse_inputs.find(index); + if (reuse_input != node_item_->reuse_inputs.end()) { + GELOGD("[%s] Output[%d] is referenced to input[%d]", GetNodeName(), index, reuse_input->second); + outputs_start_[index] = inputs_start_[reuse_input->second]; + } else { + GE_CHK_STATUS_RET_NOLOG(AllocateTensor(tensor_desc, outputs_start_[index], attr)); + GELOGD("Allocating output successfully. node: %s. index = %d, size = %zu", + node_item_->NodeName().c_str(), + index, + outputs_start_[index].GetSize()); + } } }