Init GraphKernel.

- It provides a unified style to express graph and kernel for user.
- It provides a unified IR to represent graph and kernel for developer.
- It breaks the boundary between graph and kernel.
- It provides more opportunities to do compile optimization.
pull/2160/head
gong chen 5 years ago committed by Xian Weizhao
parent 01216a9a57
commit a6dfa281ea

3
.gitmodules vendored

@ -13,3 +13,6 @@
[submodule "graphengine"]
path = graphengine
url = https://gitee.com/mindspore/graphengine.git
[submodule "akg"]
path = akg
url = https://gitee.com/mindspore/akg.git

@ -86,10 +86,14 @@ if (ENABLE_GE OR ENABLE_D OR ENABLE_TESTCASES)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain)
endif()
if (ENABLE_AKG AND ENABLE_D)
add_subdirectory("${CMAKE_SOURCE_DIR}/akg")
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
add_subdirectory(mindspore/ccsrc)
if (ENABLE_TESTCASES)
add_subdirectory(tests)
endif()
include(cmake/package.cmake)
include(cmake/package.cmake)

1
akg

@ -0,0 +1 @@
Subproject commit c460176523d039c8995f1d71089753725ebc0792

@ -246,6 +246,9 @@ checkopts "$@"
echo "---------------- mindspore: build start ----------------"
mkdir -pv "${BUILD_PATH}/package/mindspore/lib"
git submodule update --init graphengine
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then
git submodule update --init --recursive akg
fi
build_exit()
{
@ -308,7 +311,7 @@ build_mindspore()
if [[ "X$USE_GLOG" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DUSE_GLOG=ON"
fi
if [[ "X$ENABLE_AKG" = "Xon" ]]; then
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON"
fi
echo "${CMAKE_ARGS}"

@ -236,6 +236,16 @@ if (ENABLE_GPU)
endif ()
endif ()
if (ENABLE_D AND ENABLE_AKG)
set (AKG_PATH ${CMAKE_SOURCE_DIR}/build/mindspore/akg)
install(
DIRECTORY
${AKG_PATH}/akg
DESTINATION ${INSTALL_PY_DIR}/..
COMPONENT mindspore
)
endif ()
if (EXISTS ${CMAKE_SOURCE_DIR}/mindspore/dataset)
install(
DIRECTORY ${CMAKE_SOURCE_DIR}/mindspore/dataset

@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

@ -0,0 +1,35 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Providing akg compile with json"""
import sys
def run_compiler(op_json):
"""
Run AKG compiler to compile op with subprocess, if this process of
compilation failed, an exception will be raised
Args:
op_json (str): json string of the op
Returns:
None
"""
p = __import__("akg", globals(), locals(), ['ms'], 0)
func = getattr(p.ms, "compilewithjson")
res = func(op_json)
if not res:
raise ValueError("Compile error")
if __name__ == "__main__":
run_compiler(sys.argv[1])

@ -0,0 +1,71 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Providing multi process compile with json"""
import os
import subprocess
import sys
from multiprocessing import Pool, cpu_count
def _compile_akg_task(*json_strs):
"""
compile func called in single process
Parameters:
json_strs: list. List contains multiple kernel infos, suitable for json compile api.
"""
akg_compiler = os.path.join(os.path.split(
os.path.realpath(__file__))[0], "compiler.py")
for json_str in json_strs:
res = subprocess.run(
[sys.executable, akg_compiler, json_str], text=True)
if res.returncode != 0:
raise ValueError("Failed, args: {}!".format(json_str))
def compile_akg_kernel_parallel(json_infos, process, waitime):
"""
compile kernel use multi processes
Parameters:
json_infos: list. list contain kernel info(task id and json str)
process: int. processes num
waittime: int. max time the function blocked
Returns:
True for all compile success, False for some failed.
"""
if not isinstance(json_infos, list):
raise ValueError("json_infos must be a list")
if not isinstance(process, int):
raise ValueError("process must be a num")
if not isinstance(waitime, int):
raise ValueError("waittime must be a num")
if process == 0 and json_infos:
process = 1
cpu_proc_num = cpu_count()
max_proc_num = 16
process = min([cpu_proc_num, max_proc_num, process])
args = [[] for _ in range(process)]
for p, info in enumerate(json_infos):
args[p % process].append(info)
with Pool(processes=process) as pool:
res = pool.starmap_async(_compile_akg_task, args)
res.get(timeout=waitime)
return True

@ -1,107 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Providing multi process compile with json"""
import json
import math
import os
import subprocess
import sys
from multiprocessing import Pool
def _compiletask(platform, *jsons):
"""
compile func called in single process
Parameters:
platform: str. AKG platform or TBE platform
*jsons: str. json str contain kernel info, suitable for json compile
api
"""
if platform == "AKG":
p = __import__("_akg", globals(), locals(), ['ms'], 0)
func = getattr(p.ms, "compilewithjson")
for json_item in jsons:
res = func(json_item)
if not res:
raise ValueError("Compile error")
if platform == "TBE":
tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "tbe_compiler", "compiler.py")
for json_item in jsons:
res = subprocess.run([sys.executable, tbe_compiler], input=json_item, text=True)
if res.returncode != 0:
raise ValueError("Tbe compile error")
def compilekernelparallel(jsons, process, waitime):
"""
compile kernel use multi processes
Parameters:
jsons: list. json str list contain kernel info
process: int. processes num
waittime: int. max time the function blocked
"""
if not isinstance(jsons, list):
raise ValueError("jsons must be a list")
if not isinstance(process, int):
raise ValueError("process must be a num")
if not isinstance(waitime, int):
raise ValueError("waittime must be a num")
jsons_akg = []
jsons_tbe = []
for json_ in jsons:
j = json.loads(json_)
if j["platform"] == "TBE":
jsons_tbe.append(json_)
continue
if j["platform"] == "AKG":
jsons_akg.append(json_)
continue
raise RuntimeError(
"not support this platform {0}".format(j["platform"]))
if jsons_akg:
process_akg = math.floor(len(jsons)/len(jsons_akg)*process)
else:
process_akg = 0
if process_akg == 0 and jsons_akg:
process_akg = 1
process_tbe = process-process_akg
if process_tbe == 0 and jsons_tbe:
process_tbe = 1
raise RuntimeWarning("we add a process for compile more operator")
args = [[] for _ in range(process_akg+process_tbe)]
args_lens = len(args)
for p in range(args_lens):
if p < process_tbe:
args[p].append("TBE")
else:
args[p].append("AKG")
jsons_tbe_lens = len(jsons_tbe)
for p in range(jsons_tbe_lens):
args[p % process_tbe].append(jsons_tbe[p])
jsons_akg_lens = len(jsons_akg)
for p in range(jsons_akg_lens):
args[process-p % process_akg-1].append(jsons_akg[p])
for p in range(args_lens):
args[p] = tuple(args[p])
with Pool(processes=process) as pool:
res = pool.starmap_async(_compiletask, args)
res.get(timeout=waitime)
return True

@ -39,7 +39,7 @@ if(ENABLE_GPU)
"device/gpu/*.cu"
"kernel/gpu/*.cu"
"kernel/akg/gpu/*.cc"
"kernel/akg/akgkernelbuild.cc"
"kernel/akg/akg_kernel_build.cc"
"kernel/akg/akg_kernel_attrs_process.cc"
)

@ -428,6 +428,10 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
auto temp_shape = shape;
std::vector<size_t> device_shape;
if (format == kOpFormat_FRAC_NZ) {
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
// For [1] and [1024] shape we can trait it as NZ shape
return shape;
}
if (shape.size() < 2) {
MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size();
} else {

@ -111,9 +111,15 @@ void DumpGlobalInfoEntry(const FuncGraphPtr &graph, std::ostringstream &buffer)
}
buffer << "#IR entry : @" << graph->ToString() << "." << graph->debug_info()->get_id() << std::endl;
buffer << "#flags :" << std::endl;
for (const auto &flag : graph->flags()) {
buffer << flag.first << " : " << flag.second << std::endl;
buffer << "#attrs :" << std::endl;
for (const auto &attr : graph->attrs()) {
buffer << attr.first << " : ";
if (attr.second->isa<BoolImm>()) {
buffer << GetValue<bool>(attr.second);
} else if (attr.second->isa<StringImm>()) {
buffer << GetValue<std::string>(attr.second);
}
buffer << std::endl;
}
}
@ -417,10 +423,16 @@ void DumpSubgraph(const OrderedMap<FuncGraphPtr, std::shared_ptr<SubGraphIRInfo>
fout << std::endl;
for (const auto &sg : *sub_graphs) {
fout << "subgraph flag:" << std::endl;
fout << "subgraph attr:" << std::endl;
MS_EXCEPTION_IF_NULL(sg.first);
for (const auto &flag : sg.first->flags()) {
fout << flag.first << " : " << flag.second << std::endl;
for (const auto &attr : sg.first->attrs()) {
fout << attr.first << " : ";
if (attr.second->isa<BoolImm>()) {
fout << GetValue<bool>(attr.second);
} else if (attr.second->isa<StringImm>()) {
fout << GetValue<std::string>(attr.second);
}
fout << std::endl;
}
fout << "subgraph @" << sg.first->ToString() << ".";
fout << sg.first->debug_info()->get_id() << "(";

@ -548,9 +548,15 @@ void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr<session::KernelGr
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
cur_cnode_ptr = cnode_ptr_list[i];
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
ValuePtr value_ptr = nullptr;
auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr);
MS_EXCEPTION_IF_NULL(primitive);
auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst);
if (primitive != nullptr) {
value_ptr = primitive->GetAttr(kStreamNeedActivedFirst);
} else {
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cur_cnode_ptr);
MS_EXCEPTION_IF_NULL(func_graph);
value_ptr = func_graph->get_attr(kStreamNeedActivedFirst);
}
if (value_ptr == nullptr) {
continue;
}

@ -26,10 +26,12 @@
#include "kernel/kernel.h"
#include "kernel/tbe/tbe_kernel_build.h"
#include "kernel/tbe/tbe_kernel_parallel_build.h"
#include "kernel/akg/ascend/akg_ascend_kernel_build.h"
#include "kernel/aicpu/aicpu_kernel_build.h"
#include "kernel/hccl/hccl_kernel_build.h"
#include "kernel/rts/rt_kernel_build.h"
#include "kernel/tbe/tbe_utils.h"
#include "kernel/common_utils.h"
#include "operator/ops.h"
#include "session/anf_runtime_algorithm.h"
#include "./common.h"
@ -91,6 +93,7 @@ static bool KernelPreBuildParallelCompile(const mindspore::session::KernelGraph
static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
std::vector<AnfNodePtr> tbe_nodes;
std::vector<AnfNodePtr> akg_nodes;
std::vector<AnfNodePtr> other_nodes;
for (const auto &anf_node : kernel_graph_ptr->execution_order()) {
MS_EXCEPTION_IF_NULL(anf_node);
@ -105,19 +108,26 @@ static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *ke
}
break;
}
case KernelType::AKG_KERNEL: {
akg_nodes.push_back(anf_node);
break;
}
default: {
other_nodes.push_back(anf_node);
break;
}
}
}
bool ret = kernel::TbeOpParallelBuild(tbe_nodes);
bool tbe_ret = kernel::TbeOpParallelBuild(tbe_nodes);
bool akg_ret = kernel::AkgAscendKernelParallelBuild(akg_nodes);
auto bin_map = kernel::tbe::KernelMeta::GetInstance();
(void)bin_map->ReadIndex(kernel::kCceKernelMeta);
for (const auto &anf_node : other_nodes) {
kernel::KernelModPtr kernel_mod_ptr = SerialCompileImpl(anf_node);
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
}
return ret;
return tbe_ret && akg_ret;
}
static std::vector<int> CalCleanZerosSize(const CNodePtr &pre_node) {
@ -234,7 +244,7 @@ void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
for (const auto &anf_node : kernel_graph->execution_order()) {
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
AnfAlgo::GetKernelType(anf_node) == KernelType::AUTO_DIFF_KERNEL) {
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
MS_EXCEPTION_IF_NULL(clear_zero_prim);
auto new_value_node = NewValueNode(clear_zero_prim);

@ -15,16 +15,27 @@
*/
#include "device/ascend/kernel_select_ascend.h"
#include <string>
#include <vector>
#include <memory>
#include <utility>
#include <algorithm>
#include <map>
#include "kernel/oplib/oplib.h"
#include "kernel/kernel_query.h"
#include "session/anf_runtime_algorithm.h"
#include "utils/context/ms_context.h"
#include <unordered_map>
#include <unordered_set>
#include "common/utils.h"
#include "debug/anf_ir_dump.h"
#include "operator/ops.h"
#include "ir/func_graph.h"
#include "utils/context/ms_context.h"
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
#include "kernel/common_utils.h"
#include "kernel/kernel_query.h"
#include "kernel/oplib/oplib.h"
#include "kernel/kernel_build_info.h"
namespace mindspore {
namespace device {
@ -121,12 +132,23 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
}
auto pri_match_format = GetPriorityMatchFormat(kernel_node);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
auto input_anf_node = kernel_node->input(input_index + 1);
// we do not take ValueNode into consideration in graph kernel.
if (kernel_build_info.kernel_type() == KernelType::AKG_KERNEL) {
if (input_anf_node->isa<ValueNode>() && AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) {
continue;
}
}
auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore;
if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
(*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score;
}
if (kernel_build_info.GetInputDeviceType(input_index) ==
AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)) {
// we match output fix precision first.
auto prev_device_type = AnfAlgo::GetPrevNodeOutputPrecision(kernel_node, input_index);
if (prev_device_type == kTypeUnknown) {
prev_device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
}
if (kernel_build_info.GetInputDeviceType(input_index) == prev_device_type) {
(*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score;
}
if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) {
@ -146,41 +168,6 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
}
}
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
MS_EXCEPTION_IF_NULL(input_kernel_node);
auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
MS_EXCEPTION_IF_NULL(input_with_index.first);
auto real_input_node = input_with_index.first;
if (real_input_node->isa<CNode>()) {
continue;
}
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
bool is_ref = false;
auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
if (op_info != nullptr) {
is_ref = op_info->is_ref();
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode &&
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
continue;
}
// we set special device info of a input tensor.
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {AnfAlgo::GetInputDeviceDataType(kernel_node, input_index)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
}
}
}
void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *support_index) {
MS_EXCEPTION_IF_NULL(support_index);
int index = kUnSupportMixedDataTypeIndex;
@ -467,6 +454,51 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis
}
} // namespace
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
MS_EXCEPTION_IF_NULL(input_kernel_node);
auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
MS_EXCEPTION_IF_NULL(input_with_index.first);
auto real_input_node = input_with_index.first;
if (real_input_node->isa<CNode>()) {
continue;
}
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
continue;
}
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
if (IsValueNode<tensor::Tensor>(input_kernel_node) &&
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
continue;
}
// we set special device info of a input tensor.
bool is_ref = false;
auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
if (op_info != nullptr) {
is_ref = op_info->is_ref();
}
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
if (MsContext::GetInstance()->execution_mode() == kPynativeMode &&
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
continue;
}
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
}
}
}
KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node);
@ -498,11 +530,17 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
return select_status;
}
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list;
MS_EXCEPTION_IF_NULL(kernel_node);
kernel::KernelQuery(kernel_node, &kernel_info_list);
if (AnfAlgo::IsGraphKernel(kernel_node)) {
auto func_graph = GetValueNode<FuncGraphPtr>(kernel_node->input(kAnfPrimitiveIndex));
MS_EXCEPTION_IF_NULL(func_graph);
SelectGraphKernelInfo(kernel_node, func_graph);
return kStatusAllMatched;
}
kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type);
auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
// If aicore not find valid kernel info reloading aicpu kernel info list to find it
if (select_status == kNoMatched) {

@ -27,7 +27,10 @@ enum KernelSelectStatus {
kStatusReducePrecision = 1,
kStatusRaisePrecision = 2,
};
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node);
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node,
KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node);
void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph);
} // namespace ascend
} // namespace device
} // namespace mindspore

File diff suppressed because it is too large Load Diff

@ -24,7 +24,7 @@ namespace device {
namespace ascend {
void GraphDescReporter::ReportData() {
for (const auto &node : cnode_list_) {
if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AUTO_DIFF_KERNEL) {
if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) {
MS_LOG(WARNING) << "Skip non tbe kernel";
continue;
}

@ -31,7 +31,7 @@ void TaskDescReporter::ReportData() {
size_t task_index = 0;
for (const auto &node : cnode_list_) {
if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AUTO_DIFF_KERNEL) {
if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) {
MS_LOG(WARNING) << "Skip non tbe kernel";
++task_index;
continue;

@ -43,7 +43,37 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve
void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) {
MS_EXCEPTION_IF_NULL(anf_node_ptr);
if (anf_node_ptr->inputs().size() != 2) {
MS_LOG(EXCEPTION) << "atomic Addr clean Node Input nodes not equal 2.";
// akg process
// set atomic clean addr
if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, anf_node_ptr)) {
auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(anf_node_ptr, kAttrAutomicOutputIndexs);
auto graph = anf_node_ptr->func_graph();
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto node_users = manager->node_users();
if (node_users[anf_node_ptr].empty()) {
MS_LOG(EXCEPTION) << "Node users of " << anf_node_ptr->ToString() << " is empty.";
}
auto depend_node = node_users[anf_node_ptr].pop().first;
if (!IsPrimitiveCNode(depend_node, prim::kPrimDepend)) {
MS_LOG(EXCEPTION) << "Checking Depend node failed";
}
if (node_users[depend_node].empty()) {
MS_LOG(EXCEPTION) << "Node users of " << depend_node->ToString() << " is empty.";
}
auto post_node = node_users[depend_node].pop().first;
for (auto index : clean_output_indexs) {
auto device_address = AnfAlgo::GetOutputAddr(post_node, index);
kernel::AddressPtr input = std::make_shared<kernel::Address>();
input->addr = device_address->ptr_;
MS_EXCEPTION_IF_NULL(input->addr);
input->size = device_address->size_;
kernel_inputs->push_back(input);
}
MS_LOG(DEBUG) << "AtomicAddClean clean output size: " << clean_output_indexs.size();
}
return;
}
MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]);
auto pre_node = (anf_node_ptr->inputs()[1])->cast<CNodePtr>();
@ -59,7 +89,7 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP
input->size = device_address->size_;
kernel_inputs->push_back(input);
}
MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexs.size();
MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size();
}
// set clean workspace address
if (AnfAlgo::HasNodeAttr(kAttrAutomicWorkspaceSize, pre_node)) {

@ -16,7 +16,7 @@
#include "device/gpu/gpu_kernel_build.h"
#include <string>
#include "kernel/kernel.h"
#include "kernel/akg/akgkernelbuild.h"
#include "kernel/akg/akg_kernel_build.h"
#include "kernel/akg/gpu/akg_gpu_kernel_build.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "operator/ops.h"
@ -37,7 +37,7 @@ void GpuBuild(const KernelGraphPtr &kernel_graph) {
continue;
}
if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AUTO_DIFF_KERNEL) {
if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AKG_KERNEL) {
auto gpu_kernel_ptr = kernel::AkgGpuKernelBuild(kernel);
if (!gpu_kernel_ptr) {
MS_LOG(EXCEPTION) << "Build akg kernel op[" << kernel_name << "] failed";

@ -184,7 +184,7 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
if (!result) {
result = SelectAkgKernel(kernel_node, builder->Build());
kernel_type = AUTO_DIFF_KERNEL;
kernel_type = AKG_KERNEL;
}
if (!result) {

@ -26,6 +26,8 @@
#include "ir/func_graph.h"
#include "ir/primitive_base.h"
#include "operator/ops.h"
namespace mindspore {
// namespace to support intermediate representation definition
CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph)
@ -106,10 +108,14 @@ std::string ValueNode::fullname_with_scope() {
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode != nullptr) {
if (cnode == nullptr) {
return false;
}
if (value != nullptr) {
return cnode->IsApply(value);
}
return false;
const auto &prim = GetValueNode<PrimitivePtr>(cnode->input(0));
return prim != nullptr;
}
PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {

@ -124,6 +124,7 @@ class AnfNode : public Base {
const KernelInfoDevice *kernel_info() const { return kernel_info_.get(); }
KernelInfoDevice *kernel_info() { return kernel_info_.get(); }
const KernelInfoDevicePtr &kernel_info_ptr() { return kernel_info_; }
void set_kernel_info(const KernelInfoDevicePtr &kernel_info) { kernel_info_ = kernel_info; }
AbstractBasePtr abstract() const { return abstract_; }
@ -395,9 +396,9 @@ static S GetValue(const ValuePtr &value) {
std::string GetCNodeFuncName(CNodePtr cnode);
// used to check whether an AnfNode is a cnode with a kind of Primitive as first input
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value);
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value = nullptr);
// used to check whether an AnfNode is a cnode with a Primitive as first input
// used to get PrimitivePtr from a cnode first input
PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
// used to check whether an AnfNode is a valuenode having some Primitive value

@ -70,7 +70,7 @@ std::string CNode::fullname_with_scope() {
}
fullname_with_scope_ = name;
} else {
// cnode input 0 should be primitive ptr
// cnode input 0 should be primitive ptr or funcgraph ptr
auto value_ptr = input(0)->cast<ValueNodePtr>();
if (value_ptr == nullptr) {
MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << ".";
@ -84,11 +84,23 @@ std::string CNode::fullname_with_scope() {
return fullname_with_scope_;
}
PrimitivePtr prim = GetValue<PrimitivePtr>(input_value);
auto prim = input_value->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(scope());
MS_EXCEPTION_IF_NULL(prim);
fullname_with_scope_ =
scope()->name() + "/" + prim->name() + "-op" + id_generator::get_id(shared_from_base<CNode>());
fullname_with_scope_ = scope()->name() + "/";
if (prim != nullptr) {
fullname_with_scope_ += prim->name();
} else {
auto func_graph = input_value->cast<FuncGraphPtr>();
MS_EXCEPTION_IF_NULL(func_graph);
auto fg_flag = func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
if (fg_flag != nullptr) {
auto fg_name = GetValue<std::string>(fg_flag);
fullname_with_scope_ += "GraphKernel_" + fg_name;
} else {
fullname_with_scope_ += func_graph->ToString();
}
}
fullname_with_scope_ += "-op" + id_generator::get_id(shared_from_base<CNode>());
}
return fullname_with_scope_;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save