|
|
|
@ -21,11 +21,45 @@
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph.h"
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
|
|
|
|
|
|
|
|
|
#include <boost/optional.hpp>
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace ir {
|
|
|
|
|
|
|
|
|
|
// poor replacement for C++17 std::optional and Boost.Optional
|
|
|
|
|
struct InPlace {};
|
|
|
|
|
InPlace in_place;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class Maybe {
|
|
|
|
|
private:
|
|
|
|
|
typename std::aligned_storage<sizeof(T), alignof(T)>::type data;
|
|
|
|
|
bool is_initialized{false};
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
template <typename... Args>
|
|
|
|
|
explicit Maybe(InPlace, Args&&... args) {
|
|
|
|
|
new (&data) T(std::forward<Args>(args)...);
|
|
|
|
|
is_initialized = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Maybe() {}
|
|
|
|
|
|
|
|
|
|
operator bool() { return is_initialized; }
|
|
|
|
|
|
|
|
|
|
T& value() { return *reinterpret_cast<T*>(&data); }
|
|
|
|
|
|
|
|
|
|
~Maybe() { reinterpret_cast<T*>(&data)->~T(); }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T, typename... Args>
|
|
|
|
|
Maybe<T> MakeMaybe(Args&&... args) {
|
|
|
|
|
return Maybe<T>(in_place, std::forward<Args>(args)...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
using graph_ptr = std::unique_ptr<ir::Graph>;
|
|
|
|
|
using GraphWithStats = std::pair<ir::Graph*, Maybe<int>>;
|
|
|
|
|
|
|
|
|
|
void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
|
|
|
|
|
bool IsReachable(ir::Graph* graph, Node* from, Node* to);
|
|
|
|
@ -33,8 +67,10 @@ std::pair<bool, Node*> HasBias(const Node& op, const std::string& bias_name);
|
|
|
|
|
|
|
|
|
|
class ResidualConnectionMKLDNNFusePass : public FusePassBase {
|
|
|
|
|
private:
|
|
|
|
|
graph_ptr FuseConvAsX(const std::string& name_scope_, graph_ptr graph) const;
|
|
|
|
|
graph_ptr FuseConvAsY(const std::string& name_scope_, graph_ptr graph) const;
|
|
|
|
|
GraphWithStats FuseConvAsX(const std::string& name_scope,
|
|
|
|
|
const GraphWithStats& graph_with_stats) const;
|
|
|
|
|
GraphWithStats FuseConvAsY(const std::string& name_scope,
|
|
|
|
|
const GraphWithStats& graph_with_stats) const;
|
|
|
|
|
|
|
|
|
|
template <typename RetType>
|
|
|
|
|
using GetNodeFunc =
|
|
|
|
@ -48,12 +84,15 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
|
|
|
|
|
const ElementwiseAddFunc& get_node_from_elementwise_add_op,
|
|
|
|
|
const CanFuseFunc& can_fuse_func);
|
|
|
|
|
|
|
|
|
|
void operator()(const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* graph);
|
|
|
|
|
int get_stats() const { return *fusion_stats; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::shared_ptr<int> fusion_stats;
|
|
|
|
|
ConvFunc get_node_from_conv_op;
|
|
|
|
|
ElementwiseAddFunc get_node_from_elementwise_add_op;
|
|
|
|
|
CanFuseFunc can_fuse_func;
|
|
|
|
|
|
|
|
|
|
void operator()(const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* graph);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|