|
|
|
@ -48,6 +48,22 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i
|
|
|
|
|
return reshape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetTransNodeAttr(const CNodePtr &trans_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(trans_node);
|
|
|
|
|
if (AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName) {
|
|
|
|
|
std::string input_format = AnfAlgo::GetInputFormat(trans_node, 0);
|
|
|
|
|
std::string output_format = AnfAlgo::GetOutputFormat(trans_node, 0);
|
|
|
|
|
if (input_format == kOpFormat_DEFAULT) {
|
|
|
|
|
input_format = kOpFormat_NCHW;
|
|
|
|
|
}
|
|
|
|
|
if (output_format == kOpFormat_DEFAULT) {
|
|
|
|
|
output_format = kOpFormat_NCHW;
|
|
|
|
|
}
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(input_format), trans_node);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(output_format), trans_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
|
|
|
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
|
|
|
|
|
AnfNodePtr trans_node = nullptr;
|
|
|
|
@ -173,6 +189,7 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
|
|
|
|
|
builder->SetInputsDeviceType({type_id});
|
|
|
|
|
}
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
|
|
|
|
|
SetTransNodeAttr(trans_data->cast<CNodePtr>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
|
|
|
|
|