|
|
|
@ -64,7 +64,7 @@ class PIsEqual {
|
|
|
|
|
template <typename T = AnfNodePtr>
|
|
|
|
|
class PatternNode : public PBase<PatternNode<T> > {
|
|
|
|
|
public:
|
|
|
|
|
T GetNode(const AnfNodePtr &node) const {
|
|
|
|
|
T GetNode(const AnfNodePtr &) const {
|
|
|
|
|
if (!captured_) {
|
|
|
|
|
MS_EXCEPTION(ValueError) << "A Pattern wasn't captured for this Token before the call to GetNode.";
|
|
|
|
|
}
|
|
|
|
@ -107,11 +107,11 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > {
|
|
|
|
|
auto inputs = cnode->inputs();
|
|
|
|
|
if (inputs.size() == 3) {
|
|
|
|
|
// Binary Prim assumes only two inputs
|
|
|
|
|
if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) {
|
|
|
|
|
if (!x_.TryCapture(inputs[1]) || !y_.TryCapture(inputs[2])) {
|
|
|
|
|
// If the operation is commutative, then check with inversed operands
|
|
|
|
|
if (is_commutative_) {
|
|
|
|
|
Reset();
|
|
|
|
|
if (!x_.TryCapture_(inputs[2]) || !y_.TryCapture_(inputs[1])) {
|
|
|
|
|
if (!x_.TryCapture(inputs[2]) || !y_.TryCapture(inputs[1])) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
@ -207,30 +207,77 @@ class PCNode : public PBase<PCNode<TArgs...> > {
|
|
|
|
|
AnfNodePtr GetNode(const AnfNodePtr &node) const {
|
|
|
|
|
tuple_utils::PTupleGetNode get_node(node);
|
|
|
|
|
tuple_utils::apply_func_tuple(&get_node, args_);
|
|
|
|
|
return NewCNode(get_node.args_, node->func_graph());
|
|
|
|
|
auto prim_cnode = get_node.args_;
|
|
|
|
|
// In case this PCNode has captured extra nodes
|
|
|
|
|
if (extra_nodes_.size() > 0) {
|
|
|
|
|
prim_cnode.insert(prim_cnode.begin(), extra_nodes_.begin(), extra_nodes_.end());
|
|
|
|
|
}
|
|
|
|
|
return NewCNode(prim_cnode, node->func_graph());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool TryCapture_(const AnfNodePtr &node) const {
|
|
|
|
|
if (node->isa<CNode>()) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
auto inputs = cnode->inputs();
|
|
|
|
|
if (inputs.size() != sizeof...(TArgs)) {
|
|
|
|
|
|
|
|
|
|
auto pattern_arg_len = sizeof...(TArgs);
|
|
|
|
|
// There aren't enough inputs in Node to fill up the Pattern
|
|
|
|
|
if (inputs.size() < pattern_arg_len) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
tuple_utils::PTupleCapture capture_func(inputs);
|
|
|
|
|
tuple_utils::apply_func_tuple(&capture_func, args_);
|
|
|
|
|
return capture_func.captured_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Pattern must exactly match the number of Node inputs.
|
|
|
|
|
if (!has_min_extra_nodes_) {
|
|
|
|
|
// Inputs in Node perfectly match number of tokens in Pattern.
|
|
|
|
|
if ((inputs.size() - 1) == pattern_arg_len) {
|
|
|
|
|
AnfNodePtrList tokens(inputs.begin() + 1, inputs.end());
|
|
|
|
|
tuple_utils::PTupleCapture capture_func(tokens);
|
|
|
|
|
tuple_utils::apply_func_tuple(&capture_func, args_);
|
|
|
|
|
return capture_func.captured_;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Pattern may accept extra (non specified) nodes at the end of the CNode
|
|
|
|
|
// There must be at least `min_extra_nodes` additional nodes in the inputs.
|
|
|
|
|
if ((inputs.size() - 1) >= pattern_arg_len + min_extra_nodes_) {
|
|
|
|
|
AnfNodePtrList tokens(inputs.begin() + 1, inputs.begin() + 1 + pattern_arg_len);
|
|
|
|
|
tuple_utils::PTupleCapture capture_func(tokens);
|
|
|
|
|
tuple_utils::apply_func_tuple(&capture_func, args_);
|
|
|
|
|
// If it could capture the initial set of nodes specified in the Pattern
|
|
|
|
|
// and there are enough extra inputs to add
|
|
|
|
|
if (capture_func.captured_ && inputs.size() > pattern_arg_len + 1) {
|
|
|
|
|
extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end());
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return capture_func.captured_;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// This function sets the PCNode object to capture at least `min_extra_nodes_` nodes after the last one
|
|
|
|
|
/// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or
|
|
|
|
|
/// more nodes after the last one specified when building the PCNode.
|
|
|
|
|
const PCNode<TArgs...> &MinExtraNodes(const size_t &min_extra_nodes = 0) const {
|
|
|
|
|
has_min_extra_nodes_ = true;
|
|
|
|
|
min_extra_nodes_ = min_extra_nodes;
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Reset() const {
|
|
|
|
|
tuple_utils::PTupleResetCapture reset;
|
|
|
|
|
tuple_utils::apply_func_tuple(&reset, args_);
|
|
|
|
|
has_min_extra_nodes_ = false;
|
|
|
|
|
extra_nodes_.clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::tuple<typename TArgs::Internal...> args_;
|
|
|
|
|
mutable AnfNodePtrList extra_nodes_;
|
|
|
|
|
mutable bool has_min_extra_nodes_{false};
|
|
|
|
|
mutable size_t min_extra_nodes_{0};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename... TArgs>
|
|
|
|
@ -243,6 +290,11 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > {
|
|
|
|
|
tuple_utils::apply_func_tuple(&get_node, args_);
|
|
|
|
|
auto prim_cnode = get_node.args_;
|
|
|
|
|
prim_cnode.insert(prim_cnode.begin(), NewValueNode(prim_));
|
|
|
|
|
|
|
|
|
|
// In case this PPrimitive has captured extra nodes
|
|
|
|
|
if (extra_nodes_.size() > 0) {
|
|
|
|
|
prim_cnode.insert(prim_cnode.begin(), extra_nodes_.begin(), extra_nodes_.end());
|
|
|
|
|
}
|
|
|
|
|
return NewCNode(prim_cnode, node->func_graph());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -250,35 +302,66 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > {
|
|
|
|
|
if (IsPrimitiveCNode(node, prim_)) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
auto inputs = cnode->inputs();
|
|
|
|
|
if ((inputs.size() - 1) != sizeof...(TArgs)) {
|
|
|
|
|
// Number of arguments in Primitive Pattern (not including the Primitive node)
|
|
|
|
|
auto pattern_arg_len = sizeof...(TArgs);
|
|
|
|
|
// There aren't enough inputs in Node to fill up the Pattern
|
|
|
|
|
if ((inputs.size() - 1) < pattern_arg_len) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtrList rest(inputs.begin() + 1, inputs.end());
|
|
|
|
|
tuple_utils::PTupleCapture capture_func(rest);
|
|
|
|
|
tuple_utils::apply_func_tuple(&capture_func, args_);
|
|
|
|
|
// Pattern must exactly match the number of Node inputs.
|
|
|
|
|
if (!has_min_extra_nodes_) {
|
|
|
|
|
// Inputs in Node perfectly match number of tokens in Pattern.
|
|
|
|
|
if ((inputs.size() - 1) == pattern_arg_len) {
|
|
|
|
|
AnfNodePtrList tokens(inputs.begin() + 1, inputs.end());
|
|
|
|
|
tuple_utils::PTupleCapture capture_func(tokens);
|
|
|
|
|
tuple_utils::apply_func_tuple(&capture_func, args_);
|
|
|
|
|
return capture_func.captured_;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return capture_func.captured_;
|
|
|
|
|
// Pattern may accept extra (non specified) nodes at the end of the Primitive
|
|
|
|
|
// There must be at least `min_extra_nodes` additional nodes in the inputs.
|
|
|
|
|
if ((inputs.size() - 1) >= pattern_arg_len + min_extra_nodes_) {
|
|
|
|
|
AnfNodePtrList tokens(inputs.begin() + 1, inputs.begin() + 1 + pattern_arg_len);
|
|
|
|
|
tuple_utils::PTupleCapture capture_func(tokens);
|
|
|
|
|
tuple_utils::apply_func_tuple(&capture_func, args_);
|
|
|
|
|
// If it could capture the initial set of nodes specified in the Pattern
|
|
|
|
|
// and there are enough extra inputs to add
|
|
|
|
|
if (capture_func.captured_ && inputs.size() > pattern_arg_len + 1) {
|
|
|
|
|
extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end());
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return capture_func.captured_;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// If set to true, TryCapture will try to capture the nodes in iversed nodes as well (only for two input case)
|
|
|
|
|
const PPrimitive<TArgs...> &Commutative(const bool &is_commutative = true) const {
|
|
|
|
|
is_commutative_ = is_commutative;
|
|
|
|
|
/// This function sets the PPrimitive object to capture at least `min_extra_nodes_` nodes after the last one
|
|
|
|
|
/// defined in the Pattern. e.g. `min_extra_nodes_ = 1` means the Pattern will be valid if there is one or
|
|
|
|
|
/// more nodes after the last one specified when building the PPrimitive.
|
|
|
|
|
const PPrimitive<TArgs...> &MinExtraNodes(const size_t &min_extra_nodes = 0) const {
|
|
|
|
|
has_min_extra_nodes_ = true;
|
|
|
|
|
min_extra_nodes_ = min_extra_nodes;
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Reset() const {
|
|
|
|
|
tuple_utils::PTupleResetCapture reset;
|
|
|
|
|
tuple_utils::apply_func_tuple(&reset, args_);
|
|
|
|
|
has_min_extra_nodes_ = false;
|
|
|
|
|
extra_nodes_.clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const PrimitivePtr prim_;
|
|
|
|
|
std::tuple<typename TArgs::Internal...> args_;
|
|
|
|
|
mutable bool is_commutative_{false};
|
|
|
|
|
mutable AnfNodePtrList extra_nodes_;
|
|
|
|
|
mutable bool has_min_extra_nodes_{false};
|
|
|
|
|
mutable size_t min_extra_nodes_{0};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
///
|
|
|
|
|