|
|
|
@ -205,6 +205,10 @@ static bool IsAtomicNode(const CNodePtr &kernel_node) {
|
|
|
|
|
}
|
|
|
|
|
// process output
|
|
|
|
|
std::vector<size_t> output_indexs = {};
|
|
|
|
|
if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, kernel_node)) {
|
|
|
|
|
output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(kernel_node, kAttrAtomicOutputIndexs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < output_num; ++i) {
|
|
|
|
|
auto param_output = parameters_indexs.at(input_num + workspace_num + i);
|
|
|
|
|
if (param_output == 1) {
|
|
|
|
@ -212,7 +216,10 @@ static bool IsAtomicNode(const CNodePtr &kernel_node) {
|
|
|
|
|
MS_LOG(INFO) << "Atomic clear output index: " << i;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!output_indexs.empty()) {
|
|
|
|
|
std::set<size_t> s(output_indexs.begin(), output_indexs.end());
|
|
|
|
|
output_indexs.assign(s.begin(), s.end());
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(output_indexs), kernel_node);
|
|
|
|
|
}
|
|
|
|
|
// process workspace
|
|
|
|
@ -244,11 +251,49 @@ bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) {
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::map<AnfNodePtr, std::vector<size_t>> GetCommunicationOpInputInfo(
|
|
|
|
|
const mindspore::session::KernelGraph *kernel_graph) {
|
|
|
|
|
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map;
|
|
|
|
|
for (auto &kernel : kernel_graph->execution_order()) {
|
|
|
|
|
auto input_num = AnfAlgo::GetInputTensorNum(kernel);
|
|
|
|
|
if (mindspore::session::AnfRuntimeAlgorithm::IsCommunicationOp(kernel)) {
|
|
|
|
|
for (size_t i = 0; i < input_num; i++) {
|
|
|
|
|
auto input_node = kernel->input(i + 1);
|
|
|
|
|
auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
|
|
|
|
|
MS_LOG(INFO) << " Add atomic clean for single communication op input, comm:" << kernel->fullname_with_scope()
|
|
|
|
|
<< " input_node: " << kernel_input.first->fullname_with_scope()
|
|
|
|
|
<< " index: " << kernel_input.second;
|
|
|
|
|
auto iter = comm_input_info_map.find(kernel_input.first);
|
|
|
|
|
if (iter != comm_input_info_map.end()) {
|
|
|
|
|
iter->second.push_back(kernel_input.second);
|
|
|
|
|
} else {
|
|
|
|
|
std::vector<size_t> indexes = {kernel_input.second};
|
|
|
|
|
comm_input_info_map[kernel_input.first] = indexes;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// remove duplicate index
|
|
|
|
|
for (auto &info : comm_input_info_map) {
|
|
|
|
|
std::set<size_t> s(info.second.begin(), info.second.end());
|
|
|
|
|
info.second.assign(s.begin(), s.end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return comm_input_info_map;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
std::vector<CNodePtr> new_nodes;
|
|
|
|
|
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph);
|
|
|
|
|
for (const auto &anf_node : kernel_graph->execution_order()) {
|
|
|
|
|
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
|
|
|
|
|
if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) {
|
|
|
|
|
auto indexes = comm_input_info_map[anf_node];
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
|
|
|
|
|
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
|
|
|
|
|
auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
|
|
|
|
|