You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
233 lines
7.2 KiB
233 lines
7.2 KiB
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include <algorithm>
|
|
#include <cstring>
|
|
#include <exception>
|
|
#include <functional>
|
|
#include <iterator>
|
|
#include <list>
|
|
#include <map>
|
|
|
|
#include "paddle/fluid/framework/data_type.h"
|
|
#include "paddle/fluid/framework/ir/graph_traits.h"
|
|
#include "paddle/fluid/framework/ir/pass.h"
|
|
#include "paddle/fluid/framework/ir/pass_test_util.h"
|
|
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
|
|
#include "paddle/fluid/framework/op_proto_maker.h"
|
|
|
|
namespace paddle {
|
|
namespace framework {
|
|
namespace ir {
|
|
namespace test {
|
|
|
|
OpDesc* CreateOp(ProgramDesc* prog, const std::string& op_type_name,
|
|
const std::vector<InOutVarNamePair>& inputs,
|
|
const std::vector<InOutVarNamePair>& outputs,
|
|
bool use_mkldnn) {
|
|
auto* op = prog->MutableBlock(0)->AppendOp();
|
|
op->SetType(op_type_name);
|
|
op->SetAttr("use_mkldnn", use_mkldnn);
|
|
|
|
for (const auto& input : inputs) {
|
|
op->SetInput(input.first, {input.second});
|
|
}
|
|
for (const auto& output : outputs) {
|
|
op->SetOutput(output.first, {output.second});
|
|
}
|
|
|
|
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
|
|
static_cast<int>(OpRole::kForward));
|
|
return op;
|
|
}
|
|
|
|
bool TestIsReachable(const Graph& graph, std::string from, std::string to) {
|
|
auto hash = [](const Node* node) -> std::string {
|
|
return node->Name() + std::to_string(node->id());
|
|
};
|
|
|
|
auto find_node = [&](const Graph& graph, const std::string& name) -> Node* {
|
|
for (auto& node : GraphTraits::DFS(graph)) {
|
|
if (name == hash(&node)) {
|
|
return &node;
|
|
}
|
|
}
|
|
|
|
return nullptr;
|
|
};
|
|
|
|
if (from == to) return true;
|
|
|
|
std::map<std::string, bool> visited;
|
|
// update the from and to strings to hashed equivs in loop from graph traits
|
|
for (auto& node : GraphTraits::DFS(graph)) {
|
|
auto hashed = hash(&node);
|
|
if (node.Name() == from) {
|
|
from = hashed;
|
|
}
|
|
if (node.Name() == to) {
|
|
to = hashed;
|
|
}
|
|
visited[hashed] = false;
|
|
}
|
|
|
|
visited[from] = true;
|
|
|
|
std::list<std::string> queue;
|
|
queue.push_back(from);
|
|
|
|
while (!queue.empty()) {
|
|
auto cur = find_node(graph, queue.front());
|
|
queue.pop_front();
|
|
if (cur == nullptr) {
|
|
return false;
|
|
}
|
|
|
|
for (auto n : cur->outputs) {
|
|
auto hashed_name = hash(n);
|
|
if (hashed_name == to) {
|
|
return true;
|
|
}
|
|
|
|
if (!visited[hashed_name]) {
|
|
visited[hashed_name] = true;
|
|
queue.push_back(hashed_name);
|
|
}
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool AssertOpsCount(const Graph& graph,
|
|
std::vector<OpTypeCountPair> op_type_count) {
|
|
for (auto* node : graph.Nodes()) {
|
|
if (!node->IsOp()) {
|
|
continue;
|
|
}
|
|
|
|
const std::string op_type_name = node->Op()->Type();
|
|
auto op_it =
|
|
std::find_if(std::begin(op_type_count), std::end(op_type_count),
|
|
[op_type_name](const OpTypeCountPair& p) {
|
|
return op_type_name == p.first;
|
|
});
|
|
if (op_it != std::end(op_type_count)) {
|
|
op_it->second--;
|
|
}
|
|
}
|
|
|
|
bool result{true};
|
|
|
|
for (const OpTypeCountPair& p : op_type_count) {
|
|
result = result && (p.second == 0);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
|
|
const std::vector<std::string>& persistent_vars) {
|
|
ProgramDesc prog;
|
|
|
|
auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* {
|
|
auto var = prog.MutableBlock(0)->Var(var_name);
|
|
var->SetType(proto::VarType::LOD_TENSOR);
|
|
return var;
|
|
};
|
|
|
|
for (const auto& v : transient_vars) {
|
|
add_var_to_prog(v);
|
|
}
|
|
|
|
for (const auto& v : persistent_vars) {
|
|
auto* var = add_var_to_prog(v);
|
|
var->SetPersistable(true);
|
|
}
|
|
|
|
return prog;
|
|
}
|
|
|
|
bool RunPassAndAssert(Graph* graph, const std::string& pass_name,
|
|
const std::string& from, const std::string& to,
|
|
int removed_nodes_count, int added_nodes_count) {
|
|
if (!TestIsReachable(*graph, from, to)) return false;
|
|
|
|
int original_nodes_num = graph->Nodes().size();
|
|
auto pass = PassRegistry::Instance().Get(pass_name);
|
|
pass->Apply(graph);
|
|
int current_nodes_num = graph->Nodes().size();
|
|
|
|
if (!TestIsReachable(*graph, from, to)) return false;
|
|
|
|
int expected_nodes_num =
|
|
original_nodes_num - removed_nodes_count + added_nodes_count;
|
|
return expected_nodes_num == current_nodes_num;
|
|
}
|
|
|
|
template <typename T>
|
|
void InitLoDTensorHolder(const Scope& scope,
|
|
const paddle::platform::Place& place,
|
|
const std::string& var_name,
|
|
const std::vector<int64_t>& dims, const T* data) {
|
|
auto var = scope.FindLocalVar(var_name);
|
|
auto tensor = var->GetMutable<LoDTensor>();
|
|
auto* tensor_mem_ptr = tensor->mutable_data<T>(make_ddim(dims), place);
|
|
if (data != nullptr) {
|
|
std::memcpy(tensor_mem_ptr, data, tensor->memory_size());
|
|
} else {
|
|
std::memset(tensor_mem_ptr, 0, tensor->memory_size());
|
|
}
|
|
}
|
|
|
|
// Instantiate for below data types.
|
|
template void InitLoDTensorHolder<float>(const Scope&,
|
|
const paddle::platform::Place&,
|
|
const std::string&,
|
|
const std::vector<int64_t>&,
|
|
const float*);
|
|
template void InitLoDTensorHolder<int>(const Scope&,
|
|
const paddle::platform::Place&,
|
|
const std::string&,
|
|
const std::vector<int64_t>&, const int*);
|
|
template void InitLoDTensorHolder<double>(const Scope&,
|
|
const paddle::platform::Place&,
|
|
const std::string&,
|
|
const std::vector<int64_t>&,
|
|
const double*);
|
|
|
|
OpDesc* GetOp(const ProgramDesc& prog, const std::string& op_type,
|
|
const std::string& output_name,
|
|
const std::string& output_arg_name) {
|
|
return GetOp(prog.Block(0), op_type, output_name, output_arg_name);
|
|
}
|
|
|
|
OpDesc* GetOp(const BlockDesc& block_desc, const std::string& op_type,
|
|
const std::string& output_name,
|
|
const std::string& output_arg_name) {
|
|
auto all_ops = block_desc.AllOps();
|
|
for (auto* op_desc : all_ops) {
|
|
if (op_desc->Type() == op_type && op_desc->HasOutput(output_name)) {
|
|
const auto& arg_names = op_desc->Outputs().at(output_name);
|
|
for (const auto& name : arg_names) {
|
|
if (name == output_arg_name) return op_desc;
|
|
}
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
} // namespace test
|
|
} // namespace ir
|
|
} // namespace framework
|
|
} // namespace paddle
|