|
|
@ -20,22 +20,21 @@
|
|
|
|
#include <vector>
|
|
|
|
#include <vector>
|
|
|
|
#include <algorithm>
|
|
|
|
#include <algorithm>
|
|
|
|
|
|
|
|
|
|
|
|
#include "optimizer/optimizer.h"
|
|
|
|
|
|
|
|
#include "optimizer/irpass.h"
|
|
|
|
|
|
|
|
#include "ir/visitor.h"
|
|
|
|
|
|
|
|
#include "ir/func_graph.h"
|
|
|
|
#include "ir/func_graph.h"
|
|
|
|
#include "ir/func_graph_cloner.h"
|
|
|
|
#include "ir/func_graph_cloner.h"
|
|
|
|
#include "operator/ops.h"
|
|
|
|
#include "ir/optimizer_caller.h"
|
|
|
|
#include "ir/pattern_matcher.h"
|
|
|
|
#include "ir/pattern_matcher.h"
|
|
|
|
|
|
|
|
#include "operator/ops.h"
|
|
|
|
|
|
|
|
#include "optimizer/irpass.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
namespace mindspore {
|
|
|
|
namespace opt {
|
|
|
|
namespace opt {
|
|
|
|
namespace irpass {
|
|
|
|
namespace irpass {
|
|
|
|
// {prim::kPrimSwitch, true, X, Y}
|
|
|
|
// {prim::kPrimSwitch, true, X, Y}
|
|
|
|
// {prim::kPrimSwitch, false, X, Y}
|
|
|
|
// {prim::kPrimSwitch, false, X, Y}
|
|
|
|
class SwitchSimplify {
|
|
|
|
class SwitchSimplify : public OptimizerCaller {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
PatternNode<AnfNodePtr> cond, true_br, false_br;
|
|
|
|
PatternNode<AnfNodePtr> cond, true_br, false_br;
|
|
|
|
auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr {
|
|
|
|
auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr {
|
|
|
|
auto cond_value_ = GetValue<bool>(GetValueNode(cond.GetNode(node)));
|
|
|
|
auto cond_value_ = GetValue<bool>(GetValueNode(cond.GetNode(node)));
|
|
|
@ -54,9 +53,9 @@ class SwitchSimplify {
|
|
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} =>
|
|
|
|
// {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} =>
|
|
|
|
// {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}}
|
|
|
|
// {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}}
|
|
|
|
class FloatTupleGetItemSwitch {
|
|
|
|
class FloatTupleGetItemSwitch : public OptimizerCaller {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
PatternNode<AnfNodePtr> cond, true_br, false_br, x;
|
|
|
|
PatternNode<AnfNodePtr> cond, true_br, false_br, x;
|
|
|
|
MATCH_REPLACE_IF(node,
|
|
|
|
MATCH_REPLACE_IF(node,
|
|
|
|
PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x),
|
|
|
|
PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x),
|
|
|
@ -69,9 +68,9 @@ class FloatTupleGetItemSwitch {
|
|
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} =>
|
|
|
|
// {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} =>
|
|
|
|
// {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}}
|
|
|
|
// {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}}
|
|
|
|
class FloatEnvGetItemSwitch {
|
|
|
|
class FloatEnvGetItemSwitch : public OptimizerCaller {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
PatternNode<AnfNodePtr> cond, true_br, false_br, x, x2;
|
|
|
|
PatternNode<AnfNodePtr> cond, true_br, false_br, x, x2;
|
|
|
|
MATCH_REPLACE_IF(node,
|
|
|
|
MATCH_REPLACE_IF(node,
|
|
|
|
PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2),
|
|
|
|
PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2),
|
|
|
@ -93,9 +92,9 @@ AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfN
|
|
|
|
} // namespace internal
|
|
|
|
} // namespace internal
|
|
|
|
|
|
|
|
|
|
|
|
// {{prim::kPrimSwitch, X, G1, G2}, Xs}
|
|
|
|
// {{prim::kPrimSwitch, X, G1, G2}, Xs}
|
|
|
|
class ConvertSwitchReplacement {
|
|
|
|
class ConvertSwitchReplacement : public OptimizerCaller {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
|
|
|
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
|
|
|
return nullptr;
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|