|
|
|
@ -21,109 +21,70 @@
|
|
|
|
|
|
|
|
|
|
#include "optimizer/optimizer.h"
|
|
|
|
|
#include "optimizer/irpass.h"
|
|
|
|
|
#include "ir/visitor.h"
|
|
|
|
|
#include "operator/ops.h"
|
|
|
|
|
#include "utils/graph_utils.h"
|
|
|
|
|
#include "operator/composite/composite.h"
|
|
|
|
|
#include "ir/pattern_matcher.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
|
namespace irpass {
|
|
|
|
|
// {prim::kPrimMakeRef, X, Y, Z} -> Y
|
|
|
|
|
class MakeRefEliminater : public AnfVisitor {
|
|
|
|
|
class MakeRefEliminater : public OptimizerCaller {
|
|
|
|
|
public:
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
|
y_ = nullptr;
|
|
|
|
|
auto gety = [this](const AnfNodePtr &node) -> bool {
|
|
|
|
|
this->y_ = node;
|
|
|
|
|
return true;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
AnfVisitor::Match(prim::kPrimMakeRef, {IsNode, gety, IsNode})(node);
|
|
|
|
|
return y_;
|
|
|
|
|
PatternNode<AnfNodePtr> x, y, z;
|
|
|
|
|
MATCH_REPLACE(node, PPrimitive(prim::kPrimMakeRef, x, y, z), y);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &) override {}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
AnfNodePtr y_{nullptr};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimGetRefValue, Parameter} -> Parameter
|
|
|
|
|
// {prim::kPrimGetRefOrigin, Parameter} -> Parameter
|
|
|
|
|
class GetRefParamEliminater : public AnfVisitor {
|
|
|
|
|
class GetRefParamEliminater : public OptimizerCaller {
|
|
|
|
|
public:
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
|
x_ = nullptr;
|
|
|
|
|
AnfVisitor::Match(prim::kPrimGetRefOrigin, {IsParam})(node);
|
|
|
|
|
if (x_ != nullptr) {
|
|
|
|
|
return x_;
|
|
|
|
|
}
|
|
|
|
|
AnfVisitor::Match(prim::kPrimGetRefValue, {IsParam})(node);
|
|
|
|
|
return x_;
|
|
|
|
|
PatternNode<AnfNodePtr> x;
|
|
|
|
|
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsParam, node));
|
|
|
|
|
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefOrigin, x), x, x.CheckFunc(IsParam, node));
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override { x_ = node; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
AnfNodePtr x_{nullptr};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
|
|
|
|
|
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
|
|
|
|
|
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
|
|
|
|
|
class GetMakeRefEliminater : public AnfVisitor {
|
|
|
|
|
class GetMakeRefEliminater : public OptimizerCaller {
|
|
|
|
|
public:
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
if (cnode == nullptr || cnode->size() != 2) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimGetRefKey/Value, {...}}
|
|
|
|
|
auto ref = cnode->input(1)->cast<CNodePtr>();
|
|
|
|
|
if (ref == nullptr || !ref->IsApply(prim::kPrimMakeRef) || ref->size() != 4) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// {prim::kPrimMakeRef, X, Y, Z}
|
|
|
|
|
if (cnode->IsApply(prim::kPrimGetRefKey)) {
|
|
|
|
|
return ref->input(1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (cnode->IsApply(prim::kPrimGetRefValue)) {
|
|
|
|
|
return ref->input(2);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (cnode->IsApply(prim::kPrimGetRefOrigin)) {
|
|
|
|
|
return ref->input(3);
|
|
|
|
|
}
|
|
|
|
|
PatternNode<AnfNodePtr> x, y, z;
|
|
|
|
|
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
|
|
|
|
|
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
|
|
|
|
|
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z);
|
|
|
|
|
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// IsValueNode<RefKey>
|
|
|
|
|
class ReplaceRefkeyByParam : public AnfVisitor {
|
|
|
|
|
class ReplaceRefkeyByParam : public OptimizerCaller {
|
|
|
|
|
public:
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
|
|
|
|
if (!IsValueNode<RefKey>(node)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto refkey = GetValueNode<RefKeyPtr>(node);
|
|
|
|
|
auto resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(resource);
|
|
|
|
|
|
|
|
|
|
auto top_graph = resource->func_graph();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(top_graph);
|
|
|
|
|
|
|
|
|
|
for (const auto &tnode : top_graph->parameters()) {
|
|
|
|
|
auto para = tnode->cast<ParameterPtr>();
|
|
|
|
|
if (para != nullptr && para->name() == refkey->tag()) {
|
|
|
|
|
return para;
|
|
|
|
|
auto RefKeyLambda = [&node, &optimizer]() -> AnfNodePtr {
|
|
|
|
|
auto refkey = GetValueNode<RefKeyPtr>(node);
|
|
|
|
|
auto resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(resource);
|
|
|
|
|
|
|
|
|
|
auto top_graph = resource->func_graph();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(top_graph);
|
|
|
|
|
|
|
|
|
|
for (const auto &tnode : top_graph->parameters()) {
|
|
|
|
|
auto para = tnode->cast<ParameterPtr>();
|
|
|
|
|
if (para != nullptr && para->name() == refkey->tag()) {
|
|
|
|
|
return para;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
};
|
|
|
|
|
PatternNode<AnfNodePtr> x;
|
|
|
|
|
MATCH_REPLACE_LAMBDA_IF(node, x, RefKeyLambda, x.CheckFunc(IsValueNode<RefKey>, node));
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|