|
|
|
@ -25,6 +25,14 @@
|
|
|
|
|
namespace ge {
|
|
|
|
|
Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) {
|
|
|
|
|
GE_CHECK_NOTNULL(graph);
|
|
|
|
|
for (const auto &node : graph->GetAllNodes()) {
|
|
|
|
|
if (node->GetType() == STREAMSWITCH) {
|
|
|
|
|
auto sub_graph = node->GetOwnerComputeGraph();
|
|
|
|
|
if (sub_graph != nullptr && !sub_graph->GetGraphUnknownFlag()) {
|
|
|
|
|
GE_CHK_STATUS_RET(AddMemcpyAsyncNode(node), "Add memcpyasync node failed in known subgraph.");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (graph->GetGraphUnknownFlag()) {
|
|
|
|
|
GELOGD("Graph[%s] is unknown graph, skip.", graph->GetName().c_str());
|
|
|
|
|
return SUCCESS;
|
|
|
|
@ -63,6 +71,28 @@ Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status MemcpyAddrAsyncPass::AddMemcpyAsyncNode(const NodePtr &node) {
|
|
|
|
|
GE_CHECK_NOTNULL(node);
|
|
|
|
|
GELOGI("Start add memcpyasync node in front of node %s", node->GetName().c_str());
|
|
|
|
|
known_sub_graph_ = true;
|
|
|
|
|
auto sub_graph = node->GetOwnerComputeGraph();
|
|
|
|
|
for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) {
|
|
|
|
|
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
|
|
|
|
|
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue);
|
|
|
|
|
auto memcpy_async_node = CreateMemcpyAddrAsyncNode(sub_graph, peer_out_anchor, node);
|
|
|
|
|
if (memcpy_async_node == nullptr) {
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "Create memcpyasync node failed.");
|
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
|
}
|
|
|
|
|
Status ret = InsertMemcpyAddrAsyncNode(peer_out_anchor, in_data_anchor, memcpy_async_node);
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
GELOGE(ret, "Insert memcpyasync node failed.");
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status MemcpyAddrAsyncPass::AddMemcpyAddrAsyncNode(const ComputeGraphPtr &graph, const NodePtr &node) {
|
|
|
|
|
GELOGI("Start AddMemcpyAddrAsyncNode for %s.", node->GetName().c_str());
|
|
|
|
|
for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) {
|
|
|
|
@ -208,9 +238,15 @@ NodePtr MemcpyAddrAsyncPass::CreateMemcpyAddrAsyncNode(const ComputeGraphPtr &gr
|
|
|
|
|
static uint32_t new_node_index = 0;
|
|
|
|
|
OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc();
|
|
|
|
|
GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "Op_desc of pre node is invalid.");
|
|
|
|
|
std::string node_name = pre_op_desc->GetName() + "_" + MEMCPYADDRASYNC + "_" + std::to_string(new_node_index++);
|
|
|
|
|
|
|
|
|
|
OpDescPtr op_desc = MakeShared<OpDesc>(node_name, MEMCPYADDRASYNC);
|
|
|
|
|
OpDescPtr op_desc = nullptr;
|
|
|
|
|
if (known_sub_graph_) { // insert memcpyasync node when known sub graph
|
|
|
|
|
string node_name = pre_op_desc->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(new_node_index++);
|
|
|
|
|
op_desc = MakeShared<OpDesc>(node_name, MEMCPYASYNC);
|
|
|
|
|
} else {
|
|
|
|
|
string node_name = pre_op_desc->GetName() + "_" + MEMCPYADDRASYNC + "_" + std::to_string(new_node_index++);
|
|
|
|
|
op_desc = MakeShared<OpDesc>(node_name, MEMCPYADDRASYNC);
|
|
|
|
|
}
|
|
|
|
|
GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
|
|
|
|
|
|
|
|
|
|
if (op_desc->AddInputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) != GRAPH_SUCCESS) {
|
|
|
|
|