|
|
|
@ -23,9 +23,7 @@
|
|
|
|
|
#include "kernel/oplib/oplib.h"
|
|
|
|
|
#include "kernel/kernel_query.h"
|
|
|
|
|
#include "session/anf_runtime_algorithm.h"
|
|
|
|
|
#include "kernel/kernel_build_info.h"
|
|
|
|
|
#include "utils/context/ms_context.h"
|
|
|
|
|
#include "operator/ops.h"
|
|
|
|
|
#include "debug/anf_ir_dump.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
@ -45,7 +43,6 @@ enum MatchCountPriority : int {
|
|
|
|
|
MATCH_COUNT_PRIORITY_END
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
const size_t kMaxCount = 0xffffffff;
|
|
|
|
|
const int kUnSupportMixedDataTypeIndex = -1;
|
|
|
|
|
|
|
|
|
|
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
|
|
|
|
@ -91,7 +88,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
|
|
|
|
|
return priority_matched_format;
|
|
|
|
|
}
|
|
|
|
|
/**
|
|
|
|
|
* compare two vector by priority, select a better vector, like compare two num, first compare highest num location,
|
|
|
|
|
* Compare two vector by priority, select a better vector, like compare two num, first compare highest num location,
|
|
|
|
|
* if equal then next num location
|
|
|
|
|
* example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3]
|
|
|
|
|
*/
|
|
|
|
@ -167,8 +164,9 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
|
|
|
|
|
if (op_info != nullptr) {
|
|
|
|
|
is_ref = op_info->is_ref();
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
|
|
|
|
if (MsContext::GetInstance()->execution_mode() == kPynativeMode &&
|
|
|
|
|
auto ms_context = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ms_context);
|
|
|
|
|
if (ms_context->execution_mode() == kPynativeMode &&
|
|
|
|
|
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
@ -221,6 +219,7 @@ void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index,
|
|
|
|
|
std::vector<TypeId> *node_mix_precision_datatype) {
|
|
|
|
|
AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cur_input);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node_mix_precision_datatype);
|
|
|
|
|
TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index);
|
|
|
|
|
AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index);
|
|
|
|
|
node_mix_precision_datatype->push_back(input_origin_type);
|
|
|
|
@ -229,6 +228,7 @@ void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index,
|
|
|
|
|
void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index,
|
|
|
|
|
std::vector<int> *node_mix_precision_datatype_index,
|
|
|
|
|
std::vector<TypeId> *node_mix_precision_datatype) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node_mix_precision_datatype);
|
|
|
|
|
auto output_origin_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_index);
|
|
|
|
|
AddSupportMixedPrecisionDataTypeIndex(output_origin_type, node_mix_precision_datatype_index);
|
|
|
|
|
node_mix_precision_datatype->push_back(output_origin_type);
|
|
|
|
@ -239,12 +239,12 @@ void CheckDataTypeInputs(const std::vector<int> &node_mix_precision_datatype_ind
|
|
|
|
|
const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
|
|
|
|
|
std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) {
|
|
|
|
|
if (node_mix_precision_datatype_index.size() != node_mix_precision_datatype.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "node datatype index size " << node_mix_precision_datatype_index.size() << " != datatype size "
|
|
|
|
|
MS_LOG(EXCEPTION) << "Node datatype index size " << node_mix_precision_datatype_index.size() << " != datatype size "
|
|
|
|
|
<< node_mix_precision_datatype.size();
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx);
|
|
|
|
|
if (kernel_support_datatypes.size() != kernel_match_datatype_idx->size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "kernel datatype index size " << kernel_match_datatype_idx->size() << " != datatype size "
|
|
|
|
|
MS_LOG(EXCEPTION) << "Kernel datatype index size " << kernel_match_datatype_idx->size() << " != datatype size "
|
|
|
|
|
<< kernel_support_datatypes.size();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -265,10 +265,10 @@ bool RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_dat
|
|
|
|
|
if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) {
|
|
|
|
|
auto find_iter = kernel_support_datatypes.find(iter->first);
|
|
|
|
|
if (find_iter == kernel_support_datatypes.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "kernel datatype index:%lu can not be found " << iter->first;
|
|
|
|
|
MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first;
|
|
|
|
|
}
|
|
|
|
|
if (i >= find_iter->second.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "node index " << i << "kernel datatype size " << find_iter->second.size();
|
|
|
|
|
MS_LOG(EXCEPTION) << "Node index " << i << "kernel datatype size " << find_iter->second.size();
|
|
|
|
|
}
|
|
|
|
|
if (node_mix_precision_datatype[i] != find_iter->second[i]) {
|
|
|
|
|
iter = kernel_match_datatype_idx->erase(iter);
|
|
|
|
@ -279,7 +279,7 @@ bool RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_dat
|
|
|
|
|
}
|
|
|
|
|
auto datatype_indexes = iter->second;
|
|
|
|
|
if (i >= datatype_indexes.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "node datatype index: " << i << " kernel support size " << datatype_indexes.size();
|
|
|
|
|
MS_LOG(EXCEPTION) << "Node datatype index: " << i << " kernel support size " << datatype_indexes.size();
|
|
|
|
|
}
|
|
|
|
|
if (datatype_indexes[i] < node_mix_precision_datatype_index[i]) {
|
|
|
|
|
iter = kernel_match_datatype_idx->erase(iter);
|
|
|
|
@ -415,8 +415,8 @@ std::shared_ptr<kernel::KernelBuildInfo> ChooseMatchedKernelInfo(
|
|
|
|
|
size_t selected_index = 0;
|
|
|
|
|
for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
|
|
|
|
|
std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0, 0};
|
|
|
|
|
auto kernel_build_info = *(kernel_info_list[info_index]);
|
|
|
|
|
std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index];
|
|
|
|
|
auto kernel_info_ptr = kernel_info_list[info_index];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info_ptr);
|
|
|
|
|
UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts);
|
|
|
|
|
// Currently the selection policy is the match format count first, and then is datatype counts.
|
|
|
|
|
if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) {
|
|
|
|
|