|
|
|
@ -16,12 +16,30 @@
|
|
|
|
|
|
|
|
|
|
#include "kernel/hccl/hccl_kernel_metadata.h"
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <set>
|
|
|
|
|
#include "utils/utils.h"
|
|
|
|
|
#include "kernel/hccl/hcom_util.h"
|
|
|
|
|
#include "session/anf_runtime_algorithm.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace kernel {
|
|
|
|
|
namespace {
|
|
|
|
|
std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) {
|
|
|
|
|
const std::set<std::string> kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0};
|
|
|
|
|
auto op_name = AnfAlgo::GetCNodeName(kernel_node);
|
|
|
|
|
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index);
|
|
|
|
|
if (op_name != kReduceScatter && op_name != kAllGatherOpName) {
|
|
|
|
|
return format;
|
|
|
|
|
}
|
|
|
|
|
if (format == kOpFormat_FRAC_NZ && AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index).size() <= 2) {
|
|
|
|
|
return kOpFormat_DEFAULT;
|
|
|
|
|
}
|
|
|
|
|
if (kReduceNoSupportedSet.find(format) != kReduceNoSupportedSet.end()) {
|
|
|
|
|
return kOpFormat_DEFAULT;
|
|
|
|
|
}
|
|
|
|
|
return format;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
|
|
|
|
|
const std::vector<TypeId> kHcclSupportTypes = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16,
|
|
|
|
|
kNumberTypeFloat32, kNumberTypeInt16};
|
|
|
|
@ -36,13 +54,13 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
|
|
|
|
|
std::vector<std::string> inputs_format{};
|
|
|
|
|
std::vector<TypeId> inputs_type{};
|
|
|
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
|
|
|
|
inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index));
|
|
|
|
|
inputs_format.emplace_back(GetKernelFormat(kernel_node, input_index));
|
|
|
|
|
inputs_type.push_back(type);
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::string> outputs_format;
|
|
|
|
|
std::vector<TypeId> outputs_type;
|
|
|
|
|
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
|
|
|
|
|
outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index));
|
|
|
|
|
outputs_format.emplace_back(GetKernelFormat(kernel_node, output_index));
|
|
|
|
|
outputs_type.push_back(type);
|
|
|
|
|
}
|
|
|
|
|
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
|
|
|
|
|