|
|
@ -36,15 +36,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
namespace mindspore {
|
|
|
|
namespace kernel {
|
|
|
|
namespace kernel {
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
constexpr int32_t PROCESS_NUM = 16;
|
|
|
|
constexpr int32_t PROCESS_NUM = 16;
|
|
|
|
constexpr int32_t TIME_OUT = 300;
|
|
|
|
constexpr int32_t TIME_OUT = 300;
|
|
|
|
|
|
|
|
|
|
|
|
bool AkgAscendKernelBuilder::AkgOpParallelBuild(
|
|
|
|
void SetKernelMod(const KernelPackPtr &kernel_pack, const AkgKernelJsonGenerator &json_generator,
|
|
|
|
const std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> &build_args) {
|
|
|
|
const AnfNodePtr &anf_node) {
|
|
|
|
|
|
|
|
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(kernel_pack);
|
|
|
|
|
|
|
|
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
|
|
|
|
|
|
|
|
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
|
|
|
|
|
|
|
|
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> AkgAscendKernelBuilder::GetNotCachedKernelJsons(const std::vector<JsonNodePair> &build_args) {
|
|
|
|
// Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess.
|
|
|
|
// Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess.
|
|
|
|
std::vector<std::string> jsons;
|
|
|
|
std::vector<std::string> jsons;
|
|
|
|
std::unordered_set<std::string> kernel_name_set;
|
|
|
|
std::unordered_set<std::string> kernel_name_set;
|
|
|
|
std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> repeat_nodes;
|
|
|
|
|
|
|
|
for (const auto &[json_generator, anf_node] : build_args) {
|
|
|
|
for (const auto &[json_generator, anf_node] : build_args) {
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node);
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node);
|
|
|
|
auto kernel_name = json_generator.kernel_name();
|
|
|
|
auto kernel_name = json_generator.kernel_name();
|
|
|
@ -53,15 +61,12 @@ bool AkgAscendKernelBuilder::AkgOpParallelBuild(
|
|
|
|
if (cached_kernel_pack != nullptr) {
|
|
|
|
if (cached_kernel_pack != nullptr) {
|
|
|
|
MS_LOG(DEBUG) << "Use cached kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
|
|
|
|
MS_LOG(DEBUG) << "Use cached kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
|
|
|
|
<< anf_node->fullname_with_scope() << "].";
|
|
|
|
<< anf_node->fullname_with_scope() << "].";
|
|
|
|
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack);
|
|
|
|
SetKernelMod(cached_kernel_pack, json_generator, anf_node);
|
|
|
|
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
|
|
|
|
|
|
|
|
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
|
|
|
|
|
|
|
|
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
|
|
|
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (kernel_name_set.count(kernel_name) != 0) {
|
|
|
|
if (kernel_name_set.count(kernel_name) != 0) {
|
|
|
|
repeat_nodes.push_back({json_generator, anf_node});
|
|
|
|
repeat_nodes_.push_back({json_generator, anf_node});
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
kernel_name_set.insert(kernel_name);
|
|
|
|
kernel_name_set.insert(kernel_name);
|
|
|
@ -69,7 +74,43 @@ bool AkgAscendKernelBuilder::AkgOpParallelBuild(
|
|
|
|
kernel::SaveJsonInfo(kernel_name, kernel_json);
|
|
|
|
kernel::SaveJsonInfo(kernel_name, kernel_json);
|
|
|
|
jsons.push_back(kernel_json);
|
|
|
|
jsons.push_back(kernel_json);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return jsons;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool AkgAscendKernelBuilder::InsertToCache(const std::vector<JsonNodePair> &build_args) {
|
|
|
|
|
|
|
|
for (const auto &[json_generator, anf_node] : build_args) {
|
|
|
|
|
|
|
|
auto kernel_name = json_generator.kernel_name();
|
|
|
|
|
|
|
|
auto new_kernel_pack = tbe::TbeUtils::InsertCache(kernel_name, GetProcessorStr(anf_node));
|
|
|
|
|
|
|
|
if (new_kernel_pack == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope["
|
|
|
|
|
|
|
|
<< anf_node->fullname_with_scope() << "].";
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
SetKernelMod(new_kernel_pack, json_generator, anf_node);
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Akg compile " << kernel_name << " kernel and insert cache successfully!";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool AkgAscendKernelBuilder::HandleRepeatNodes() {
|
|
|
|
|
|
|
|
for (const auto &[json_generator, anf_node] : repeat_nodes_) {
|
|
|
|
|
|
|
|
auto kernel_name = json_generator.kernel_name();
|
|
|
|
|
|
|
|
auto cached_kernel_pack = tbe::TbeUtils::SearchCache(kernel_name, GetProcessorStr(anf_node));
|
|
|
|
|
|
|
|
if (cached_kernel_pack == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Use cached kernel failed, kernel_name[" << kernel_name << "], fullname_with_scope["
|
|
|
|
|
|
|
|
<< anf_node->fullname_with_scope() << "].";
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Use just compiled kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
|
|
|
|
|
|
|
|
<< anf_node->fullname_with_scope() << "].";
|
|
|
|
|
|
|
|
SetKernelMod(cached_kernel_pack, json_generator, anf_node);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool AkgAscendKernelBuilder::AkgOpParallelBuild(const std::vector<JsonNodePair> &build_args) {
|
|
|
|
|
|
|
|
repeat_nodes_.clear();
|
|
|
|
|
|
|
|
auto jsons = GetNotCachedKernelJsons(build_args);
|
|
|
|
if (jsons.empty()) {
|
|
|
|
if (jsons.empty()) {
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -89,56 +130,35 @@ bool AkgAscendKernelBuilder::AkgOpParallelBuild(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// All unique done here, cache them and set kernel.
|
|
|
|
// All unique done here, cache them and set kernel.
|
|
|
|
for (const auto &[json_generator, anf_node] : build_args) {
|
|
|
|
if (!InsertToCache(build_args)) {
|
|
|
|
auto kernel_name = json_generator.kernel_name();
|
|
|
|
MS_LOG(ERROR) << "Insert cache failed.";
|
|
|
|
auto new_kernel_pack = tbe::TbeUtils::InsertCache(kernel_name, GetProcessorStr(anf_node));
|
|
|
|
return false;
|
|
|
|
if (new_kernel_pack == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope["
|
|
|
|
|
|
|
|
<< anf_node->fullname_with_scope() << "].";
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(new_kernel_pack);
|
|
|
|
|
|
|
|
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
|
|
|
|
|
|
|
|
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
|
|
|
|
|
|
|
|
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Akg compile " << kernel_name << " kernel and insert cache successfully!";
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Handle repeated nodes.
|
|
|
|
if (!HandleRepeatNodes()) {
|
|
|
|
for (const auto &[json_generator, anf_node] : repeat_nodes) {
|
|
|
|
MS_LOG(ERROR) << "Handle repeat nodes failed.";
|
|
|
|
auto kernel_name = json_generator.kernel_name();
|
|
|
|
return false;
|
|
|
|
auto cached_kernel_pack = tbe::TbeUtils::SearchCache(kernel_name, GetProcessorStr(anf_node));
|
|
|
|
|
|
|
|
if (cached_kernel_pack == nullptr) return false;
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Use just compiled kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
|
|
|
|
|
|
|
|
<< anf_node->fullname_with_scope() << "].";
|
|
|
|
|
|
|
|
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack);
|
|
|
|
|
|
|
|
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
|
|
|
|
|
|
|
|
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
|
|
|
|
|
|
|
|
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
|
|
|
|
bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
|
|
|
|
std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> json_and_node;
|
|
|
|
std::vector<JsonNodePair> json_and_node;
|
|
|
|
for (const auto &anf_node : anf_nodes) {
|
|
|
|
for (const auto &anf_node : anf_nodes) {
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node);
|
|
|
|
MS_EXCEPTION_IF_NULL(anf_node);
|
|
|
|
AkgKernelJsonGenerator akg_kernel_json_generator;
|
|
|
|
AkgKernelJsonGenerator akg_kernel_json_generator;
|
|
|
|
KernelPackPtr kernel_pack = nullptr;
|
|
|
|
|
|
|
|
auto cnode = anf_node->cast<CNodePtr>();
|
|
|
|
auto cnode = anf_node->cast<CNodePtr>();
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
if (AnfAlgo::IsGraphKernel(cnode)) {
|
|
|
|
if (AnfAlgo::IsGraphKernel(cnode)) {
|
|
|
|
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
|
|
|
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
auto mng = func_graph->manager();
|
|
|
|
auto mng = func_graph->manager();
|
|
|
|
if (mng == nullptr) {
|
|
|
|
if (mng == nullptr) {
|
|
|
|
mng = Manage(func_graph, true);
|
|
|
|
mng = Manage(func_graph, true);
|
|
|
|
func_graph->set_manager(mng);
|
|
|
|
func_graph->set_manager(mng);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
std::vector<AnfNodePtr> node_list, input_list, output_list;
|
|
|
|
std::vector<AnfNodePtr> node_list;
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> input_list;
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> output_list;
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Akg start compile composite op[" << anf_node->fullname_with_scope() << "]";
|
|
|
|
MS_LOG(INFO) << "Akg start compile composite op[" << anf_node->fullname_with_scope() << "]";
|
|
|
|
GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
|
|
|
|
GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
|
|
|
|
if (!akg_kernel_json_generator.CollectFusedJson(node_list, input_list, output_list)) {
|
|
|
|
if (!akg_kernel_json_generator.CollectFusedJson(node_list, input_list, output_list)) {
|
|
|
@ -146,7 +166,7 @@ bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
if (!akg_kernel_json_generator.CollectJson(anf_node)) {
|
|
|
|
if (!akg_kernel_json_generator.CollectJson(anf_node)) {
|
|
|
|
MS_EXCEPTION(UnknownError) << "Akg build failed op[" << anf_node->fullname_with_scope() << "].";
|
|
|
|
MS_EXCEPTION(UnknownError) << "Akg build failed basic op[" << anf_node->fullname_with_scope() << "].";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
json_and_node.push_back({akg_kernel_json_generator, anf_node});
|
|
|
|
json_and_node.push_back({akg_kernel_json_generator, anf_node});
|
|
|
|