|
|
|
@ -33,7 +33,6 @@ namespace opt {
|
|
|
|
|
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
|
|
|
|
namespace {
|
|
|
|
|
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
|
|
|
|
|
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW};
|
|
|
|
|
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
|
|
|
|
|
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
|
|
|
|
|
std::vector<AnfNodePtr> trans_inputs;
|
|
|
|
@ -82,45 +81,18 @@ std::string InitDefaultFormat(const AnfNodePtr &node) {
|
|
|
|
|
return default_format;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
|
|
|
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
|
|
|
|
|
AnfNodePtr trans_node = nullptr;
|
|
|
|
|
CNodePtr trans_data = nullptr;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
// Init
|
|
|
|
|
std::string default_format = InitDefaultFormat(node);
|
|
|
|
|
AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node;
|
|
|
|
|
std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index);
|
|
|
|
|
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format;
|
|
|
|
|
std::vector<Axis> padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index)
|
|
|
|
|
: AnfAlgo::GetOutputReshapeType(node, insert_index);
|
|
|
|
|
auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index)
|
|
|
|
|
: AnfAlgo::GetOutputInferShape(input_node, insert_index);
|
|
|
|
|
bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size())
|
|
|
|
|
: trans::IsNeedPadding(input_format, input_node_out_shape.size());
|
|
|
|
|
if (!need_padding) {
|
|
|
|
|
// don't need padding insert transdata only
|
|
|
|
|
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
|
|
|
|
trans_node = trans_data;
|
|
|
|
|
} else if (is_insert_input) {
|
|
|
|
|
// if need padding & is input need insert a transdata
|
|
|
|
|
// reshape[padding shape] -> transdata[padding shape] -> node
|
|
|
|
|
auto padding_shape =
|
|
|
|
|
trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, insert_index));
|
|
|
|
|
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
|
|
|
|
|
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
|
|
|
|
trans_node = trans_data;
|
|
|
|
|
trans_data->set_abstract(input_node->abstract());
|
|
|
|
|
} else {
|
|
|
|
|
// if need padding & is output need insert a transdata
|
|
|
|
|
// node -> transdata[padding shape] -> reshape[ori_shape]
|
|
|
|
|
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
|
|
|
|
auto reshape_node = CreateReshapeNode(func_graph, trans_data, kernel_select, input_node_out_shape);
|
|
|
|
|
trans_node = reshape_node;
|
|
|
|
|
void ReFreshInferShape(const AnfNodePtr &trans_node, const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(trans_node);
|
|
|
|
|
auto real_input_node = AnfAlgo::VisitKernelWithReturnType(node, 0).first;
|
|
|
|
|
if (!real_input_node->isa<CNode>()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto op_name = AnfAlgo::GetCNodeName(real_input_node);
|
|
|
|
|
if (op_name == kBasicLSTMCellWeightGradOpName && AnfAlgo::GetCNodeName(trans_node) == prim::kPrimReshape->name()) {
|
|
|
|
|
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(trans_node, 0);
|
|
|
|
|
auto type = AnfAlgo::GetPrevNodeOutputInferDataType(trans_node, 0);
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({type}, {{shape[0], shape[1]}}, node.get());
|
|
|
|
|
}
|
|
|
|
|
// refresh the transdata's format to ori format & dst format
|
|
|
|
|
RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis);
|
|
|
|
|
return trans_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
|
|
|
|
@ -161,15 +133,6 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
|
|
|
|
|
return node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ReFreshInferShape(const AnfNodePtr &node, const std::string &op_name) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (op_name == kBasicLSTMCellWeightGradOpName && AnfAlgo::GetCNodeName(node) == prim::kPrimReshape->name()) {
|
|
|
|
|
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
|
|
|
|
auto type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({type}, {{shape[0], shape[1]}}, node.get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
|
|
|
const KernelSelectPtr &kernel_select) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
@ -177,10 +140,6 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
|
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
|
|
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
|
|
|
|
size_t out_num = AnfAlgo::GetOutputTensorNum(node);
|
|
|
|
|
std::string op_name;
|
|
|
|
|
if (node->isa<CNode>()) {
|
|
|
|
|
op_name = AnfAlgo::GetCNodeName(node);
|
|
|
|
|
}
|
|
|
|
|
for (size_t output_idx = 0; output_idx < out_num; ++output_idx) {
|
|
|
|
|
std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx);
|
|
|
|
|
if (output_format == kOpFormat_NC1KHKWHWC0) {
|
|
|
|
@ -191,7 +150,6 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
|
|
|
|
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
|
|
|
|
|
if (origin_shape.size() > 1 && kCommonFormatSet.find(output_format) == kCommonFormatSet.end()) {
|
|
|
|
|
auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false);
|
|
|
|
|
ReFreshInferShape(trans_op, op_name);
|
|
|
|
|
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) {
|
|
|
|
|
kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0);
|
|
|
|
|
}
|
|
|
|
@ -205,6 +163,50 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
|
|
|
|
return make_tuple;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
|
|
|
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
|
|
|
|
|
AnfNodePtr trans_node = nullptr;
|
|
|
|
|
CNodePtr trans_data = nullptr;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
// Init
|
|
|
|
|
std::string default_format = InitDefaultFormat(node);
|
|
|
|
|
AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node;
|
|
|
|
|
std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index);
|
|
|
|
|
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format;
|
|
|
|
|
std::vector<Axis> padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index)
|
|
|
|
|
: AnfAlgo::GetOutputReshapeType(node, insert_index);
|
|
|
|
|
auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index)
|
|
|
|
|
: AnfAlgo::GetOutputInferShape(input_node, insert_index);
|
|
|
|
|
bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size())
|
|
|
|
|
: trans::IsNeedPadding(input_format, input_node_out_shape.size());
|
|
|
|
|
if (!need_padding) {
|
|
|
|
|
// don't need padding insert transdata only
|
|
|
|
|
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
|
|
|
|
trans_node = trans_data;
|
|
|
|
|
} else if (is_insert_input) {
|
|
|
|
|
// if need padding & is input need insert a transdata
|
|
|
|
|
// reshape[padding shape] -> transdata[padding shape] -> node
|
|
|
|
|
auto padding_shape =
|
|
|
|
|
trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, insert_index));
|
|
|
|
|
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
|
|
|
|
|
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
|
|
|
|
trans_node = trans_data;
|
|
|
|
|
trans_data->set_abstract(input_node->abstract());
|
|
|
|
|
} else {
|
|
|
|
|
// if need padding & is output need insert a transdata
|
|
|
|
|
// node -> transdata[padding shape] -> reshape[ori_shape]
|
|
|
|
|
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
|
|
|
|
auto reshape_node = CreateReshapeNode(func_graph, trans_data, kernel_select, input_node_out_shape);
|
|
|
|
|
trans_node = reshape_node;
|
|
|
|
|
}
|
|
|
|
|
// refresh the transdata's format to ori format & dst format
|
|
|
|
|
RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis);
|
|
|
|
|
if (!is_insert_input) {
|
|
|
|
|
ReFreshInferShape(trans_node, node);
|
|
|
|
|
}
|
|
|
|
|
return trans_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
|
|
|
|
const AnfNodePtr &trans_data, const std::vector<Axis> &reshape_type,
|
|
|
|
|
const TypeId &type_id) {
|
|
|
|
|