|
|
|
@ -44,10 +44,14 @@ struct TestIsReachable {
|
|
|
|
|
using func = std::function<bool(const std::string&, const std::string&)>;
|
|
|
|
|
|
|
|
|
|
auto operator()(const std::unique_ptr<ir::Graph>& graph) -> func {
|
|
|
|
|
auto find_node = [](const std::unique_ptr<ir::Graph>& graph,
|
|
|
|
|
const std::string& name) -> Node* {
|
|
|
|
|
auto hash = [](const Node* node) -> std::string {
|
|
|
|
|
return node->Name() + std::to_string(node->id());
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto find_node = [&](const std::unique_ptr<ir::Graph>& graph,
|
|
|
|
|
const std::string& name) -> Node* {
|
|
|
|
|
for (auto& node : GraphTraits::DFS(*graph)) {
|
|
|
|
|
if (name == node.Name()) {
|
|
|
|
|
if (name == hash(&node)) {
|
|
|
|
|
return &node;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -55,13 +59,17 @@ struct TestIsReachable {
|
|
|
|
|
return nullptr;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
return [&](std::string from, const std::string to) -> bool {
|
|
|
|
|
// update the from and to strings to hashed equivs in loop from graph traits
|
|
|
|
|
return [&](std::string from, std::string to) -> bool {
|
|
|
|
|
if (from == to) return true;
|
|
|
|
|
|
|
|
|
|
std::map<std::string, bool> visited;
|
|
|
|
|
|
|
|
|
|
for (auto& node : GraphTraits::DFS(*graph)) {
|
|
|
|
|
visited[node.Name()] = false;
|
|
|
|
|
auto hashed = hash(&node);
|
|
|
|
|
if (node.Name() == from) from = hashed;
|
|
|
|
|
if (node.Name() == to) to = hashed;
|
|
|
|
|
visited[hashed] = false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
visited[from] = true;
|
|
|
|
@ -72,15 +80,15 @@ struct TestIsReachable {
|
|
|
|
|
while (!queue.empty()) {
|
|
|
|
|
auto cur = find_node(graph, queue.front());
|
|
|
|
|
queue.pop_front();
|
|
|
|
|
|
|
|
|
|
if (cur == nullptr) return false;
|
|
|
|
|
|
|
|
|
|
for (auto n : cur->outputs) {
|
|
|
|
|
if (n->Name() == to) return true;
|
|
|
|
|
auto hashed_name = hash(n);
|
|
|
|
|
if (hashed_name == to) return true;
|
|
|
|
|
|
|
|
|
|
if (!visited[n->Name()]) {
|
|
|
|
|
visited[n->Name()] = true;
|
|
|
|
|
queue.push_back(n->Name());
|
|
|
|
|
if (!visited[hashed_name]) {
|
|
|
|
|
visited[hashed_name] = true;
|
|
|
|
|
queue.push_back(hashed_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -166,6 +174,28 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) {
|
|
|
|
|
RunPassAndAssert(&prog, "a", "relu", 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(ConvElementwiseAddMKLDNNFusePass,
|
|
|
|
|
ConvolutionProjectionAsYWithElementwiseAddRelu) {
|
|
|
|
|
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e", "f"},
|
|
|
|
|
{"bias", "weights", "bias2", "weights2"});
|
|
|
|
|
|
|
|
|
|
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
|
|
|
|
|
// right branch
|
|
|
|
|
SetOp(&prog, "conv2d",
|
|
|
|
|
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
|
|
|
|
|
{"Output", "c"});
|
|
|
|
|
|
|
|
|
|
// left branch
|
|
|
|
|
SetOp(&prog, "conv2d",
|
|
|
|
|
{{"Input", "a"}, {"Bias", "bias2"}, {"Filter", "weights2"}},
|
|
|
|
|
{"Output", "f"});
|
|
|
|
|
|
|
|
|
|
SetOp(&prog, "elementwise_add", {{"X", "f"}, {"Y", "c"}}, {"Out", "d"});
|
|
|
|
|
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
|
|
|
|
|
|
|
|
|
|
RunPassAndAssert(&prog, "a", "relu", 2);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(ConvElementwiseAddMKLDNNFusePass,
|
|
|
|
|
ConvolutionAsYWithElementwiseAddReluNoBias) {
|
|
|
|
|
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
|
|
|
|
|