[oneDNN] Refactor fuse pass helper functions to one place. (#30460)
* Move pass tester helper functions to single common place. * Use helper functions in two more fuse pass tests.revert-31068-fix_conv3d_windows
parent
1d7bf1de2b
commit
c5ffad126c
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,174 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <list>
|
||||
#include <map>
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph_traits.h"
|
||||
#include "paddle/fluid/framework/ir/mkldnn/pass_test_util.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
namespace test {
|
||||
|
||||
OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name,
|
||||
const std::vector<InOutVarNamePair>& inputs,
|
||||
const std::vector<InOutVarNamePair>& outputs,
|
||||
bool use_mkldnn) {
|
||||
auto op = prog->MutableBlock(0)->AppendOp();
|
||||
op->SetType(op_type_name);
|
||||
op->SetAttr("use_mkldnn", use_mkldnn);
|
||||
|
||||
for (const auto& input : inputs) {
|
||||
op->SetInput(input.first, {input.second});
|
||||
}
|
||||
for (const auto& output : outputs) {
|
||||
op->SetOutput(output.first, {output.second});
|
||||
}
|
||||
|
||||
return op;
|
||||
}
|
||||
|
||||
bool TestIsReachable(const Graph& graph, std::string from, std::string to) {
|
||||
auto hash = [](const Node* node) -> std::string {
|
||||
return node->Name() + std::to_string(node->id());
|
||||
};
|
||||
|
||||
auto find_node = [&](const Graph& graph, const std::string& name) -> Node* {
|
||||
for (auto& node : GraphTraits::DFS(graph)) {
|
||||
if (name == hash(&node)) {
|
||||
return &node;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
if (from == to) return true;
|
||||
|
||||
std::map<std::string, bool> visited;
|
||||
// update the from and to strings to hashed equivs in loop from graph traits
|
||||
for (auto& node : GraphTraits::DFS(graph)) {
|
||||
auto hashed = hash(&node);
|
||||
if (node.Name() == from) {
|
||||
from = hashed;
|
||||
}
|
||||
if (node.Name() == to) {
|
||||
to = hashed;
|
||||
}
|
||||
visited[hashed] = false;
|
||||
}
|
||||
|
||||
visited[from] = true;
|
||||
|
||||
std::list<std::string> queue;
|
||||
queue.push_back(from);
|
||||
|
||||
while (!queue.empty()) {
|
||||
auto cur = find_node(graph, queue.front());
|
||||
queue.pop_front();
|
||||
if (cur == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto n : cur->outputs) {
|
||||
auto hashed_name = hash(n);
|
||||
if (hashed_name == to) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!visited[hashed_name]) {
|
||||
visited[hashed_name] = true;
|
||||
queue.push_back(hashed_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool AssertOpsCount(const Graph& graph,
|
||||
std::vector<OpTypeCountPair> op_type_count) {
|
||||
for (auto* node : graph.Nodes()) {
|
||||
if (!node->IsOp()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string op_type_name = node->Op()->Type();
|
||||
auto op_it =
|
||||
std::find_if(std::begin(op_type_count), std::end(op_type_count),
|
||||
[op_type_name](const OpTypeCountPair& p) {
|
||||
return op_type_name == p.first;
|
||||
});
|
||||
if (op_it != std::end(op_type_count)) {
|
||||
op_it->second--;
|
||||
}
|
||||
}
|
||||
|
||||
bool result{true};
|
||||
|
||||
for (const OpTypeCountPair& p : op_type_count) {
|
||||
result = result && (p.second == 0);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
|
||||
const std::vector<std::string>& persistent_vars) {
|
||||
ProgramDesc prog;
|
||||
|
||||
auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* {
|
||||
auto var = prog.MutableBlock(0)->Var(var_name);
|
||||
var->SetType(proto::VarType::LOD_TENSOR);
|
||||
return var;
|
||||
};
|
||||
|
||||
for (const auto& v : transient_vars) {
|
||||
add_var_to_prog(v);
|
||||
}
|
||||
|
||||
for (const auto& v : persistent_vars) {
|
||||
auto* var = add_var_to_prog(v);
|
||||
var->SetPersistable(true);
|
||||
}
|
||||
|
||||
return prog;
|
||||
}
|
||||
|
||||
bool RunPassAndAssert(Graph* graph, const std::string& pass_name,
|
||||
const std::string& from, const std::string& to,
|
||||
int removed_nodes_count, int added_nodes_count) {
|
||||
if (!TestIsReachable(*graph, from, to)) return false;
|
||||
|
||||
int original_nodes_num = graph->Nodes().size();
|
||||
auto pass = PassRegistry::Instance().Get(pass_name);
|
||||
pass->Apply(graph);
|
||||
int current_nodes_num = graph->Nodes().size();
|
||||
|
||||
if (!TestIsReachable(*graph, from, to)) return false;
|
||||
|
||||
int expected_nodes_num =
|
||||
original_nodes_num - removed_nodes_count + added_nodes_count;
|
||||
return expected_nodes_num == current_nodes_num;
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,119 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/op_desc.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
// -------------------------- helper functions --------------------------------
|
||||
namespace test {
|
||||
|
||||
/// The pair describing correlation between {input/output name, variable name}.
|
||||
using InOutVarNamePair = std::pair<std::string, std::string>;
|
||||
/// The pair describing number of occurrences of given op type.
|
||||
using OpTypeCountPair = std::pair<std::string, int>;
|
||||
|
||||
///
|
||||
/// @brief Creates the specified operator and sets up its inputs/outputs.
|
||||
///
|
||||
/// @param prog The program descriptor to which we add new op.
|
||||
/// @param[in] op_type_name The operator type name.
|
||||
/// @param[in] inputs The vector of input pairs: {input_name, variable
|
||||
/// name}
|
||||
/// @param[in] outputs The vector of output pairs {output_name, variable}
|
||||
/// @param[in] use_mkldnn The flag deciding whether or not to set
|
||||
/// 'use_mkldnn' attribute.
|
||||
///
|
||||
/// @return Returns pointer to the created operator descriptor.
|
||||
///
|
||||
OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name,
|
||||
const std::vector<InOutVarNamePair>& inputs,
|
||||
const std::vector<InOutVarNamePair>& outputs,
|
||||
bool use_mkldnn = true);
|
||||
|
||||
///
|
||||
/// @brief Check whether node 'to' is reachable from node 'from' in graph.
|
||||
///
|
||||
/// @param[in] graph The graph we're checking for reachability.
|
||||
/// @param[in] from The 'from' node name.
|
||||
/// @param[in] to The 'to' node name.
|
||||
///
|
||||
/// @return True if there is connection between nodes 'from' and 'to'.
|
||||
///
|
||||
bool TestIsReachable(const Graph& graph, std::string from, std::string to);
|
||||
|
||||
///
|
||||
/// @brief Search through graph and counts provided operator occurrences.
|
||||
///
|
||||
/// @param[in] graph The graph we search through.
|
||||
/// @param[in] op_type_count The vector of pairs {op_type_name, op count}
|
||||
///
|
||||
/// @note After going through all graph nodes this function asserts
|
||||
/// whether counted number for each requested op is as expected.
|
||||
///
|
||||
/// @return Returns true if occurrences of all ops is as expected.
|
||||
///
|
||||
bool AssertOpsCount(const Graph& graph,
|
||||
std::vector<OpTypeCountPair> op_type_count);
|
||||
|
||||
///
|
||||
/// @brief Builds a program descriptor.
|
||||
///
|
||||
/// @param[in] transient_vars The vector of transient variables names.
|
||||
/// @param[in] persistent_vars The vector of persistent variables names. Those
|
||||
/// will have persistable attribute set to true.
|
||||
///
|
||||
/// @return The program descriptor object.
|
||||
///
|
||||
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
|
||||
const std::vector<std::string>& persistent_vars);
|
||||
|
||||
///
|
||||
/// @brief Execute pass on provided graph and perform checks.
|
||||
///
|
||||
/// @note Check whether the balance of removed and added nodes after pass
|
||||
/// is as expected.
|
||||
///
|
||||
/// @param graph The graph we run pass on.
|
||||
/// @param[in] from The name of a 'starting' node sequence in a
|
||||
/// graph. This would be used to test for
|
||||
/// correct node connections.
|
||||
/// @param[in] to The name of a 'ending' node sequence in a
|
||||
/// graph. This would be used to test for
|
||||
/// correct node connections.
|
||||
/// @param[in] removed_nodes_count The number of nodes we expect will be
|
||||
/// removed/fused after pass execution.
|
||||
/// @param[in] added_nodes_count The number of nodes we expect will be added
|
||||
/// after pass execution.
|
||||
///
|
||||
/// @return Return true if all checks passed, otherwise false.
|
||||
///
|
||||
bool RunPassAndAssert(Graph* graph, const std::string& pass_name,
|
||||
const std::string& from, const std::string& to,
|
||||
int removed_nodes_count, int added_nodes_count = 0);
|
||||
|
||||
} // namespace test
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Loading…
Reference in new issue