|
|
|
@ -33,7 +33,25 @@ std::vector<int> TransposeAxis(const std::string &src_format, const std::string
|
|
|
|
|
} else if ((src_format == kOpFormat_NHWC) && (dst_format == kOpFormat_NCHW)) {
|
|
|
|
|
return {0, 3, 1, 2};
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invaild format transform, from " << src_format << " to " << dst_format;
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid format transform, from " << src_format << " to " << dst_format;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Transpose can be replaceed by nop reshape in some situations.
|
|
|
|
|
// 1. out_shape [x, 1, 1, y] with transpose perm {0, 2, 3, 1}
|
|
|
|
|
// 2. out_shape [x, y, 1, 1] with transpose perm {0, 3, 1, 2}
|
|
|
|
|
bool IsFakeTranspose(const std::vector<size_t> &out_shape, const std::vector<int> &transpose_perm) {
|
|
|
|
|
if (out_shape.size() != 4) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid data shape, 4-D data was needed, but get " << out_shape.size() << "-D.";
|
|
|
|
|
}
|
|
|
|
|
std::vector<int> perm1 = {0, 2, 3, 1};
|
|
|
|
|
std::vector<int> perm2 = {0, 3, 1, 2};
|
|
|
|
|
if (transpose_perm == perm1) {
|
|
|
|
|
return (out_shape[1] == 1 && out_shape[2] == 1);
|
|
|
|
|
} else if (transpose_perm == perm2) {
|
|
|
|
|
return (out_shape[2] == 1 && out_shape[3] == 1);
|
|
|
|
|
} else {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -56,8 +74,16 @@ void SetTransposeOpBuildInfo(const std::string &input_format, const std::string
|
|
|
|
|
CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node,
|
|
|
|
|
int used_node_index, const std::vector<int> &transpose_perm) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
// 1.Create a transpose node.
|
|
|
|
|
auto transpose_prim = std::make_shared<Primitive>(prim::kPrimTranspose->name());
|
|
|
|
|
// 0.Judge whether it is a fake transpose
|
|
|
|
|
auto transed_shape = AnfAlgo::GetInputDeviceShape(used_node, used_node_index);
|
|
|
|
|
bool is_fake = IsFakeTranspose(transed_shape, transpose_perm);
|
|
|
|
|
// 1.Create a transpose node or a fake transpose node:reshape.
|
|
|
|
|
mindspore::PrimitivePtr transpose_prim;
|
|
|
|
|
if (is_fake) {
|
|
|
|
|
transpose_prim = std::make_shared<Primitive>(prim::kPrimReshape->name());
|
|
|
|
|
} else {
|
|
|
|
|
transpose_prim = std::make_shared<Primitive>(prim::kPrimTranspose->name());
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(transpose_prim);
|
|
|
|
|
// 2.Set the input of transpose.
|
|
|
|
|
std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node};
|
|
|
|
@ -66,7 +92,9 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co
|
|
|
|
|
auto transpose_type = {AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)};
|
|
|
|
|
auto transpose_shape = {AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index)};
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(transpose_type, transpose_shape, transpose_op.get());
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op);
|
|
|
|
|
if (!is_fake) {
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op);
|
|
|
|
|
}
|
|
|
|
|
// 4.Set the input of used_node.
|
|
|
|
|
MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
|
|
|
|
|
<< ", index: " << used_node_index;
|
|
|
|
|