|
|
|
@ -31,29 +31,26 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<TypeId> data_type_list{kNumberTypeFloat32, kNumberTypeFloat16, kNumberTypeInt8, kNumberTypeInt32};
|
|
|
|
|
std::vector<std::string> input_format, output_format;
|
|
|
|
|
std::vector<TypeId> input_type, output_type;
|
|
|
|
|
for (const auto &data_type : data_type_list) {
|
|
|
|
|
for (const auto &format : kOpFormatList) {
|
|
|
|
|
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
|
|
|
|
|
input_format.clear();
|
|
|
|
|
input_format.push_back(format);
|
|
|
|
|
input_type.clear();
|
|
|
|
|
input_type.push_back(data_type);
|
|
|
|
|
output_format.clear();
|
|
|
|
|
output_format.push_back(format);
|
|
|
|
|
output_type.clear();
|
|
|
|
|
output_type.push_back(data_type);
|
|
|
|
|
std::vector<std::string> inputs_format{};
|
|
|
|
|
std::vector<TypeId> inputs_type{};
|
|
|
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
|
|
|
|
inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index));
|
|
|
|
|
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
builder->SetInputsFormat(input_format);
|
|
|
|
|
builder->SetInputsDeviceType(input_type);
|
|
|
|
|
builder->SetOutputsFormat(output_format);
|
|
|
|
|
builder->SetOutputsDeviceType(output_type);
|
|
|
|
|
builder->SetKernelType(HCCL_KERNEL);
|
|
|
|
|
kernel_info_list->emplace_back(builder->Build());
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::string> outputs_format;
|
|
|
|
|
std::vector<TypeId> outputs_type;
|
|
|
|
|
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
|
|
|
|
|
outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index));
|
|
|
|
|
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
|
|
|
|
}
|
|
|
|
|
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
|
|
|
|
|
builder.SetInputsFormat(inputs_format);
|
|
|
|
|
builder.SetInputsDeviceType(inputs_type);
|
|
|
|
|
builder.SetOutputsFormat(outputs_format);
|
|
|
|
|
builder.SetOutputsDeviceType(outputs_type);
|
|
|
|
|
builder.SetKernelType(HCCL_KERNEL);
|
|
|
|
|
kernel_info_list->push_back(builder.Build());
|
|
|
|
|
}
|
|
|
|
|
} // namespace kernel
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|