|
|
|
@ -32,13 +32,13 @@ namespace opt {
|
|
|
|
|
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
|
|
|
|
namespace {
|
|
|
|
|
kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
|
|
|
|
const AnfNodePtr &node,
|
|
|
|
|
const kernel::KernelBuildInfo ori_build_info) {
|
|
|
|
|
const AnfNodePtr &node, const TypeId device_type,
|
|
|
|
|
const kernel::KernelBuildInfo &ori_build_info) {
|
|
|
|
|
KernelBuildInfoBuilder builder;
|
|
|
|
|
builder.SetInputsFormat({input_format});
|
|
|
|
|
builder.SetOutputsFormat({output_format});
|
|
|
|
|
builder.SetInputsDeviceType({ori_build_info.GetInputDeviceType(0)});
|
|
|
|
|
builder.SetOutputsDeviceType({ori_build_info.GetOutputDeviceType(0)});
|
|
|
|
|
builder.SetInputsDeviceType({device_type});
|
|
|
|
|
builder.SetOutputsDeviceType({device_type});
|
|
|
|
|
builder.SetKernelType(ori_build_info.kernel_type());
|
|
|
|
|
builder.SetFusionType(ori_build_info.fusion_type());
|
|
|
|
|
builder.SetProcessor(ori_build_info.processor());
|
|
|
|
@ -56,11 +56,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
|
|
|
|
CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(trans_node);
|
|
|
|
|
std::vector<kernel::Axis> padding_axis;
|
|
|
|
|
if (AnfAlgo::IsRealKernel(input)) {
|
|
|
|
|
padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
|
|
|
|
|
} else {
|
|
|
|
|
padding_axis = AnfAlgo::GetPrevNodeOutputReshapeType(input, 0);
|
|
|
|
|
}
|
|
|
|
|
padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
|
|
|
|
|
if (need_padding) {
|
|
|
|
|
// if need padding we should set the transdata node's shape to the padding shape
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
|
|
|
|
@ -129,15 +125,8 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
|
|
|
|
|
AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
|
|
|
const KernelSelectPtr &kernel_select) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
std::string output_format;
|
|
|
|
|
std::vector<size_t> origin_shape;
|
|
|
|
|
if (!AnfAlgo::IsRealKernel(node)) {
|
|
|
|
|
output_format = AnfAlgo::GetPrevNodeOutputFormat(node, 0);
|
|
|
|
|
origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
|
|
|
|
} else {
|
|
|
|
|
output_format = AnfAlgo::GetOutputFormat(node, 0);
|
|
|
|
|
origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
|
|
|
|
}
|
|
|
|
|
std::string output_format = AnfAlgo::GetOutputFormat(node, 0);
|
|
|
|
|
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
|
|
|
|
if (output_format == kOpFormat_NC1KHKWHWC0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
|
|
|
|
|
<< node->DebugString();
|
|
|
|
@ -186,6 +175,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
|
|
|
|
AnfNodePtr trans_node = nullptr;
|
|
|
|
|
AnfNodePtr input_node = node;
|
|
|
|
|
AnfNodePtr trans_data = nullptr;
|
|
|
|
|
TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (origin_format.empty() || dest_format.empty()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format;
|
|
|
|
@ -196,6 +186,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
|
|
|
|
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode";
|
|
|
|
|
}
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
input_node = AnfAlgo::GetInputNode(cnode, insert_index);
|
|
|
|
|
}
|
|
|
|
@ -231,7 +222,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
|
|
|
|
MS_EXCEPTION_IF_NULL(trans_data);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(trans_data->kernel_info());
|
|
|
|
|
auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info();
|
|
|
|
|
auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, *trans_ori_build_info);
|
|
|
|
|
auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, dtype, *trans_ori_build_info);
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get());
|
|
|
|
|
return trans_node;
|
|
|
|
|
}
|
|
|
|
|