|
|
|
@ -20,51 +20,33 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace ir {
|
|
|
|
|
namespace patterns {
|
|
|
|
|
|
|
|
|
|
template <typename IT, typename FindFunc, typename ReplaceFunc>
|
|
|
|
|
static void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) {
|
|
|
|
|
if (s == e) return;
|
|
|
|
|
|
|
|
|
|
auto it = std::find_if(s, e, f);
|
|
|
|
|
|
|
|
|
|
if (it != e) {
|
|
|
|
|
r(*it);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
it++;
|
|
|
|
|
ReplaceAllOccurances(it, e, f, r);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
|
|
|
|
|
namespace {
|
|
|
|
|
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
|
|
|
|
|
for (auto& node : GraphTraits::DFS(*graph)) {
|
|
|
|
|
auto same = std::find_if(std::begin(node.inputs), std::end(node.inputs),
|
|
|
|
|
[from](Node* n) { return n == from; });
|
|
|
|
|
auto from_in_inputs =
|
|
|
|
|
std::find(std::begin(node.inputs), std::end(node.inputs), from);
|
|
|
|
|
|
|
|
|
|
if (same != std::end(node.inputs)) {
|
|
|
|
|
if (from_in_inputs != std::end(node.inputs)) {
|
|
|
|
|
IR_NODE_LINK_TO(to, (&node));
|
|
|
|
|
|
|
|
|
|
auto inputs = node.Op()->Inputs();
|
|
|
|
|
|
|
|
|
|
using input_type = VariableNameMap::value_type;
|
|
|
|
|
|
|
|
|
|
ReplaceAllOccurances(
|
|
|
|
|
std::begin(inputs), std::end(inputs),
|
|
|
|
|
[from](const input_type& i) -> bool {
|
|
|
|
|
auto params = i.second;
|
|
|
|
|
auto pi =
|
|
|
|
|
std::find_if(std::begin(params), std::end(params),
|
|
|
|
|
std::bind(std::equal_to<std::string>(),
|
|
|
|
|
from->Name(), std::placeholders::_1));
|
|
|
|
|
return pi != std::end(params);
|
|
|
|
|
},
|
|
|
|
|
[to, &node](const input_type& i) {
|
|
|
|
|
node.Op()->SetInput(i.first, {to->Name()});
|
|
|
|
|
});
|
|
|
|
|
std::for_each(std::begin(inputs), std::end(inputs),
|
|
|
|
|
[from, to, &node](const input_type& i) -> void {
|
|
|
|
|
auto param_names = i.second;
|
|
|
|
|
auto pi = std::find(std::begin(param_names),
|
|
|
|
|
std::end(param_names), from->Name());
|
|
|
|
|
|
|
|
|
|
if (pi != std::end(param_names)) {
|
|
|
|
|
node.Op()->SetInput(i.first, {to->Name()});
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace patterns
|
|
|
|
|
} // namespace
|
|
|
|
|
using graph_ptr = std::unique_ptr<ir::Graph>;
|
|
|
|
|
|
|
|
|
|
graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
@ -116,7 +98,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(fused_conv_op, conv_output);
|
|
|
|
|
|
|
|
|
|
patterns::CorrectGraphEdges(g, elementwise_add_out, conv_output);
|
|
|
|
|
CorrectGraphEdges(g, elementwise_add_out, conv_output);
|
|
|
|
|
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|