|
|
|
@ -120,6 +120,24 @@ bool CheckIndexOutput(const CNodePtr &node, const std::shared_ptr<kernel::Kernel
|
|
|
|
|
return AnfAlgo::GetOutputFormat(node, 0) == kernel_info->GetOutputFormat(index);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size_t cast_index) {
|
|
|
|
|
using Shape = std::vector<size_t>;
|
|
|
|
|
auto cast_dtype = AnfAlgo::GetOutputInferDataType(cast, 0);
|
|
|
|
|
auto cast_shape = AnfAlgo::GetOutputInferShape(cast, 0);
|
|
|
|
|
std::vector<Shape> shapes;
|
|
|
|
|
std::vector<TypeId> types;
|
|
|
|
|
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) {
|
|
|
|
|
if (cast_index == index) {
|
|
|
|
|
shapes.emplace_back(cast_shape);
|
|
|
|
|
types.emplace_back(cast_dtype);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
shapes.emplace_back(AnfAlgo::GetOutputInferShape(cnode, index));
|
|
|
|
|
types.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, index));
|
|
|
|
|
}
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, cnode.get());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, const KernelQueryPtr kernel_query) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_query);
|
|
|
|
@ -151,9 +169,9 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co
|
|
|
|
|
<< "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info"
|
|
|
|
|
<< (*alternative_kernel_info)->ToString();
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get());
|
|
|
|
|
ChangeNodeInferInfo(next_cnode, node, cast_index);
|
|
|
|
|
if (node->inputs().size() < kCastInputNum) {
|
|
|
|
|
auto op_name = AnfAlgo::GetCNodeName(node);
|
|
|
|
|
MS_LOG(EXCEPTION) << "op[" << op_name << "] has wrong input num:";
|
|
|
|
|
MS_LOG(EXCEPTION) << "Op[" << node->DebugString() << "] has wrong input num:";
|
|
|
|
|
}
|
|
|
|
|
return node->input(1);
|
|
|
|
|
}
|
|
|
|
@ -223,7 +241,11 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod
|
|
|
|
|
<< "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info"
|
|
|
|
|
<< (*kernel_info_it)->ToString();
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(*kernel_info_it, prior_op.get());
|
|
|
|
|
|
|
|
|
|
ChangeNodeInferInfo(prior_op, cur_node, output_idx);
|
|
|
|
|
if (!single_output) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(x_node);
|
|
|
|
|
ChangeNodeInferInfo(x_node->cast<CNodePtr>(), cur_node, 0);
|
|
|
|
|
}
|
|
|
|
|
auto prior_name = AnfAlgo::GetCNodeName(prior_op);
|
|
|
|
|
if (prior_name == kFive2FourOpName) {
|
|
|
|
|
AnfAlgo::CopyNodeAttr("dst_type", "dstType", cur_node, prior_op);
|
|
|
|
|