|
|
|
@ -15,16 +15,27 @@
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include "device/ascend/kernel_select_ascend.h"
|
|
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include "kernel/oplib/oplib.h"
|
|
|
|
|
#include "kernel/kernel_query.h"
|
|
|
|
|
#include "session/anf_runtime_algorithm.h"
|
|
|
|
|
#include "utils/context/ms_context.h"
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
|
|
|
|
|
#include "common/utils.h"
|
|
|
|
|
#include "debug/anf_ir_dump.h"
|
|
|
|
|
#include "operator/ops.h"
|
|
|
|
|
#include "ir/func_graph.h"
|
|
|
|
|
#include "utils/context/ms_context.h"
|
|
|
|
|
#include "session/anf_runtime_algorithm.h"
|
|
|
|
|
#include "device/kernel_info.h"
|
|
|
|
|
#include "kernel/common_utils.h"
|
|
|
|
|
#include "kernel/kernel_query.h"
|
|
|
|
|
#include "kernel/oplib/oplib.h"
|
|
|
|
|
#include "kernel/kernel_build_info.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace device {
|
|
|
|
@ -121,12 +132,23 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
|
|
|
|
|
}
|
|
|
|
|
auto pri_match_format = GetPriorityMatchFormat(kernel_node);
|
|
|
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
|
|
|
|
auto input_anf_node = kernel_node->input(input_index + 1);
|
|
|
|
|
// we do not take ValueNode into consideration in graph kernel.
|
|
|
|
|
if (kernel_build_info.kernel_type() == KernelType::AKG_KERNEL) {
|
|
|
|
|
if (input_anf_node->isa<ValueNode>() && AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore;
|
|
|
|
|
if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
|
|
|
|
|
(*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score;
|
|
|
|
|
}
|
|
|
|
|
if (kernel_build_info.GetInputDeviceType(input_index) ==
|
|
|
|
|
AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)) {
|
|
|
|
|
// we match output fix precision first.
|
|
|
|
|
auto prev_device_type = AnfAlgo::GetPrevNodeOutputPrecision(kernel_node, input_index);
|
|
|
|
|
if (prev_device_type == kTypeUnknown) {
|
|
|
|
|
prev_device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
|
|
|
|
|
}
|
|
|
|
|
if (kernel_build_info.GetInputDeviceType(input_index) == prev_device_type) {
|
|
|
|
|
(*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score;
|
|
|
|
|
}
|
|
|
|
|
if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) {
|
|
|
|
@ -146,41 +168,6 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
|
|
|
|
auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_kernel_node);
|
|
|
|
|
auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_with_index.first);
|
|
|
|
|
auto real_input_node = input_with_index.first;
|
|
|
|
|
if (real_input_node->isa<CNode>()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
|
|
|
|
|
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
|
|
|
|
bool is_ref = false;
|
|
|
|
|
auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
|
|
|
|
|
if (op_info != nullptr) {
|
|
|
|
|
is_ref = op_info->is_ref();
|
|
|
|
|
}
|
|
|
|
|
auto ms_context = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ms_context);
|
|
|
|
|
if (ms_context->execution_mode() == kPynativeMode &&
|
|
|
|
|
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// we set special device info of a input tensor.
|
|
|
|
|
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
|
|
|
|
|
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
|
|
|
|
|
builder->SetOutputsFormat(output_format);
|
|
|
|
|
std::vector<TypeId> output_type = {AnfAlgo::GetInputDeviceDataType(kernel_node, input_index)};
|
|
|
|
|
builder->SetOutputsDeviceType(output_type);
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *support_index) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(support_index);
|
|
|
|
|
int index = kUnSupportMixedDataTypeIndex;
|
|
|
|
@ -467,6 +454,51 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
|
|
|
|
auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_kernel_node);
|
|
|
|
|
auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_with_index.first);
|
|
|
|
|
auto real_input_node = input_with_index.first;
|
|
|
|
|
if (real_input_node->isa<CNode>()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
|
|
|
|
if (IsValueNode<tensor::Tensor>(input_kernel_node) &&
|
|
|
|
|
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) {
|
|
|
|
|
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
|
|
|
|
|
builder->SetOutputsFormat(output_format);
|
|
|
|
|
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
|
|
|
|
|
builder->SetOutputsDeviceType(output_type);
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// we set special device info of a input tensor.
|
|
|
|
|
bool is_ref = false;
|
|
|
|
|
auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
|
|
|
|
|
if (op_info != nullptr) {
|
|
|
|
|
is_ref = op_info->is_ref();
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
|
|
|
|
if (MsContext::GetInstance()->execution_mode() == kPynativeMode &&
|
|
|
|
|
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
|
|
|
|
|
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
|
|
|
|
|
builder->SetOutputsFormat(output_format);
|
|
|
|
|
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
|
|
|
|
|
builder->SetOutputsDeviceType(output_type);
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
|
|
|
|
|
const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
@ -498,11 +530,17 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
|
|
|
|
|
return select_status;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
|
|
|
|
|
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
|
|
|
|
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
|
|
|
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
|
kernel::KernelQuery(kernel_node, &kernel_info_list);
|
|
|
|
|
if (AnfAlgo::IsGraphKernel(kernel_node)) {
|
|
|
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(kernel_node->input(kAnfPrimitiveIndex));
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
SelectGraphKernelInfo(kernel_node, func_graph);
|
|
|
|
|
return kStatusAllMatched;
|
|
|
|
|
}
|
|
|
|
|
kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type);
|
|
|
|
|
auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
|
|
|
|
|
// If aicore not find valid kernel info reloading aicpu kernel info list to find it
|
|
|
|
|
if (select_status == kNoMatched) {
|
|
|
|
|