fix bug of hccl kernel info and change cast's kernel info

pull/2393/head
WilliamLian 5 years ago
parent a58b1a1435
commit 5f9d2759ee

@ -23,6 +23,8 @@
namespace mindspore {
namespace kernel {
void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
const std::vector<TypeId> kHcclSupportTypes = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16,
kNumberTypeFloat32, kNumberTypeInt16};
MS_EXCEPTION_IF_NULL(kernel_info_list);
MS_EXCEPTION_IF_NULL(kernel_node);
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
@ -30,27 +32,27 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
MS_LOG(DEBUG) << "Hccl does not have op [" << op_name << "]";
return;
}
std::vector<std::string> inputs_format{};
std::vector<TypeId> inputs_type{};
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index));
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index));
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index));
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
for (const auto &type : kHcclSupportTypes) {
std::vector<std::string> inputs_format{};
std::vector<TypeId> inputs_type{};
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index));
inputs_type.push_back(type);
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index));
outputs_type.push_back(type);
}
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
builder.SetInputsFormat(inputs_format);
builder.SetInputsDeviceType(inputs_type);
builder.SetOutputsFormat(outputs_format);
builder.SetOutputsDeviceType(outputs_type);
builder.SetKernelType(HCCL_KERNEL);
kernel_info_list->push_back(builder.Build());
}
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
builder.SetInputsFormat(inputs_format);
builder.SetInputsDeviceType(inputs_type);
builder.SetOutputsFormat(outputs_format);
builder.SetOutputsDeviceType(outputs_type);
builder.SetKernelType(HCCL_KERNEL);
kernel_info_list->push_back(builder.Build());
}
} // namespace kernel
} // namespace mindspore

@ -120,6 +120,24 @@ bool CheckIndexOutput(const CNodePtr &node, const std::shared_ptr<kernel::Kernel
return AnfAlgo::GetOutputFormat(node, 0) == kernel_info->GetOutputFormat(index);
}
void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size_t cast_index) {
using Shape = std::vector<size_t>;
auto cast_dtype = AnfAlgo::GetOutputInferDataType(cast, 0);
auto cast_shape = AnfAlgo::GetOutputInferShape(cast, 0);
std::vector<Shape> shapes;
std::vector<TypeId> types;
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) {
if (cast_index == index) {
shapes.emplace_back(cast_shape);
types.emplace_back(cast_dtype);
continue;
}
shapes.emplace_back(AnfAlgo::GetOutputInferShape(cnode, index));
types.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, index));
}
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, cnode.get());
}
AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, const KernelQueryPtr kernel_query) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(kernel_query);
@ -151,9 +169,9 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co
<< "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info"
<< (*alternative_kernel_info)->ToString();
AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get());
ChangeNodeInferInfo(next_cnode, node, cast_index);
if (node->inputs().size() < kCastInputNum) {
auto op_name = AnfAlgo::GetCNodeName(node);
MS_LOG(EXCEPTION) << "op[" << op_name << "] has wrong input num:";
MS_LOG(EXCEPTION) << "Op[" << node->DebugString() << "] has wrong input num:";
}
return node->input(1);
}
@ -223,7 +241,11 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod
<< "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info"
<< (*kernel_info_it)->ToString();
AnfAlgo::SetSelectKernelBuildInfo(*kernel_info_it, prior_op.get());
ChangeNodeInferInfo(prior_op, cur_node, output_idx);
if (!single_output) {
MS_EXCEPTION_IF_NULL(x_node);
ChangeNodeInferInfo(x_node->cast<CNodePtr>(), cur_node, 0);
}
auto prior_name = AnfAlgo::GetCNodeName(prior_op);
if (prior_name == kFive2FourOpName) {
AnfAlgo::CopyNodeAttr("dst_type", "dstType", cur_node, prior_op);

Loading…
Cancel
Save