conv2d_unify_ir pass adapt for pynative

pull/10019/head
yuchaojie 4 years ago
parent a492cc470b
commit 0368d686d3

@ -37,6 +37,7 @@ constexpr auto kAttrPadList = "pad_list";
constexpr auto kAttrPads = "pads";
constexpr auto kAttrMode = "mode";
constexpr auto kAttrChannelMultiplier = "channel_multiplier";
constexpr auto kAttrPerm = "perm";
bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vector<size_t> out_shape) {
MS_EXCEPTION_IF_NULL(conv2d);
@ -86,9 +87,16 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(conv2d);
MS_EXCEPTION_IF_NULL(input_node);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto perm = std::vector<int64_t>{1, 0, 2, 3};
std::vector<AnfNodePtr> transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)), input_node,
CreatePermValueNode(graph, perm)};
std::vector<AnfNodePtr> transpose_inputs;
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)), input_node};
} else {
transpose_inputs = {NewValueNode(std::make_shared<Primitive>(kTransposeOpName)), input_node,
CreatePermValueNode(graph, perm)};
}
auto transpose = graph->NewCNode(transpose_inputs);
MS_EXCEPTION_IF_NULL(transpose);
transpose->set_scope(conv2d->scope());
@ -111,6 +119,9 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons
auto output_names = std::vector<std::string>{"output"};
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), transpose);
AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), transpose);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(perm), transpose);
}
return transpose;
}

Loading…
Cancel
Save