commit
ad5f0e6018
@ -1,32 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class EagerDeletionPass : public ir::Pass {
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,62 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
||||
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
||||
#include "paddle/fluid/framework/ir/graph_helper.h"
|
||||
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class WhileOpEagerDeletionPass : public ir::Pass {
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override {
|
||||
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
|
||||
|
||||
// Find all while_op and while_grad_op
|
||||
std::unordered_map<size_t, std::pair<std::vector<OperatorBase *>,
|
||||
std::vector<OperatorBase *>>>
|
||||
target_ops;
|
||||
for (auto *op : all_ops) {
|
||||
auto compute_op = dynamic_cast<ComputationOpHandle *>(op);
|
||||
if (compute_op == nullptr) continue;
|
||||
|
||||
if (compute_op->Name() == "while") {
|
||||
target_ops[compute_op->GetScopeIdx()].first.emplace_back(
|
||||
compute_op->GetOp());
|
||||
} else if (compute_op->Name() == "while_grad") {
|
||||
target_ops[compute_op->GetScopeIdx()].second.emplace_back(
|
||||
compute_op->GetOp());
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &ops_pair : target_ops) {
|
||||
auto &while_ops = ops_pair.second.first;
|
||||
auto &while_grad_ops = ops_pair.second.second;
|
||||
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
|
||||
while_ops, while_grad_ops);
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(while_op_eager_deletion_pass,
|
||||
paddle::framework::details::WhileOpEagerDeletionPass);
|
@ -0,0 +1,146 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file eint8_outcept in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either eint8_outpress or
|
||||
// implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/cpu_quantize_squash_pass.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
#include "paddle/fluid/string/pretty_log.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
using string::PrettyLogDetail;
|
||||
|
||||
void CPUQuantizeSquashPass::FindNodesToKeep(
|
||||
Graph* graph,
|
||||
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
|
||||
GraphPatternDetector gpd;
|
||||
patterns::DequantAny deq_any_pattern{gpd.mutable_pattern(), "deqant_any"};
|
||||
deq_any_pattern();
|
||||
|
||||
int found_count = 0;
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, deq_any_pattern);
|
||||
|
||||
if (nodes_keep_counter->find(dequant_out) == nodes_keep_counter->end())
|
||||
(*nodes_keep_counter)[dequant_out] = 1;
|
||||
else
|
||||
(*nodes_keep_counter)[dequant_out] += 1;
|
||||
|
||||
found_count++;
|
||||
};
|
||||
gpd(graph, handler);
|
||||
AddStatis(found_count);
|
||||
}
|
||||
|
||||
void CPUQuantizeSquashPass::Squash(
|
||||
Graph* graph,
|
||||
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
|
||||
GraphPatternDetector gpd;
|
||||
patterns::DequantQuantAny squash_pattern{gpd.mutable_pattern(), "squash"};
|
||||
squash_pattern();
|
||||
|
||||
int found_squash_count = 0;
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
VLOG(4) << "squash requantize-quantize ops pair";
|
||||
|
||||
GET_IR_NODE_FROM_SUBGRAPH(dequant_in, dequant_in, squash_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, squash_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, squash_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, squash_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern);
|
||||
|
||||
auto* next_op_desc = next_op->Op();
|
||||
float dequant_scale = boost::get<float>(dequant_op->Op()->GetAttr("Scale"));
|
||||
float quant_scale = boost::get<float>(quant_op->Op()->GetAttr("Scale"));
|
||||
PADDLE_ENFORCE(nodes_keep_counter->find(dequant_out) !=
|
||||
nodes_keep_counter->end());
|
||||
|
||||
// check if dequantize op should be kept or removed, decrease the counter
|
||||
bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1;
|
||||
|
||||
if (dequant_scale == quant_scale) {
|
||||
// squash dequantize-quantize to nothing
|
||||
auto quant_out_var_name = quant_out->Name();
|
||||
auto next_op_inputs = next_op_desc->InputNames();
|
||||
for (const auto& name : next_op_inputs) {
|
||||
auto var_name = next_op_desc->Input(name)[0];
|
||||
if (var_name.compare(quant_out_var_name) == 0) {
|
||||
next_op_desc->SetInput(
|
||||
name, std::vector<std::string>({dequant_in->Name()}));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (keep_dequant)
|
||||
GraphSafeRemoveNodes(graph, {quant_op, quant_out});
|
||||
else
|
||||
GraphSafeRemoveNodes(graph,
|
||||
{dequant_op, quant_op, dequant_out, quant_out});
|
||||
|
||||
IR_NODE_LINK_TO(dequant_in, next_op);
|
||||
|
||||
found_squash_count++;
|
||||
} else {
|
||||
// squash dequantize-quantize to requantize op
|
||||
OpDesc desc;
|
||||
desc.SetType("requantize");
|
||||
desc.SetInput("Input", std::vector<std::string>({dequant_in->Name()}));
|
||||
desc.SetOutput("Output", std::vector<std::string>({quant_out->Name()}));
|
||||
desc.SetAttr("Scale_in", dequant_scale);
|
||||
desc.SetAttr("Scale_out", quant_scale);
|
||||
|
||||
auto requant_op = g->CreateOpNode(&desc);
|
||||
|
||||
if (keep_dequant)
|
||||
GraphSafeRemoveNodes(graph, {quant_op});
|
||||
else
|
||||
GraphSafeRemoveNodes(graph, {dequant_op, quant_op, dequant_out});
|
||||
|
||||
IR_NODE_LINK_TO(dequant_in, requant_op);
|
||||
IR_NODE_LINK_TO(requant_op, quant_out);
|
||||
|
||||
found_squash_count++;
|
||||
}
|
||||
};
|
||||
gpd(graph, handler);
|
||||
AddStatis(found_squash_count);
|
||||
PrettyLogDetail("--- squashed %d dequantize-quantize pairs",
|
||||
found_squash_count);
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> CPUQuantizeSquashPass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
PADDLE_ENFORCE(graph.get());
|
||||
FusePassBase::Init("cpu_quantize_squash_pass", graph.get());
|
||||
|
||||
std::unordered_map<const Node*, int> nodes_keep_counter;
|
||||
FindNodesToKeep(graph.get(), &nodes_keep_counter);
|
||||
Squash(graph.get(), &nodes_keep_counter);
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(cpu_quantize_squash_pass,
|
||||
paddle::framework::ir::CPUQuantizeSquashPass);
|
@ -0,0 +1,58 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
/*
|
||||
* Squash dequantize->quantize pair pattern into requantize op
|
||||
*/
|
||||
class CPUQuantizeSquashPass : public FusePassBase {
|
||||
public:
|
||||
virtual ~CPUQuantizeSquashPass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
|
||||
/*
|
||||
* For each dequantize's output find the number of operators it is an input to
|
||||
*/
|
||||
void FindNodesToKeep(
|
||||
Graph* graph,
|
||||
std::unordered_map<const Node*, int>* nodes_keep_counter) const;
|
||||
|
||||
/*
|
||||
* Squash dequantize-quantize ops pairs into requantize or nothing
|
||||
*/
|
||||
void Squash(Graph* graph,
|
||||
std::unordered_map<const Node*, int>* nodes_keep_counter) const;
|
||||
|
||||
const std::string name_scope_{"squash"};
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,179 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/cpu_quantize_squash_pass.h"
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/framework/naive_executor.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
|
||||
const std::vector<std::string>& inputs,
|
||||
const std::vector<std::string>& outputs, bool use_mkldnn,
|
||||
float scale = 0) {
|
||||
auto* op = prog->MutableBlock(0)->AppendOp();
|
||||
op->SetType(type);
|
||||
op->SetAttr("use_mkldnn", use_mkldnn);
|
||||
op->SetAttr("name", name);
|
||||
if (type == "conv2d") {
|
||||
op->SetInput("Input", {inputs[0]});
|
||||
if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]});
|
||||
if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]});
|
||||
op->SetOutput("Output", {outputs[0]});
|
||||
} else if (type == "quantize") {
|
||||
op->SetInput("Input", {inputs[0]});
|
||||
op->SetOutput("Output", {outputs[0]});
|
||||
op->SetAttr("Scale", scale);
|
||||
} else if (type == "dequantize") {
|
||||
op->SetInput("Input", {inputs[0]});
|
||||
op->SetOutput("Output", {outputs[0]});
|
||||
op->SetAttr("Scale", scale);
|
||||
}
|
||||
}
|
||||
|
||||
// (a,w1,b1)->Conv1->d
|
||||
// d->Dequant->e
|
||||
// e->Quant->f
|
||||
// (f,w2,b2)->Conv2->i
|
||||
ProgramDesc BuildProgramDesc(bool use_mkldnn, float scale1, float scale2) {
|
||||
ProgramDesc prog;
|
||||
for (auto& v : std::initializer_list<std::string>(
|
||||
{"a", "w1", "b1", "d", "e", "f", "w2", "b2", "i"})) {
|
||||
auto* var = prog.MutableBlock(0)->Var(v);
|
||||
if (v.find("w") == 0 || v.find("b") == 0) {
|
||||
var->SetPersistable(true);
|
||||
}
|
||||
}
|
||||
|
||||
SetOp(&prog, "conv2d", "Conv1", {"a", "w1", "b1"}, {"d"}, use_mkldnn);
|
||||
SetOp(&prog, "dequantize", "Dequant", {"d"}, {"e"}, use_mkldnn, scale1);
|
||||
SetOp(&prog, "quantize", "Quant", {"e"}, {"f"}, use_mkldnn, scale2);
|
||||
SetOp(&prog, "conv2d", "Conv2", {"f", "w2", "b2"}, {"i"}, use_mkldnn);
|
||||
return prog;
|
||||
}
|
||||
|
||||
static const std::initializer_list<std::string> variable_names{
|
||||
"a", "b", "c", "d", "e", "f", "g", "h"};
|
||||
// a->Conv1->b
|
||||
// b->Dequant->c
|
||||
//
|
||||
// c->Quant1->d and d->Conv2->e
|
||||
//
|
||||
// c->Conv3->f
|
||||
//
|
||||
// c->Quant2->g and g->Conv4->h
|
||||
//
|
||||
ProgramDesc BuildProgramDesc2(bool use_mkldnn, float scale1, float scale2,
|
||||
float scale3) {
|
||||
ProgramDesc prog;
|
||||
for (auto& v : variable_names) {
|
||||
prog.MutableBlock(0)->Var(v);
|
||||
}
|
||||
|
||||
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn);
|
||||
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1);
|
||||
|
||||
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, use_mkldnn, scale2);
|
||||
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn);
|
||||
|
||||
SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn);
|
||||
|
||||
SetOp(&prog, "quantize", "Quant2", {"c"}, {"g"}, use_mkldnn, scale3);
|
||||
SetOp(&prog, "conv2d", "Conv4", {"g"}, {"h"}, use_mkldnn);
|
||||
|
||||
return prog;
|
||||
}
|
||||
|
||||
void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
|
||||
const char* var_name) {
|
||||
auto x = scope->Var(var_name);
|
||||
auto tensor = x->GetMutable<LoDTensor>();
|
||||
tensor->mutable_data(place, proto::VarType::FP32,
|
||||
::paddle::memory::Allocator::kDefault, 1);
|
||||
}
|
||||
|
||||
void MainTest(const ProgramDesc& prog, int removed_nodes_num) {
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
|
||||
// Init scope, as it is used in pass
|
||||
auto place = paddle::platform::CPUPlace();
|
||||
NaiveExecutor exe{place};
|
||||
Scope scope;
|
||||
exe.CreateVariables(prog, 0, true, &scope);
|
||||
|
||||
for (auto& v : variable_names) {
|
||||
InitTensorHolder(&scope, place, v.c_str());
|
||||
}
|
||||
|
||||
graph->Set(kParamScopeAttr, new framework::Scope*(&scope));
|
||||
|
||||
auto pass = PassRegistry::Instance().Get("cpu_quantize_squash_pass");
|
||||
|
||||
int original_nodes_num = graph->Nodes().size();
|
||||
|
||||
graph = pass->Apply(std::move(graph));
|
||||
|
||||
int current_nodes_num = graph->Nodes().size();
|
||||
|
||||
EXPECT_EQ(original_nodes_num - removed_nodes_num, current_nodes_num);
|
||||
}
|
||||
|
||||
TEST(CpuQuantizeSquashPass, equal_scales) {
|
||||
auto scale = 1.2345f;
|
||||
auto use_mkldnn = true;
|
||||
// Remove 4 nodes: Dequant, Quant, e, f
|
||||
auto remove_nodes = 4;
|
||||
MainTest(BuildProgramDesc(use_mkldnn, scale, scale), remove_nodes);
|
||||
|
||||
use_mkldnn = !use_mkldnn;
|
||||
MainTest(BuildProgramDesc(use_mkldnn, scale, scale), remove_nodes);
|
||||
}
|
||||
|
||||
TEST(CpuQuantizeSquashPass, inequal_scales) {
|
||||
auto scale1 = 1.2345f;
|
||||
auto scale2 = 21.0f;
|
||||
auto use_mkldnn = true;
|
||||
// Remove 3 nodes: Dequant, Quant, e
|
||||
// Insert 1 node: requantize
|
||||
auto remove_nodes = 2;
|
||||
MainTest(BuildProgramDesc(use_mkldnn, scale1, scale2), remove_nodes);
|
||||
|
||||
use_mkldnn = !use_mkldnn;
|
||||
MainTest(BuildProgramDesc(use_mkldnn, scale1, scale2), remove_nodes);
|
||||
}
|
||||
|
||||
TEST(CpuQuantizeSquashPass, branch_to_equal_inequal_and_fp32) {
|
||||
// Delete both quantize ops,
|
||||
// bypass dequantize in both branches,
|
||||
// insert requantize on one branch
|
||||
auto scale = 1.2345f;
|
||||
auto scale2 = 21.0f;
|
||||
auto use_mkldnn = true;
|
||||
// Remove 3 nodes: Quant1, Quant2, g
|
||||
// Insert 1 node: requantize
|
||||
auto remove_nodes = 2;
|
||||
MainTest(BuildProgramDesc2(use_mkldnn, scale, scale, scale2), remove_nodes);
|
||||
|
||||
use_mkldnn = !use_mkldnn;
|
||||
MainTest(BuildProgramDesc2(use_mkldnn, scale, scale, scale2), remove_nodes);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(cpu_quantize_squash_pass);
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue