|
|
|
@ -50,8 +50,10 @@ enum MatchCountPriority : int {
|
|
|
|
|
MATCH_OUTPUT_DTYPE_COUNT,
|
|
|
|
|
MATCH_COUNT_PRIORITY_END
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
const int kUnSupportMixedDataTypeIndex = -1;
|
|
|
|
|
const std::map<std::string, std::vector<std::string>> kNextOpFormatList = {
|
|
|
|
|
{prim::kPrimConv2D->name(), {kOpFormat_NC1HWC0, kOpFormat_FRAC_Z}},
|
|
|
|
|
{prim::kPrimFusedBatchNorm->name(),
|
|
|
|
|
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}}};
|
|
|
|
|
|
|
|
|
|
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
@ -313,10 +315,41 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis
|
|
|
|
|
}
|
|
|
|
|
return filtered_kernel_info_list;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
|
|
|
|
|
void SetCastAndWeightFormat(const CNodePtr &kernel_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrPynativeNextIndex, kernel_node) ||
|
|
|
|
|
!AnfAlgo::HasNodeAttr(kAttrPynativeNextOpName, kernel_node)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The node [" << kernel_node->DebugString() << "] attr of " << kAttrPynativeNextIndex << " or "
|
|
|
|
|
<< kAttrPynativeNextOpName << " has been not setted yet!";
|
|
|
|
|
}
|
|
|
|
|
auto next_index = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPynativeNextIndex);
|
|
|
|
|
auto next_op_name = AnfAlgo::GetNodeAttr<std::string>(kernel_node, kAttrPynativeNextOpName);
|
|
|
|
|
auto iter = kNextOpFormatList.find(next_op_name);
|
|
|
|
|
if (iter == kNextOpFormatList.end()) {
|
|
|
|
|
MS_LOG(WARNING) << "The op name " << next_op_name << "has been not setted in the next op map ";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (iter->second.size() < next_index) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Next input index " << next_index << "is out of range in the next op map max size is "
|
|
|
|
|
<< iter->second.size();
|
|
|
|
|
}
|
|
|
|
|
if (AnfAlgo::GetCNodeName(kernel_node) != prim::kPrimCast->name()) {
|
|
|
|
|
MS_LOG(INFO) << "Only supported to change the node Cast's build info!!!";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto format = iter->second[next_index];
|
|
|
|
|
auto info_builder =
|
|
|
|
|
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(kernel_node));
|
|
|
|
|
info_builder->SetInputsFormat({format});
|
|
|
|
|
info_builder->SetOutputsFormat({format});
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), kernel_node.get());
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
|
auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(selected_kernel_info);
|
|
|
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
|
|
|
|
auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_kernel_node);
|
|
|
|
@ -329,7 +362,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
|
|
|
|
|
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (selected_kernel_info.GetInputFormat(input_index) == kOpFormat_FRACTAL_ZN_LSTM) {
|
|
|
|
|
if (selected_kernel_info->GetInputFormat(input_index) == kOpFormat_FRACTAL_ZN_LSTM) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// we set special device info of a input tensor.
|
|
|
|
@ -344,17 +377,17 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
|
|
|
|
|
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
|
|
|
|
if (IsValueNode<tensor::Tensor>(input_kernel_node) &&
|
|
|
|
|
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) {
|
|
|
|
|
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
|
|
|
|
|
std::vector<std::string> output_format = {selected_kernel_info->GetInputFormat(input_index)};
|
|
|
|
|
builder->SetOutputsFormat(output_format);
|
|
|
|
|
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
|
|
|
|
|
std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)};
|
|
|
|
|
builder->SetOutputsDeviceType(output_type);
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
|
|
|
|
|
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
|
|
|
|
|
std::vector<std::string> output_format = {selected_kernel_info->GetInputFormat(input_index)};
|
|
|
|
|
builder->SetOutputsFormat(output_format);
|
|
|
|
|
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
|
|
|
|
|
std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)};
|
|
|
|
|
builder->SetOutputsDeviceType(output_type);
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
|
|
|
|
|
}
|
|
|
|
@ -388,7 +421,10 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
|
|
|
|
|
// Set kernel info to the anfnode
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
|
|
|
|
|
// Set format and data type for input tensor.
|
|
|
|
|
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
|
|
|
|
|
if (AnfAlgo::HasNodeAttr(kAttrPynativeNextOpName, kernel_node)) {
|
|
|
|
|
SetCastAndWeightFormat(kernel_node);
|
|
|
|
|
}
|
|
|
|
|
SetTensorDeviceInfo(kernel_node);
|
|
|
|
|
return select_status;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -428,7 +464,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kern
|
|
|
|
|
auto selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, kernel_info_list);
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
|
|
|
|
|
// Set format and data type for input tensor.
|
|
|
|
|
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
|
|
|
|
|
SetTensorDeviceInfo(kernel_node);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(WARNING) << " <<<";
|
|
|
|
|
MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString()
|
|
|
|
|