!2504 modify code review

Merge pull request !2504 from Destiny/master-mod-review
pull/2504/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 02e6dbb695

@ -25,6 +25,18 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
namespace {
// sort format according the number of occurrences.
bool cmp_format_num(const std::pair<std::string, size_t> &a, const std::pair<std::string, size_t> &b) {
if (a.second != b.second) {
return a.second > b.second;
} else if (a.first == kOpFormat_DEFAULT) {
return a.second + 1 > b.second;
} else if (b.first == kOpFormat_DEFAULT) {
return a.second > b.second + 1;
}
return a.second > b.second;
}
TypeId GetPrimitivePrecision(const CNodePtr &cnode) { TypeId GetPrimitivePrecision(const CNodePtr &cnode) {
auto primitive = AnfAlgo::GetCNodePrimitive(cnode); auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
@ -44,6 +56,7 @@ TypeId GetPrimitivePrecision(const CNodePtr &cnode) {
return except_type; return except_type;
} }
} // namespace
void ResetKernelBuildInfo(const CNodePtr &kernel_node) { void ResetKernelBuildInfo(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
@ -185,15 +198,12 @@ void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format,
auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first;
MS_EXCEPTION_IF_NULL(input_kernel_node); MS_EXCEPTION_IF_NULL(input_kernel_node);
if (!input_kernel_node->isa<Parameter>()) { if (!input_kernel_node->isa<Parameter>()) {
auto pre_format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i); ++all_input_formats[AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i)];
++all_input_formats[pre_format];
continue; continue;
} }
auto para = input_kernel_node->cast<ParameterPtr>(); auto para = input_kernel_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(para);
if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) { if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) {
auto pre_format = AnfAlgo::GetOutputFormat(para, 0); ++all_input_formats[AnfAlgo::GetOutputFormat(para, 0)];
++all_input_formats[pre_format];
continue; continue;
} }
*use_same_format = false; *use_same_format = false;
@ -207,17 +217,8 @@ void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format,
for (auto iter = all_input_formats.begin(); iter != all_input_formats.end(); ++iter) { for (auto iter = all_input_formats.begin(); iter != all_input_formats.end(); ++iter) {
pairs.push_back(std::make_pair(iter->first, iter->second)); pairs.push_back(std::make_pair(iter->first, iter->second));
} }
auto cmp_func = [](const std::pair<std::string, size_t> &a, const std::pair<std::string, size_t> &b) {
if (a.second != b.second) { std::sort(pairs.begin(), pairs.end(), cmp_format_num);
return a.second > b.second;
} else if (a.first == kOpFormat_DEFAULT) {
return a.second + 1 > b.second;
} else if (b.first == kOpFormat_DEFAULT) {
return a.second > b.second + 1;
}
return a.second > b.second;
};
std::sort(pairs.begin(), pairs.end(), cmp_func);
*default_format = pairs.begin()->first; *default_format = pairs.begin()->first;
} }
@ -237,10 +238,9 @@ void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format,
} }
} }
void UpdateGraphKernelInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list, void UpdateInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list,
const std::string &default_format, bool use_same_format, const std::string &default_format, bool use_same_format,
std::vector<std::string> *graph_input_format, std::vector<std::string> *graph_input_format, std::vector<TypeId> *graph_input_type) {
std::vector<TypeId> *graph_input_type) {
MS_EXCEPTION_IF_NULL(graph_input_format); MS_EXCEPTION_IF_NULL(graph_input_format);
MS_EXCEPTION_IF_NULL(graph_input_type); MS_EXCEPTION_IF_NULL(graph_input_type);
// We set same format to all inputs of graph kernel subgraph, and process this latter. // We set same format to all inputs of graph kernel subgraph, and process this latter.
@ -338,21 +338,22 @@ void UpdateEquivFormat(const std::vector<std::pair<AnfNodePtr, size_t>> &output_
} }
} }
void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &node_list, void CheckFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &input_list, const FuncGraphManagerPtr &mng, const FuncGraphManagerPtr &mng, const std::string &default_format,
const std::string &default_format, std::vector<std::string> *graph_input_format, std::vector<std::string> *graph_input_format, std::vector<TypeId> *graph_input_type,
std::vector<TypeId> *graph_input_type) { std::vector<bool> *need_update) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
MS_EXCEPTION_IF_NULL(graph_input_format); MS_EXCEPTION_IF_NULL(graph_input_format);
MS_EXCEPTION_IF_NULL(graph_input_type); MS_EXCEPTION_IF_NULL(graph_input_type);
// update graph input format and dtype use inner ops. MS_EXCEPTION_IF_NULL(need_update);
// check graph input format and dtype use inner ops.
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (graph_input_format->size() != input_num) { if (graph_input_format->size() != input_num || graph_input_type->size() != input_num ||
need_update->size() != input_num) {
MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString() MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString()
<< "], [%" << graph_input_format->size() << "] != [%" << input_num << "]"; << "], [" << graph_input_format->size() << "] != [" << input_num << "]";
} }
std::vector<bool> need_update(input_num, false);
auto &node_users = mng->node_users(); auto &node_users = mng->node_users();
for (size_t i = 0; i < input_num; ++i) { for (size_t i = 0; i < input_num; ++i) {
auto &input = input_list[i]; auto &input = input_list[i];
@ -372,36 +373,48 @@ void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNo
<< kernel_node->DebugString() << kernel_node->DebugString()
<< "] selected different format. we use defult: " << default_format; << "] selected different format. we use defult: " << default_format;
(*graph_input_format)[i] = default_format; (*graph_input_format)[i] = default_format;
need_update[i] = true; (*need_update)[i] = true;
} }
if (kernel_node->input(i + 1)->isa<Parameter>()) { if (kernel_node->input(i + 1)->isa<Parameter>() ||
auto user_dtype = AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1)); AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1)) == (*graph_input_type)[i]) {
if (user_dtype != (*graph_input_type)[i]) { continue;
TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0);
MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of ["
<< kernel_node->DebugString()
<< "] selected different dtype. we use default: " << TypeIdLabel(default_dtype);
(*graph_input_type)[i] = default_dtype;
need_update[i] = true;
}
} }
TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0);
MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of ["
<< kernel_node->DebugString()
<< "] selected different dtype. we use default: " << TypeIdLabel(default_dtype);
(*graph_input_type)[i] = default_dtype;
(*need_update)[i] = true;
} }
} }
}
void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list, const std::vector<bool> &need_update,
const std::vector<std::string> &graph_input_format,
const std::vector<TypeId> &graph_input_type) {
MS_EXCEPTION_IF_NULL(kernel_node);
// update graph input format and dtype use inner ops.
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (graph_input_format.size() != input_num || graph_input_type.size() != input_num ||
need_update.size() != input_num) {
MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString()
<< "], [" << graph_input_format.size() << "] != [" << input_num << "]";
}
for (size_t i = 0; i < input_num; ++i) { for (size_t i = 0; i < input_num; ++i) {
if (!need_update[i]) { if (!need_update[i]) {
continue; continue;
} }
need_update[i] = false;
MS_LOG(DEBUG) << "Update input format: " << i << " of: [" << kernel_node->DebugString() MS_LOG(DEBUG) << "Update input format: " << i << " of: [" << kernel_node->DebugString()
<< "] to: " << (*graph_input_format)[i]; << "] to: " << graph_input_format[i];
MS_LOG(DEBUG) << "Update input dtype: " << i << " of: [" << kernel_node->DebugString() MS_LOG(DEBUG) << "Update input dtype: " << i << " of: [" << kernel_node->DebugString()
<< "] to: " << TypeIdLabel((*graph_input_type)[i]); << "] to: " << TypeIdLabel(graph_input_type[i]);
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
std::vector<std::string> outputs_format = {(*graph_input_format)[i]}; std::vector<std::string> outputs_format = {graph_input_format[i]};
std::vector<TypeId> outputs_device_type = {(*graph_input_type)[i]}; std::vector<TypeId> outputs_device_type = {graph_input_type[i]};
builder.SetOutputsFormat(outputs_format); builder.SetOutputsFormat(outputs_format);
builder.SetOutputsDeviceType(outputs_device_type); builder.SetOutputsDeviceType(outputs_device_type);
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
@ -487,8 +500,8 @@ void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func
std::vector<std::string> graph_input_format; std::vector<std::string> graph_input_format;
std::vector<TypeId> graph_input_type; std::vector<TypeId> graph_input_type;
UpdateGraphKernelInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format, UpdateInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format,
&graph_input_type); &graph_input_type);
auto mng = func_graph->manager(); auto mng = func_graph->manager();
if (mng == nullptr) { if (mng == nullptr) {
@ -502,8 +515,10 @@ void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func
kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
// update graph input format and dtype use inner ops. // update graph input format and dtype use inner ops.
UpdateFormatsAndDtypes(kernel_node, node_list, input_list, mng, default_format, &graph_input_format, std::vector<bool> need_update(AnfAlgo::GetInputTensorNum(kernel_node), false);
&graph_input_type); CheckFormatsAndDtypes(kernel_node, input_list, mng, default_format, &graph_input_format, &graph_input_type,
&need_update);
UpdateFormatsAndDtypes(kernel_node, node_list, input_list, need_update, graph_input_format, graph_input_type);
// set fix_precision for kernel when the me prim has fix_precision attr // set fix_precision for kernel when the me prim has fix_precision attr
UpdateKernelInfo(node_list); UpdateKernelInfo(node_list);

Loading…
Cancel
Save