commit
ad5f0e6018
@ -1,32 +0,0 @@
|
|||||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "paddle/fluid/framework/ir/graph.h"
|
|
||||||
#include "paddle/fluid/framework/ir/pass.h"
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
namespace framework {
|
|
||||||
namespace details {
|
|
||||||
|
|
||||||
class EagerDeletionPass : public ir::Pass {
|
|
||||||
protected:
|
|
||||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
|
||||||
std::unique_ptr<ir::Graph> graph) const override;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace details
|
|
||||||
} // namespace framework
|
|
||||||
} // namespace paddle
|
|
@ -0,0 +1,62 @@
|
|||||||
|
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
||||||
|
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
||||||
|
#include "paddle/fluid/framework/ir/graph_helper.h"
|
||||||
|
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace details {
|
||||||
|
|
||||||
|
class WhileOpEagerDeletionPass : public ir::Pass {
|
||||||
|
protected:
|
||||||
|
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||||
|
std::unique_ptr<ir::Graph> graph) const override {
|
||||||
|
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
|
||||||
|
|
||||||
|
// Find all while_op and while_grad_op
|
||||||
|
std::unordered_map<size_t, std::pair<std::vector<OperatorBase *>,
|
||||||
|
std::vector<OperatorBase *>>>
|
||||||
|
target_ops;
|
||||||
|
for (auto *op : all_ops) {
|
||||||
|
auto compute_op = dynamic_cast<ComputationOpHandle *>(op);
|
||||||
|
if (compute_op == nullptr) continue;
|
||||||
|
|
||||||
|
if (compute_op->Name() == "while") {
|
||||||
|
target_ops[compute_op->GetScopeIdx()].first.emplace_back(
|
||||||
|
compute_op->GetOp());
|
||||||
|
} else if (compute_op->Name() == "while_grad") {
|
||||||
|
target_ops[compute_op->GetScopeIdx()].second.emplace_back(
|
||||||
|
compute_op->GetOp());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &ops_pair : target_ops) {
|
||||||
|
auto &while_ops = ops_pair.second.first;
|
||||||
|
auto &while_grad_ops = ops_pair.second.second;
|
||||||
|
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
|
||||||
|
while_ops, while_grad_ops);
|
||||||
|
}
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace details
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
REGISTER_PASS(while_op_eager_deletion_pass,
|
||||||
|
paddle::framework::details::WhileOpEagerDeletionPass);
|
@ -0,0 +1,146 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file eint8_outcept in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either eint8_outpress or
|
||||||
|
// implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/ir/cpu_quantize_squash_pass.h"
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/platform/enforce.h"
|
||||||
|
#include "paddle/fluid/string/pretty_log.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
|
||||||
|
using string::PrettyLogDetail;
|
||||||
|
|
||||||
|
void CPUQuantizeSquashPass::FindNodesToKeep(
|
||||||
|
Graph* graph,
|
||||||
|
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
|
||||||
|
GraphPatternDetector gpd;
|
||||||
|
patterns::DequantAny deq_any_pattern{gpd.mutable_pattern(), "deqant_any"};
|
||||||
|
deq_any_pattern();
|
||||||
|
|
||||||
|
int found_count = 0;
|
||||||
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||||
|
Graph* g) {
|
||||||
|
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, deq_any_pattern);
|
||||||
|
|
||||||
|
if (nodes_keep_counter->find(dequant_out) == nodes_keep_counter->end())
|
||||||
|
(*nodes_keep_counter)[dequant_out] = 1;
|
||||||
|
else
|
||||||
|
(*nodes_keep_counter)[dequant_out] += 1;
|
||||||
|
|
||||||
|
found_count++;
|
||||||
|
};
|
||||||
|
gpd(graph, handler);
|
||||||
|
AddStatis(found_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CPUQuantizeSquashPass::Squash(
|
||||||
|
Graph* graph,
|
||||||
|
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
|
||||||
|
GraphPatternDetector gpd;
|
||||||
|
patterns::DequantQuantAny squash_pattern{gpd.mutable_pattern(), "squash"};
|
||||||
|
squash_pattern();
|
||||||
|
|
||||||
|
int found_squash_count = 0;
|
||||||
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||||
|
Graph* g) {
|
||||||
|
VLOG(4) << "squash requantize-quantize ops pair";
|
||||||
|
|
||||||
|
GET_IR_NODE_FROM_SUBGRAPH(dequant_in, dequant_in, squash_pattern);
|
||||||
|
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, squash_pattern);
|
||||||
|
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, squash_pattern);
|
||||||
|
GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, squash_pattern);
|
||||||
|
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern);
|
||||||
|
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern);
|
||||||
|
|
||||||
|
auto* next_op_desc = next_op->Op();
|
||||||
|
float dequant_scale = boost::get<float>(dequant_op->Op()->GetAttr("Scale"));
|
||||||
|
float quant_scale = boost::get<float>(quant_op->Op()->GetAttr("Scale"));
|
||||||
|
PADDLE_ENFORCE(nodes_keep_counter->find(dequant_out) !=
|
||||||
|
nodes_keep_counter->end());
|
||||||
|
|
||||||
|
// check if dequantize op should be kept or removed, decrease the counter
|
||||||
|
bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1;
|
||||||
|
|
||||||
|
if (dequant_scale == quant_scale) {
|
||||||
|
// squash dequantize-quantize to nothing
|
||||||
|
auto quant_out_var_name = quant_out->Name();
|
||||||
|
auto next_op_inputs = next_op_desc->InputNames();
|
||||||
|
for (const auto& name : next_op_inputs) {
|
||||||
|
auto var_name = next_op_desc->Input(name)[0];
|
||||||
|
if (var_name.compare(quant_out_var_name) == 0) {
|
||||||
|
next_op_desc->SetInput(
|
||||||
|
name, std::vector<std::string>({dequant_in->Name()}));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (keep_dequant)
|
||||||
|
GraphSafeRemoveNodes(graph, {quant_op, quant_out});
|
||||||
|
else
|
||||||
|
GraphSafeRemoveNodes(graph,
|
||||||
|
{dequant_op, quant_op, dequant_out, quant_out});
|
||||||
|
|
||||||
|
IR_NODE_LINK_TO(dequant_in, next_op);
|
||||||
|
|
||||||
|
found_squash_count++;
|
||||||
|
} else {
|
||||||
|
// squash dequantize-quantize to requantize op
|
||||||
|
OpDesc desc;
|
||||||
|
desc.SetType("requantize");
|
||||||
|
desc.SetInput("Input", std::vector<std::string>({dequant_in->Name()}));
|
||||||
|
desc.SetOutput("Output", std::vector<std::string>({quant_out->Name()}));
|
||||||
|
desc.SetAttr("Scale_in", dequant_scale);
|
||||||
|
desc.SetAttr("Scale_out", quant_scale);
|
||||||
|
|
||||||
|
auto requant_op = g->CreateOpNode(&desc);
|
||||||
|
|
||||||
|
if (keep_dequant)
|
||||||
|
GraphSafeRemoveNodes(graph, {quant_op});
|
||||||
|
else
|
||||||
|
GraphSafeRemoveNodes(graph, {dequant_op, quant_op, dequant_out});
|
||||||
|
|
||||||
|
IR_NODE_LINK_TO(dequant_in, requant_op);
|
||||||
|
IR_NODE_LINK_TO(requant_op, quant_out);
|
||||||
|
|
||||||
|
found_squash_count++;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
gpd(graph, handler);
|
||||||
|
AddStatis(found_squash_count);
|
||||||
|
PrettyLogDetail("--- squashed %d dequantize-quantize pairs",
|
||||||
|
found_squash_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ir::Graph> CPUQuantizeSquashPass::ApplyImpl(
|
||||||
|
std::unique_ptr<ir::Graph> graph) const {
|
||||||
|
PADDLE_ENFORCE(graph.get());
|
||||||
|
FusePassBase::Init("cpu_quantize_squash_pass", graph.get());
|
||||||
|
|
||||||
|
std::unordered_map<const Node*, int> nodes_keep_counter;
|
||||||
|
FindNodesToKeep(graph.get(), &nodes_keep_counter);
|
||||||
|
Squash(graph.get(), &nodes_keep_counter);
|
||||||
|
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
REGISTER_PASS(cpu_quantize_squash_pass,
|
||||||
|
paddle::framework::ir::CPUQuantizeSquashPass);
|
@ -0,0 +1,58 @@
|
|||||||
|
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||||
|
#include "paddle/fluid/framework/ir/graph.h"
|
||||||
|
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||||
|
#include "paddle/fluid/framework/ir/pass.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Squash dequantize->quantize pair pattern into requantize op
|
||||||
|
*/
|
||||||
|
class CPUQuantizeSquashPass : public FusePassBase {
|
||||||
|
public:
|
||||||
|
virtual ~CPUQuantizeSquashPass() {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||||
|
std::unique_ptr<ir::Graph> graph) const override;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* For each dequantize's output find the number of operators it is an input to
|
||||||
|
*/
|
||||||
|
void FindNodesToKeep(
|
||||||
|
Graph* graph,
|
||||||
|
std::unordered_map<const Node*, int>* nodes_keep_counter) const;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Squash dequantize-quantize ops pairs into requantize or nothing
|
||||||
|
*/
|
||||||
|
void Squash(Graph* graph,
|
||||||
|
std::unordered_map<const Node*, int>* nodes_keep_counter) const;
|
||||||
|
|
||||||
|
const std::string name_scope_{"squash"};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,179 @@
|
|||||||
|
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/ir/cpu_quantize_squash_pass.h"
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "paddle/fluid/framework/naive_executor.h"
|
||||||
|
#include "paddle/fluid/platform/place.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace ir {
|
||||||
|
|
||||||
|
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
|
||||||
|
const std::vector<std::string>& inputs,
|
||||||
|
const std::vector<std::string>& outputs, bool use_mkldnn,
|
||||||
|
float scale = 0) {
|
||||||
|
auto* op = prog->MutableBlock(0)->AppendOp();
|
||||||
|
op->SetType(type);
|
||||||
|
op->SetAttr("use_mkldnn", use_mkldnn);
|
||||||
|
op->SetAttr("name", name);
|
||||||
|
if (type == "conv2d") {
|
||||||
|
op->SetInput("Input", {inputs[0]});
|
||||||
|
if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]});
|
||||||
|
if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]});
|
||||||
|
op->SetOutput("Output", {outputs[0]});
|
||||||
|
} else if (type == "quantize") {
|
||||||
|
op->SetInput("Input", {inputs[0]});
|
||||||
|
op->SetOutput("Output", {outputs[0]});
|
||||||
|
op->SetAttr("Scale", scale);
|
||||||
|
} else if (type == "dequantize") {
|
||||||
|
op->SetInput("Input", {inputs[0]});
|
||||||
|
op->SetOutput("Output", {outputs[0]});
|
||||||
|
op->SetAttr("Scale", scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// (a,w1,b1)->Conv1->d
|
||||||
|
// d->Dequant->e
|
||||||
|
// e->Quant->f
|
||||||
|
// (f,w2,b2)->Conv2->i
|
||||||
|
ProgramDesc BuildProgramDesc(bool use_mkldnn, float scale1, float scale2) {
|
||||||
|
ProgramDesc prog;
|
||||||
|
for (auto& v : std::initializer_list<std::string>(
|
||||||
|
{"a", "w1", "b1", "d", "e", "f", "w2", "b2", "i"})) {
|
||||||
|
auto* var = prog.MutableBlock(0)->Var(v);
|
||||||
|
if (v.find("w") == 0 || v.find("b") == 0) {
|
||||||
|
var->SetPersistable(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SetOp(&prog, "conv2d", "Conv1", {"a", "w1", "b1"}, {"d"}, use_mkldnn);
|
||||||
|
SetOp(&prog, "dequantize", "Dequant", {"d"}, {"e"}, use_mkldnn, scale1);
|
||||||
|
SetOp(&prog, "quantize", "Quant", {"e"}, {"f"}, use_mkldnn, scale2);
|
||||||
|
SetOp(&prog, "conv2d", "Conv2", {"f", "w2", "b2"}, {"i"}, use_mkldnn);
|
||||||
|
return prog;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const std::initializer_list<std::string> variable_names{
|
||||||
|
"a", "b", "c", "d", "e", "f", "g", "h"};
|
||||||
|
// a->Conv1->b
|
||||||
|
// b->Dequant->c
|
||||||
|
//
|
||||||
|
// c->Quant1->d and d->Conv2->e
|
||||||
|
//
|
||||||
|
// c->Conv3->f
|
||||||
|
//
|
||||||
|
// c->Quant2->g and g->Conv4->h
|
||||||
|
//
|
||||||
|
ProgramDesc BuildProgramDesc2(bool use_mkldnn, float scale1, float scale2,
|
||||||
|
float scale3) {
|
||||||
|
ProgramDesc prog;
|
||||||
|
for (auto& v : variable_names) {
|
||||||
|
prog.MutableBlock(0)->Var(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn);
|
||||||
|
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1);
|
||||||
|
|
||||||
|
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, use_mkldnn, scale2);
|
||||||
|
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn);
|
||||||
|
|
||||||
|
SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn);
|
||||||
|
|
||||||
|
SetOp(&prog, "quantize", "Quant2", {"c"}, {"g"}, use_mkldnn, scale3);
|
||||||
|
SetOp(&prog, "conv2d", "Conv4", {"g"}, {"h"}, use_mkldnn);
|
||||||
|
|
||||||
|
return prog;
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
|
||||||
|
const char* var_name) {
|
||||||
|
auto x = scope->Var(var_name);
|
||||||
|
auto tensor = x->GetMutable<LoDTensor>();
|
||||||
|
tensor->mutable_data(place, proto::VarType::FP32,
|
||||||
|
::paddle::memory::Allocator::kDefault, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MainTest(const ProgramDesc& prog, int removed_nodes_num) {
|
||||||
|
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||||
|
|
||||||
|
// Init scope, as it is used in pass
|
||||||
|
auto place = paddle::platform::CPUPlace();
|
||||||
|
NaiveExecutor exe{place};
|
||||||
|
Scope scope;
|
||||||
|
exe.CreateVariables(prog, 0, true, &scope);
|
||||||
|
|
||||||
|
for (auto& v : variable_names) {
|
||||||
|
InitTensorHolder(&scope, place, v.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
graph->Set(kParamScopeAttr, new framework::Scope*(&scope));
|
||||||
|
|
||||||
|
auto pass = PassRegistry::Instance().Get("cpu_quantize_squash_pass");
|
||||||
|
|
||||||
|
int original_nodes_num = graph->Nodes().size();
|
||||||
|
|
||||||
|
graph = pass->Apply(std::move(graph));
|
||||||
|
|
||||||
|
int current_nodes_num = graph->Nodes().size();
|
||||||
|
|
||||||
|
EXPECT_EQ(original_nodes_num - removed_nodes_num, current_nodes_num);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CpuQuantizeSquashPass, equal_scales) {
|
||||||
|
auto scale = 1.2345f;
|
||||||
|
auto use_mkldnn = true;
|
||||||
|
// Remove 4 nodes: Dequant, Quant, e, f
|
||||||
|
auto remove_nodes = 4;
|
||||||
|
MainTest(BuildProgramDesc(use_mkldnn, scale, scale), remove_nodes);
|
||||||
|
|
||||||
|
use_mkldnn = !use_mkldnn;
|
||||||
|
MainTest(BuildProgramDesc(use_mkldnn, scale, scale), remove_nodes);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CpuQuantizeSquashPass, inequal_scales) {
|
||||||
|
auto scale1 = 1.2345f;
|
||||||
|
auto scale2 = 21.0f;
|
||||||
|
auto use_mkldnn = true;
|
||||||
|
// Remove 3 nodes: Dequant, Quant, e
|
||||||
|
// Insert 1 node: requantize
|
||||||
|
auto remove_nodes = 2;
|
||||||
|
MainTest(BuildProgramDesc(use_mkldnn, scale1, scale2), remove_nodes);
|
||||||
|
|
||||||
|
use_mkldnn = !use_mkldnn;
|
||||||
|
MainTest(BuildProgramDesc(use_mkldnn, scale1, scale2), remove_nodes);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CpuQuantizeSquashPass, branch_to_equal_inequal_and_fp32) {
|
||||||
|
// Delete both quantize ops,
|
||||||
|
// bypass dequantize in both branches,
|
||||||
|
// insert requantize on one branch
|
||||||
|
auto scale = 1.2345f;
|
||||||
|
auto scale2 = 21.0f;
|
||||||
|
auto use_mkldnn = true;
|
||||||
|
// Remove 3 nodes: Quant1, Quant2, g
|
||||||
|
// Insert 1 node: requantize
|
||||||
|
auto remove_nodes = 2;
|
||||||
|
MainTest(BuildProgramDesc2(use_mkldnn, scale, scale, scale2), remove_nodes);
|
||||||
|
|
||||||
|
use_mkldnn = !use_mkldnn;
|
||||||
|
MainTest(BuildProgramDesc2(use_mkldnn, scale, scale, scale2), remove_nodes);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ir
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
USE_PASS(cpu_quantize_squash_pass);
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue