!1681 modify hccl mutiple input and output op register

Merge pull request !1681 from Maoweiyong/fix_hccl_mutiple_input_output
pull/1681/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit b60a353b94

@ -31,29 +31,26 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
return;
}
std::vector<TypeId> data_type_list{kNumberTypeFloat32, kNumberTypeFloat16, kNumberTypeInt8, kNumberTypeInt32};
std::vector<std::string> input_format, output_format;
std::vector<TypeId> input_type, output_type;
for (const auto &data_type : data_type_list) {
for (const auto &format : kOpFormatList) {
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
input_format.clear();
input_format.push_back(format);
input_type.clear();
input_type.push_back(data_type);
output_format.clear();
output_format.push_back(format);
output_type.clear();
output_type.push_back(data_type);
std::vector<std::string> inputs_format{};
std::vector<TypeId> inputs_type{};
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index));
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index));
}
builder->SetInputsFormat(input_format);
builder->SetInputsDeviceType(input_type);
builder->SetOutputsFormat(output_format);
builder->SetOutputsDeviceType(output_type);
builder->SetKernelType(HCCL_KERNEL);
kernel_info_list->emplace_back(builder->Build());
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index));
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
}
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
builder.SetInputsFormat(inputs_format);
builder.SetInputsDeviceType(inputs_type);
builder.SetOutputsFormat(outputs_format);
builder.SetOutputsDeviceType(outputs_type);
builder.SetKernelType(HCCL_KERNEL);
kernel_info_list->push_back(builder.Build());
}
} // namespace kernel
} // namespace mindspore

Loading…
Cancel
Save