|
|
@ -25,6 +25,18 @@
|
|
|
|
namespace mindspore {
|
|
|
|
namespace mindspore {
|
|
|
|
namespace device {
|
|
|
|
namespace device {
|
|
|
|
namespace ascend {
|
|
|
|
namespace ascend {
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
// sort format according the number of occurrences.
|
|
|
|
|
|
|
|
bool cmp_format_num(const std::pair<std::string, size_t> &a, const std::pair<std::string, size_t> &b) {
|
|
|
|
|
|
|
|
if (a.second != b.second) {
|
|
|
|
|
|
|
|
return a.second > b.second;
|
|
|
|
|
|
|
|
} else if (a.first == kOpFormat_DEFAULT) {
|
|
|
|
|
|
|
|
return a.second + 1 > b.second;
|
|
|
|
|
|
|
|
} else if (b.first == kOpFormat_DEFAULT) {
|
|
|
|
|
|
|
|
return a.second > b.second + 1;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return a.second > b.second;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TypeId GetPrimitivePrecision(const CNodePtr &cnode) {
|
|
|
|
TypeId GetPrimitivePrecision(const CNodePtr &cnode) {
|
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
|
|
@ -44,6 +56,7 @@ TypeId GetPrimitivePrecision(const CNodePtr &cnode) {
|
|
|
|
|
|
|
|
|
|
|
|
return except_type;
|
|
|
|
return except_type;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
|
|
void ResetKernelBuildInfo(const CNodePtr &kernel_node) {
|
|
|
|
void ResetKernelBuildInfo(const CNodePtr &kernel_node) {
|
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
|
|
@ -185,15 +198,12 @@ void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format,
|
|
|
|
auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first;
|
|
|
|
auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first;
|
|
|
|
MS_EXCEPTION_IF_NULL(input_kernel_node);
|
|
|
|
MS_EXCEPTION_IF_NULL(input_kernel_node);
|
|
|
|
if (!input_kernel_node->isa<Parameter>()) {
|
|
|
|
if (!input_kernel_node->isa<Parameter>()) {
|
|
|
|
auto pre_format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i);
|
|
|
|
++all_input_formats[AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i)];
|
|
|
|
++all_input_formats[pre_format];
|
|
|
|
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto para = input_kernel_node->cast<ParameterPtr>();
|
|
|
|
auto para = input_kernel_node->cast<ParameterPtr>();
|
|
|
|
MS_EXCEPTION_IF_NULL(para);
|
|
|
|
|
|
|
|
if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) {
|
|
|
|
if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) {
|
|
|
|
auto pre_format = AnfAlgo::GetOutputFormat(para, 0);
|
|
|
|
++all_input_formats[AnfAlgo::GetOutputFormat(para, 0)];
|
|
|
|
++all_input_formats[pre_format];
|
|
|
|
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
*use_same_format = false;
|
|
|
|
*use_same_format = false;
|
|
|
@ -207,17 +217,8 @@ void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format,
|
|
|
|
for (auto iter = all_input_formats.begin(); iter != all_input_formats.end(); ++iter) {
|
|
|
|
for (auto iter = all_input_formats.begin(); iter != all_input_formats.end(); ++iter) {
|
|
|
|
pairs.push_back(std::make_pair(iter->first, iter->second));
|
|
|
|
pairs.push_back(std::make_pair(iter->first, iter->second));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto cmp_func = [](const std::pair<std::string, size_t> &a, const std::pair<std::string, size_t> &b) {
|
|
|
|
|
|
|
|
if (a.second != b.second) {
|
|
|
|
std::sort(pairs.begin(), pairs.end(), cmp_format_num);
|
|
|
|
return a.second > b.second;
|
|
|
|
|
|
|
|
} else if (a.first == kOpFormat_DEFAULT) {
|
|
|
|
|
|
|
|
return a.second + 1 > b.second;
|
|
|
|
|
|
|
|
} else if (b.first == kOpFormat_DEFAULT) {
|
|
|
|
|
|
|
|
return a.second > b.second + 1;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return a.second > b.second;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
std::sort(pairs.begin(), pairs.end(), cmp_func);
|
|
|
|
|
|
|
|
*default_format = pairs.begin()->first;
|
|
|
|
*default_format = pairs.begin()->first;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -237,10 +238,9 @@ void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void UpdateGraphKernelInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list,
|
|
|
|
void UpdateInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list,
|
|
|
|
const std::string &default_format, bool use_same_format,
|
|
|
|
const std::string &default_format, bool use_same_format,
|
|
|
|
std::vector<std::string> *graph_input_format,
|
|
|
|
std::vector<std::string> *graph_input_format, std::vector<TypeId> *graph_input_type) {
|
|
|
|
std::vector<TypeId> *graph_input_type) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_input_format);
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_input_format);
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_input_type);
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_input_type);
|
|
|
|
// We set same format to all inputs of graph kernel subgraph, and process this latter.
|
|
|
|
// We set same format to all inputs of graph kernel subgraph, and process this latter.
|
|
|
@ -338,21 +338,22 @@ void UpdateEquivFormat(const std::vector<std::pair<AnfNodePtr, size_t>> &output_
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &node_list,
|
|
|
|
void CheckFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list,
|
|
|
|
const std::vector<AnfNodePtr> &input_list, const FuncGraphManagerPtr &mng,
|
|
|
|
const FuncGraphManagerPtr &mng, const std::string &default_format,
|
|
|
|
const std::string &default_format, std::vector<std::string> *graph_input_format,
|
|
|
|
std::vector<std::string> *graph_input_format, std::vector<TypeId> *graph_input_type,
|
|
|
|
std::vector<TypeId> *graph_input_type) {
|
|
|
|
std::vector<bool> *need_update) {
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
MS_EXCEPTION_IF_NULL(mng);
|
|
|
|
MS_EXCEPTION_IF_NULL(mng);
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_input_format);
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_input_format);
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_input_type);
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_input_type);
|
|
|
|
// update graph input format and dtype use inner ops.
|
|
|
|
MS_EXCEPTION_IF_NULL(need_update);
|
|
|
|
|
|
|
|
// check graph input format and dtype use inner ops.
|
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
|
|
|
if (graph_input_format->size() != input_num) {
|
|
|
|
if (graph_input_format->size() != input_num || graph_input_type->size() != input_num ||
|
|
|
|
|
|
|
|
need_update->size() != input_num) {
|
|
|
|
MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString()
|
|
|
|
MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString()
|
|
|
|
<< "], [%" << graph_input_format->size() << "] != [%" << input_num << "]";
|
|
|
|
<< "], [" << graph_input_format->size() << "] != [" << input_num << "]";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
std::vector<bool> need_update(input_num, false);
|
|
|
|
|
|
|
|
auto &node_users = mng->node_users();
|
|
|
|
auto &node_users = mng->node_users();
|
|
|
|
for (size_t i = 0; i < input_num; ++i) {
|
|
|
|
for (size_t i = 0; i < input_num; ++i) {
|
|
|
|
auto &input = input_list[i];
|
|
|
|
auto &input = input_list[i];
|
|
|
@ -372,36 +373,48 @@ void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNo
|
|
|
|
<< kernel_node->DebugString()
|
|
|
|
<< kernel_node->DebugString()
|
|
|
|
<< "] selected different format. we use defult: " << default_format;
|
|
|
|
<< "] selected different format. we use defult: " << default_format;
|
|
|
|
(*graph_input_format)[i] = default_format;
|
|
|
|
(*graph_input_format)[i] = default_format;
|
|
|
|
need_update[i] = true;
|
|
|
|
(*need_update)[i] = true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (kernel_node->input(i + 1)->isa<Parameter>() ||
|
|
|
|
|
|
|
|
AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1)) == (*graph_input_type)[i]) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (kernel_node->input(i + 1)->isa<Parameter>()) {
|
|
|
|
|
|
|
|
auto user_dtype = AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1));
|
|
|
|
|
|
|
|
if (user_dtype != (*graph_input_type)[i]) {
|
|
|
|
|
|
|
|
TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0);
|
|
|
|
TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0);
|
|
|
|
MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of ["
|
|
|
|
MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of ["
|
|
|
|
<< kernel_node->DebugString()
|
|
|
|
<< kernel_node->DebugString()
|
|
|
|
<< "] selected different dtype. we use default: " << TypeIdLabel(default_dtype);
|
|
|
|
<< "] selected different dtype. we use default: " << TypeIdLabel(default_dtype);
|
|
|
|
(*graph_input_type)[i] = default_dtype;
|
|
|
|
(*graph_input_type)[i] = default_dtype;
|
|
|
|
need_update[i] = true;
|
|
|
|
(*need_update)[i] = true;
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &node_list,
|
|
|
|
|
|
|
|
const std::vector<AnfNodePtr> &input_list, const std::vector<bool> &need_update,
|
|
|
|
|
|
|
|
const std::vector<std::string> &graph_input_format,
|
|
|
|
|
|
|
|
const std::vector<TypeId> &graph_input_type) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
|
|
|
|
// update graph input format and dtype use inner ops.
|
|
|
|
|
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
|
|
|
|
|
|
|
if (graph_input_format.size() != input_num || graph_input_type.size() != input_num ||
|
|
|
|
|
|
|
|
need_update.size() != input_num) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString()
|
|
|
|
|
|
|
|
<< "], [" << graph_input_format.size() << "] != [" << input_num << "]";
|
|
|
|
|
|
|
|
}
|
|
|
|
for (size_t i = 0; i < input_num; ++i) {
|
|
|
|
for (size_t i = 0; i < input_num; ++i) {
|
|
|
|
if (!need_update[i]) {
|
|
|
|
if (!need_update[i]) {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
need_update[i] = false;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Update input format: " << i << " of: [" << kernel_node->DebugString()
|
|
|
|
MS_LOG(DEBUG) << "Update input format: " << i << " of: [" << kernel_node->DebugString()
|
|
|
|
<< "] to: " << (*graph_input_format)[i];
|
|
|
|
<< "] to: " << graph_input_format[i];
|
|
|
|
MS_LOG(DEBUG) << "Update input dtype: " << i << " of: [" << kernel_node->DebugString()
|
|
|
|
MS_LOG(DEBUG) << "Update input dtype: " << i << " of: [" << kernel_node->DebugString()
|
|
|
|
<< "] to: " << TypeIdLabel((*graph_input_type)[i]);
|
|
|
|
<< "] to: " << TypeIdLabel(graph_input_type[i]);
|
|
|
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
|
|
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
|
|
|
std::vector<std::string> outputs_format = {(*graph_input_format)[i]};
|
|
|
|
std::vector<std::string> outputs_format = {graph_input_format[i]};
|
|
|
|
std::vector<TypeId> outputs_device_type = {(*graph_input_type)[i]};
|
|
|
|
std::vector<TypeId> outputs_device_type = {graph_input_type[i]};
|
|
|
|
builder.SetOutputsFormat(outputs_format);
|
|
|
|
builder.SetOutputsFormat(outputs_format);
|
|
|
|
builder.SetOutputsDeviceType(outputs_device_type);
|
|
|
|
builder.SetOutputsDeviceType(outputs_device_type);
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
|
|
|
@ -487,7 +500,7 @@ void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> graph_input_format;
|
|
|
|
std::vector<std::string> graph_input_format;
|
|
|
|
std::vector<TypeId> graph_input_type;
|
|
|
|
std::vector<TypeId> graph_input_type;
|
|
|
|
UpdateGraphKernelInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format,
|
|
|
|
UpdateInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format,
|
|
|
|
&graph_input_type);
|
|
|
|
&graph_input_type);
|
|
|
|
|
|
|
|
|
|
|
|
auto mng = func_graph->manager();
|
|
|
|
auto mng = func_graph->manager();
|
|
|
@ -502,8 +515,10 @@ void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func
|
|
|
|
kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
|
|
|
|
kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
|
|
|
|
|
|
|
|
|
|
|
|
// update graph input format and dtype use inner ops.
|
|
|
|
// update graph input format and dtype use inner ops.
|
|
|
|
UpdateFormatsAndDtypes(kernel_node, node_list, input_list, mng, default_format, &graph_input_format,
|
|
|
|
std::vector<bool> need_update(AnfAlgo::GetInputTensorNum(kernel_node), false);
|
|
|
|
&graph_input_type);
|
|
|
|
CheckFormatsAndDtypes(kernel_node, input_list, mng, default_format, &graph_input_format, &graph_input_type,
|
|
|
|
|
|
|
|
&need_update);
|
|
|
|
|
|
|
|
UpdateFormatsAndDtypes(kernel_node, node_list, input_list, need_update, graph_input_format, graph_input_type);
|
|
|
|
|
|
|
|
|
|
|
|
// set fix_precision for kernel when the me prim has fix_precision attr
|
|
|
|
// set fix_precision for kernel when the me prim has fix_precision attr
|
|
|
|
UpdateKernelInfo(node_list);
|
|
|
|
UpdateKernelInfo(node_list);
|
|
|
|