|
|
|
@ -98,16 +98,23 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con
|
|
|
|
|
void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(hccl_node);
|
|
|
|
|
if (hccl_node->size() != 2) {
|
|
|
|
|
MS_LOG(INFO) << "node[" + AnfAlgo::GetCNodeName(hccl_node) + "]'s inputs size not equal 2";
|
|
|
|
|
return;
|
|
|
|
|
bool has_insert_memcpy = false;
|
|
|
|
|
AnfNodePtr memcpy_async = nullptr;
|
|
|
|
|
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
|
|
|
|
|
for (size_t i = 1; i < hccl_node->size(); ++i) {
|
|
|
|
|
auto input = hccl_node->input(i);
|
|
|
|
|
if (NeedInsertMemcpy(graph, input)) {
|
|
|
|
|
memcpy_async = CreateMemcpyAsyncOp(graph, input);
|
|
|
|
|
has_insert_memcpy = true;
|
|
|
|
|
new_inputs.push_back(memcpy_async);
|
|
|
|
|
} else {
|
|
|
|
|
new_inputs.push_back(input);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto input = hccl_node->input(1);
|
|
|
|
|
if (NeedInsertMemcpy(graph, input)) {
|
|
|
|
|
auto memcpy_async = CreateMemcpyAsyncOp(graph, input);
|
|
|
|
|
if (has_insert_memcpy) {
|
|
|
|
|
CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node);
|
|
|
|
|
new_hccl_node->set_inputs({hccl_node->input(0), memcpy_async});
|
|
|
|
|
new_hccl_node->set_inputs(new_inputs);
|
|
|
|
|
auto manager = graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node";
|
|
|
|
@ -115,7 +122,9 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
|
|
|
|
|
MS_LOG(DEBUG) << "end replace";
|
|
|
|
|
|
|
|
|
|
// transer hccl op's control to the memcpy_async
|
|
|
|
|
TransferControl(new_hccl_node, memcpy_async, graph);
|
|
|
|
|
if (hccl_node->size() == 2) {
|
|
|
|
|
TransferControl(new_hccl_node, memcpy_async, graph);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|