|
|
|
@ -22,112 +22,13 @@ namespace framework {
|
|
|
|
|
namespace ir {
|
|
|
|
|
namespace patterns {
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
struct Pattern : public PatternBase {
|
|
|
|
|
Pattern(PDPattern* pattern, const std::string& name_scope)
|
|
|
|
|
: PatternBase{pattern, name_scope, ""} {}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::string name_scope() { return name_scope_; }
|
|
|
|
|
std::string repr() { return repr_; }
|
|
|
|
|
size_t id() { return id_; }
|
|
|
|
|
PDPattern* node_pattern() { return pattern; }
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
std::string node_name(std::string op_name) {
|
|
|
|
|
return PDNodeName(name_scope(), repr(), id(), op_name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PDNode* retrieve_node(std::string op_name) {
|
|
|
|
|
return node_pattern()->RetrieveNode(node_name(op_name));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PDNode* new_node(std::string op_name) {
|
|
|
|
|
return node_pattern()->NewNode(node_name(op_name));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
*/
|
|
|
|
|
/*
|
|
|
|
|
struct Conv {
|
|
|
|
|
std::string op_name() const { return "conv2d"; }
|
|
|
|
|
std::string input_name() const { return "Input"; }
|
|
|
|
|
std::string bias_name() const { return "Bias"; }
|
|
|
|
|
std::string filter_name() const { return "Filter"; }
|
|
|
|
|
std::string residual_data_name() const { return "ResidualData"; }
|
|
|
|
|
std::string output_name() const { return "Output"; }
|
|
|
|
|
|
|
|
|
|
std::function<PDNode*()> operator()(std::shared_ptr<Pattern> pattern) {
|
|
|
|
|
return [&]() -> PDNode* {
|
|
|
|
|
auto conv_op = pattern->new_node(op_name())->assert_is_op(op_name());
|
|
|
|
|
|
|
|
|
|
auto input_var = pattern->new_node(input_name())
|
|
|
|
|
->assert_is_op_input(op_name(), input_name());
|
|
|
|
|
|
|
|
|
|
auto bias_var = pattern->new_node(bias_name())
|
|
|
|
|
->assert_is_op_input(op_name(), bias_name());
|
|
|
|
|
|
|
|
|
|
auto filter_var = pattern->new_node(filter_name())
|
|
|
|
|
->assert_is_op_input(op_name(), filter_name());
|
|
|
|
|
|
|
|
|
|
auto output_var = pattern->new_node(output_name())
|
|
|
|
|
->assert_is_op_output(op_name(), output_name());
|
|
|
|
|
|
|
|
|
|
conv_op->LinksFrom({input_var, bias_var, filter_var});
|
|
|
|
|
conv_op->LinksTo({output_var});
|
|
|
|
|
|
|
|
|
|
return output_var;
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct ElementwiseAdd {
|
|
|
|
|
std::string op_name() const { return "elementwise_add"; }
|
|
|
|
|
std::string x_name() const { return "X"; }
|
|
|
|
|
std::string y_name() const { return "Y"; }
|
|
|
|
|
std::string out_name() const { return "Out"; }
|
|
|
|
|
|
|
|
|
|
std::function<PDNode*(PDNode*)> operator()(std::shared_ptr<Pattern> pattern) {
|
|
|
|
|
return [&](PDNode* conv_output) -> PDNode* {
|
|
|
|
|
auto elementwise_add_op =
|
|
|
|
|
pattern->new_node(op_name())->assert_is_op(op_name());
|
|
|
|
|
|
|
|
|
|
auto x_var =
|
|
|
|
|
pattern->new_node(x_name())->assert_is_op_input(op_name(), x_name());
|
|
|
|
|
|
|
|
|
|
conv_output->assert_is_op_input(op_name(), y_name());
|
|
|
|
|
|
|
|
|
|
auto out_var = pattern->new_node(out_name())
|
|
|
|
|
->AsOutput()
|
|
|
|
|
->assert_is_op_output(op_name(), out_name());
|
|
|
|
|
|
|
|
|
|
elementwise_add_op->LinksFrom({x_var, conv_output});
|
|
|
|
|
elementwise_add_op->LinksTo({out_var});
|
|
|
|
|
|
|
|
|
|
return out_var;
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
*/
|
|
|
|
|
/*
|
|
|
|
|
Node* GetNodeFromSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
std::shared_ptr<patterns::Pattern> pattern,
|
|
|
|
|
const std::string& op_name) {
|
|
|
|
|
PADDLE_ENFORCE(subgraph.count(pattern->retrieve_node(op_name)),
|
|
|
|
|
"Node not found for PDNode %s", pattern->node_name(op_name));
|
|
|
|
|
Node* var = subgraph.at(pattern->retrieve_node(op_name));
|
|
|
|
|
PADDLE_ENFORCE(var, "node %s not exists in the sub-graph");
|
|
|
|
|
|
|
|
|
|
return var;
|
|
|
|
|
}
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
void LinkNodes(Node* from, Node* to) {
|
|
|
|
|
static void LinkNodes(Node* from, Node* to) {
|
|
|
|
|
from->outputs.push_back(to);
|
|
|
|
|
to->inputs.push_back(from);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename IT, typename FindFunc, typename ReplaceFunc>
|
|
|
|
|
void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) {
|
|
|
|
|
static void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) {
|
|
|
|
|
if (s == e) return;
|
|
|
|
|
|
|
|
|
|
auto it = std::find_if(s, e, f);
|
|
|
|
@ -140,7 +41,7 @@ void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) {
|
|
|
|
|
ReplaceAllOccurances(it, e, f, r);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
|
|
|
|
|
static void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
|
|
|
|
|
for (auto& node : GraphTraits::DFS(*graph)) {
|
|
|
|
|
auto same = std::find_if(std::begin(node.inputs), std::end(node.inputs),
|
|
|
|
|
[from](Node* n) { return n == from; });
|
|
|
|
|