|
|
|
@ -37,6 +37,7 @@ constexpr auto kAttrPadList = "pad_list";
|
|
|
|
|
constexpr auto kAttrPads = "pads";
|
|
|
|
|
constexpr auto kAttrMode = "mode";
|
|
|
|
|
constexpr auto kAttrChannelMultiplier = "channel_multiplier";
|
|
|
|
|
constexpr auto kAttrPerm = "perm";
|
|
|
|
|
|
|
|
|
|
bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vector<size_t> out_shape) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(conv2d);
|
|
|
|
@ -86,9 +87,16 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(conv2d);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_node);
|
|
|
|
|
auto ms_context = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ms_context);
|
|
|
|
|
auto perm = std::vector<int64_t>{1, 0, 2, 3};
|
|
|
|
|
std::vector<AnfNodePtr> transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)), input_node,
|
|
|
|
|
CreatePermValueNode(graph, perm)};
|
|
|
|
|
std::vector<AnfNodePtr> transpose_inputs;
|
|
|
|
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
|
|
|
|
transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)), input_node};
|
|
|
|
|
} else {
|
|
|
|
|
transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)), input_node,
|
|
|
|
|
CreatePermValueNode(graph, perm)};
|
|
|
|
|
}
|
|
|
|
|
auto transpose = graph->NewCNode(transpose_inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(transpose);
|
|
|
|
|
transpose->set_scope(conv2d->scope());
|
|
|
|
@ -111,6 +119,9 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons
|
|
|
|
|
auto output_names = std::vector<std::string>{"output"};
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), transpose);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), transpose);
|
|
|
|
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(perm), transpose);
|
|
|
|
|
}
|
|
|
|
|
return transpose;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|