|
|
|
@ -101,9 +101,9 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
|
|
|
|
|
auto origin_type = AnfAlgo::GetOutputDeviceDataType(origin_pair.first, origin_pair.second);
|
|
|
|
|
auto cur_format = AnfAlgo::GetOutputFormat(cnode, output_index);
|
|
|
|
|
auto cur_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_index);
|
|
|
|
|
auto cur_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
|
|
|
|
|
auto cur_shape = AnfAlgo::GetOutputInferShape(cnode, output_index);
|
|
|
|
|
// insert trans
|
|
|
|
|
if (origin_format != cur_format) {
|
|
|
|
|
if (origin_format != cur_format && cur_shape.size() > 1) {
|
|
|
|
|
auto kernel_select = std::make_shared<KernelSelect>();
|
|
|
|
|
final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, cur_format, origin_format,
|
|
|
|
|
kTransDataOpName, false);
|
|
|
|
|