|
|
|
@ -40,6 +40,8 @@ namespace device {
|
|
|
|
|
namespace ascend {
|
|
|
|
|
using mindspore::kernel::tbe::TbeUtils;
|
|
|
|
|
using std::make_shared;
|
|
|
|
|
constexpr size_t kMaxAttrMemListSize = 192;
|
|
|
|
|
|
|
|
|
|
static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) {
|
|
|
|
|
kernel::KernelModPtr kernel_mod_ptr = nullptr;
|
|
|
|
|
KernelType kernel_type = AnfAlgo::GetKernelType(anf_node);
|
|
|
|
@ -159,6 +161,30 @@ static void AddTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_gr
|
|
|
|
|
new_nodes->push_back(clear_zero);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void AddFusionTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_graph,
|
|
|
|
|
const mindspore::CNodePtr &stream_node,
|
|
|
|
|
const std::vector<AnfNodePtr> &fusion_clear_inputs,
|
|
|
|
|
const std::vector<size_t> &clean_size_list,
|
|
|
|
|
std::vector<mindspore::CNodePtr> *new_nodes) {
|
|
|
|
|
auto clear_zero_prim = std::make_shared<Primitive>(kAtomicAddrCleanOpName);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(clear_zero_prim);
|
|
|
|
|
auto new_value_node = NewValueNode(clear_zero_prim);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_value_node);
|
|
|
|
|
std::vector<AnfNodePtr> inputs = {new_value_node};
|
|
|
|
|
inputs.insert(inputs.end(), fusion_clear_inputs.begin(), fusion_clear_inputs.end());
|
|
|
|
|
CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(clear_zero);
|
|
|
|
|
AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(abstract);
|
|
|
|
|
clear_zero->set_abstract(abstract);
|
|
|
|
|
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
|
|
|
|
builder->SetKernelType(KernelType::TBE_KERNEL);
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get());
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size_list), clear_zero);
|
|
|
|
|
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(stream_node.get()), clear_zero.get());
|
|
|
|
|
new_nodes->insert(new_nodes->begin(), clear_zero);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool IsAtomicNode(const CNodePtr &kernel_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
|
auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node);
|
|
|
|
@ -264,23 +290,23 @@ std::map<AnfNodePtr, std::vector<size_t>> GetCommunicationOpInputInfo(
|
|
|
|
|
return comm_input_info_map;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
static void TbeClearZeroNodeFusion(mindspore::session::KernelGraph *const kernel_graph) {
|
|
|
|
|
std::vector<CNodePtr> new_nodes;
|
|
|
|
|
std::vector<size_t> clean_size_list;
|
|
|
|
|
std::vector<AnfNodePtr> fusion_clear_inputs;
|
|
|
|
|
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph);
|
|
|
|
|
for (const auto &anf_node : kernel_graph->execution_order()) {
|
|
|
|
|
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
|
|
|
|
|
bool is_comm_input = false;
|
|
|
|
|
// set communication input output index attr
|
|
|
|
|
if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) {
|
|
|
|
|
auto indexes = comm_input_info_map[anf_node];
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node);
|
|
|
|
|
is_comm_input = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (is_comm_input) {
|
|
|
|
|
AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes);
|
|
|
|
|
} else if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
|
|
|
|
|
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
|
|
|
|
|
if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
|
|
|
|
|
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
|
|
|
|
|
auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(clear_zero_prim);
|
|
|
|
|
auto new_value_node = NewValueNode(clear_zero_prim);
|
|
|
|
@ -299,15 +325,85 @@ void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
|
|
|
|
|
// set the distinction label of clear same with anf
|
|
|
|
|
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get());
|
|
|
|
|
new_nodes.push_back(clear_zero);
|
|
|
|
|
} else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL) {
|
|
|
|
|
if (IsAtomicNode(anf_node)) {
|
|
|
|
|
AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes);
|
|
|
|
|
} else if (is_comm_input ||
|
|
|
|
|
(AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL && IsAtomicNode(anf_node))) {
|
|
|
|
|
auto clean_sizes = CalCleanZerosSize(anf_node);
|
|
|
|
|
if (!clean_sizes.empty()) {
|
|
|
|
|
auto clean_total_num = clean_size_list.size() + clean_sizes.size();
|
|
|
|
|
if (clean_total_num >= kMaxAttrMemListSize) {
|
|
|
|
|
// create clean node
|
|
|
|
|
auto stream_node = new_nodes.empty() ? anf_node : new_nodes.front();
|
|
|
|
|
AddFusionTbeClearZeroNode(kernel_graph, stream_node, fusion_clear_inputs, clean_size_list, &new_nodes);
|
|
|
|
|
clean_size_list.clear();
|
|
|
|
|
fusion_clear_inputs.clear();
|
|
|
|
|
}
|
|
|
|
|
clean_size_list.insert(clean_size_list.end(), clean_sizes.begin(), clean_sizes.end());
|
|
|
|
|
fusion_clear_inputs.emplace_back(anf_node);
|
|
|
|
|
MS_LOG(DEBUG) << "fusion_clear_inputs size: " << fusion_clear_inputs.size()
|
|
|
|
|
<< ", clean_size_list: " << clean_size_list.size();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
new_nodes.push_back(anf_node);
|
|
|
|
|
new_nodes.emplace_back(anf_node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!fusion_clear_inputs.empty() && !clean_size_list.empty()) {
|
|
|
|
|
// create clean node
|
|
|
|
|
auto stream_node = new_nodes.front();
|
|
|
|
|
AddFusionTbeClearZeroNode(kernel_graph, stream_node, fusion_clear_inputs, clean_size_list, &new_nodes);
|
|
|
|
|
}
|
|
|
|
|
kernel_graph->set_execution_order(new_nodes);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
static const auto enable_fusion_clear = (common::GetEnv("ENV_FUSION_CLEAR") == "1");
|
|
|
|
|
bool is_dynamic_graph = kernel_graph->is_dynamic_shape();
|
|
|
|
|
if (!is_dynamic_graph && enable_fusion_clear) {
|
|
|
|
|
TbeClearZeroNodeFusion(kernel_graph);
|
|
|
|
|
} else {
|
|
|
|
|
std::vector<CNodePtr> new_nodes;
|
|
|
|
|
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph);
|
|
|
|
|
for (const auto &anf_node : kernel_graph->execution_order()) {
|
|
|
|
|
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
|
|
|
|
|
bool is_comm_input = false;
|
|
|
|
|
if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) {
|
|
|
|
|
auto indexes = comm_input_info_map[anf_node];
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node);
|
|
|
|
|
is_comm_input = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (is_comm_input) {
|
|
|
|
|
AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes);
|
|
|
|
|
} else if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
|
|
|
|
|
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
|
|
|
|
|
auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(clear_zero_prim);
|
|
|
|
|
auto new_value_node = NewValueNode(clear_zero_prim);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_value_node);
|
|
|
|
|
std::vector<AnfNodePtr> inputs = {new_value_node};
|
|
|
|
|
inputs.push_back(anf_node);
|
|
|
|
|
CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(clear_zero);
|
|
|
|
|
auto kernel_info = std::make_shared<device::KernelInfo>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
clear_zero->set_kernel_info(kernel_info);
|
|
|
|
|
AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(abstract);
|
|
|
|
|
AnfAlgo::SetNodeAttr("input_names", MakeValue(std::vector<std::string>({"x"})), clear_zero);
|
|
|
|
|
SelectKernelInfo(clear_zero);
|
|
|
|
|
// set the distinction label of clear same with anf
|
|
|
|
|
AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get());
|
|
|
|
|
new_nodes.push_back(clear_zero);
|
|
|
|
|
} else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL) {
|
|
|
|
|
if (IsAtomicNode(anf_node)) {
|
|
|
|
|
AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
new_nodes.push_back(anf_node);
|
|
|
|
|
}
|
|
|
|
|
kernel_graph->set_execution_order(new_nodes);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace ascend
|
|
|
|
|
} // namespace device
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|