add op names mapping for one hwopt pass

pull/1987/head
huanghui 5 years ago
parent cc0add562b
commit 1044310783

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <tuple> #include <tuple>
#include <string>
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "common/utils.h" #include "common/utils.h"
@ -50,6 +51,8 @@ CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square,
square_sumv1->set_scope(sum->scope()); square_sumv1->set_scope(sum->scope());
AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv1); AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv1);
AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum, square_sumv1); AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum, square_sumv1);
auto names = MakeValue<std::vector<std::string>>({prim::kPrimSquare->name(), prim::kPrimReduceSum->name()});
AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, names, square_sumv1);
return square_sumv1; return square_sumv1;
} }
@ -71,6 +74,8 @@ CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square,
square_sumv2->set_scope(sum->scope()); square_sumv2->set_scope(sum->scope());
AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv2); AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv2);
AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum, square_sumv2); AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum, square_sumv2);
auto names = MakeValue<std::vector<std::string>>({prim::kPrimSquare->name(), prim::kPrimReduceSum->name()});
AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, names, square_sumv2);
return square_sumv2; return square_sumv2;
} }

@ -200,6 +200,7 @@ constexpr auto kAttrLabelIndex = "label_index";
constexpr auto kAttrLabelSwitchList = "label_switch_list"; constexpr auto kAttrLabelSwitchList = "label_switch_list";
constexpr auto kAttrNewAxisMask = "new_axis_mask"; constexpr auto kAttrNewAxisMask = "new_axis_mask";
constexpr auto kAttrShrinkAxisMask = "shrink_axis_mask"; constexpr auto kAttrShrinkAxisMask = "shrink_axis_mask";
constexpr auto kAttrDatadumpOriginalNames = "_datadump_original_names";
// attr value // attr value
constexpr auto kValueTargetSwitch = "target_switch"; constexpr auto kValueTargetSwitch = "target_switch";

Loading…
Cancel
Save