|
|
|
@ -51,11 +51,11 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i
|
|
|
|
|
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
|
|
|
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
|
|
|
|
|
AnfNodePtr trans_node = nullptr;
|
|
|
|
|
AnfNodePtr input_node = node;
|
|
|
|
|
AnfNodePtr input_node = nullptr;
|
|
|
|
|
CNodePtr trans_data = nullptr;
|
|
|
|
|
std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0);
|
|
|
|
|
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT;
|
|
|
|
|
std::vector<kernel::Axis> padding_axis = AnfAlgo::GetOutputReshapeType(node, 0);
|
|
|
|
|
std::vector<kernel::Axis> padding_axis;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
// if insert transdata for input we need to change the input
|
|
|
|
|
if (is_insert_input) {
|
|
|
|
@ -66,12 +66,17 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
|
|
|
|
dst_format = AnfAlgo::GetInputFormat(cnode, insert_index);
|
|
|
|
|
input_node = AnfAlgo::GetInputNode(cnode, insert_index);
|
|
|
|
|
padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index);
|
|
|
|
|
} else {
|
|
|
|
|
input_node = node;
|
|
|
|
|
padding_axis = AnfAlgo::GetOutputReshapeType(node, 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto input_node_out_shape = AnfAlgo::GetOutputInferShape(input_node, 0);
|
|
|
|
|
bool need_padding = false;
|
|
|
|
|
if (is_insert_input) {
|
|
|
|
|
need_padding = (trans::IsNeedPadding(dst_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()));
|
|
|
|
|
need_padding = (trans::IsNeedPadding(dst_format, input_node_out_shape.size()));
|
|
|
|
|
} else {
|
|
|
|
|
need_padding = (trans::IsNeedPadding(input_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()));
|
|
|
|
|
need_padding = (trans::IsNeedPadding(input_format, input_node_out_shape.size()));
|
|
|
|
|
}
|
|
|
|
|
if (!need_padding) {
|
|
|
|
|
// don't need padding insert transdata only
|
|
|
|
@ -80,8 +85,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
|
|
|
|
} else if (is_insert_input) {
|
|
|
|
|
// if need padding & is input need insert a transdata
|
|
|
|
|
// reshape[padding shape] -> transdata[padding shape] -> node
|
|
|
|
|
auto padding_shape =
|
|
|
|
|
trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0));
|
|
|
|
|
auto padding_shape = trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, 0));
|
|
|
|
|
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
|
|
|
|
|
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
|
|
|
|
trans_node = trans_data;
|
|
|
|
@ -89,8 +93,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
|
|
|
|
// if need padding & is output need insert a transdata
|
|
|
|
|
// node -> transdata[padding shape] -> reshape[ori_shape]
|
|
|
|
|
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
|
|
|
|
auto reshape_node =
|
|
|
|
|
CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0));
|
|
|
|
|
auto reshape_node = CreateReshapeNode(func_graph, trans_data, kernel_select, input_node_out_shape);
|
|
|
|
|
trans_node = reshape_node;
|
|
|
|
|
}
|
|
|
|
|
// refresh the transdata's format to ori format & dst format
|
|
|
|
@ -140,10 +143,10 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
|
|
|
|
const KernelSelectPtr &kernel_select) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs;
|
|
|
|
|
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
|
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
|
|
|
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
|
|
|
|
for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) {
|
|
|
|
|
size_t out_num = AnfAlgo::GetOutputTensorNum(node);
|
|
|
|
|
for (size_t output_idx = 0; output_idx < out_num; ++output_idx) {
|
|
|
|
|
std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx);
|
|
|
|
|
if (output_format == kOpFormat_NC1KHKWHWC0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node "
|
|
|
|
@ -151,12 +154,12 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
|
|
|
|
}
|
|
|
|
|
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
|
|
|
|
|
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
|
|
|
|
|
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
|
|
|
|
if (origin_shape.size() > 1 && kCommonFormatSet.find(output_format) == kCommonFormatSet.end()) {
|
|
|
|
|
auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false);
|
|
|
|
|
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node, output_idx)) {
|
|
|
|
|
kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0);
|
|
|
|
|
}
|
|
|
|
|
make_tuple_inputs.emplace_back(trans_op);
|
|
|
|
|
make_tuple_inputs.push_back(trans_op);
|
|
|
|
|
} else {
|
|
|
|
|
// No need insert trans op.
|
|
|
|
|
make_tuple_inputs.push_back(tuple_getitem);
|
|
|
|
@ -188,15 +191,11 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
|
|
|
|
|
const bool need_padding, const std::string &op_name) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input);
|
|
|
|
|
std::vector<AnfNodePtr> trans_inputs;
|
|
|
|
|
auto prim = std::make_shared<Primitive>(op_name);
|
|
|
|
|
trans_inputs.push_back(NewValueNode(prim));
|
|
|
|
|
trans_inputs.push_back(input);
|
|
|
|
|
CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
|
|
|
|
|
CNodePtr trans_node = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(op_name)), input});
|
|
|
|
|
MS_EXCEPTION_IF_NULL(trans_node);
|
|
|
|
|
auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
|
|
|
|
|
if (need_padding) {
|
|
|
|
|
// if need padding we should set the transdata node's shape to the padding shape
|
|
|
|
|
auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
|
|
|
|
|
{trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)},
|
|
|
|
|
trans_node.get());
|
|
|
|
@ -224,11 +223,7 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
std::string input_format = format;
|
|
|
|
|
std::string output_format = format;
|
|
|
|
|
std::vector<AnfNodePtr> new_cast_inputs;
|
|
|
|
|
auto prim = std::make_shared<Primitive>(prim::kPrimCast->name());
|
|
|
|
|
new_cast_inputs.push_back(NewValueNode(prim));
|
|
|
|
|
new_cast_inputs.push_back(input);
|
|
|
|
|
CNodePtr cast = func_graph->NewCNode(new_cast_inputs);
|
|
|
|
|
CNodePtr cast = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())), input});
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cast);
|
|
|
|
|
// set kernel build info
|
|
|
|
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
|
|
|
@ -280,7 +275,8 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
|
|
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
|
|
|
|
|
size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
|
|
|
|
|
for (size_t input_index = 0; input_index < in_num; ++input_index) {
|
|
|
|
|
AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_node);
|
|
|
|
|
new_inputs.push_back(input_node);
|
|
|
|
@ -301,8 +297,10 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt
|
|
|
|
|
CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
|
|
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
|
|
|
|
|
const auto infer_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
|
|
|
|
|
size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
|
|
|
|
|
for (size_t input_index = 0; input_index < in_num; ++input_index) {
|
|
|
|
|
auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index);
|
|
|
|
|
const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
|
|
|
|
|
TypeId origin_type(kTypeUnknown);
|
|
|
|
|
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
|
|
|
|
|
auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0);
|
|
|
|
@ -311,20 +309,19 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
|
|
|
|
|
// weight
|
|
|
|
|
origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index);
|
|
|
|
|
if (origin_type == kTypeUnknown) {
|
|
|
|
|
origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index);
|
|
|
|
|
origin_type = AnfAlgo::GetOutputDeviceDataType(prev_node.first, prev_node.second);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// feature map
|
|
|
|
|
origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
|
|
|
|
|
origin_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
|
|
|
|
|
}
|
|
|
|
|
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
|
|
|
|
|
const std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_index);
|
|
|
|
|
const TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index);
|
|
|
|
|
const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second);
|
|
|
|
|
// In graph kernel, we check parameter,
|
|
|
|
|
// the eliminate pass will not eliminate this case, so we just do not insert the noused cast.
|
|
|
|
|
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) {
|
|
|
|
|
new_inputs.push_back(cur_input);
|
|
|
|
|
} else if (origin_type != device_type) {
|
|
|
|
|
} else if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
|
|
|
|
|
auto cast =
|
|
|
|
|
AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cast);
|
|
|
|
|