|
|
|
@ -161,6 +161,36 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out,
|
|
|
|
|
return prog;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// a->Conv1->b
|
|
|
|
|
// b->Dequant1(Scale1)->c
|
|
|
|
|
// c->Concat
|
|
|
|
|
ProgramDesc BuildConvDequantConcatProgramDesc(bool use_mkldnn, float scale_out,
|
|
|
|
|
float scale) {
|
|
|
|
|
ProgramDesc prog;
|
|
|
|
|
for (auto& v : variable_names) {
|
|
|
|
|
prog.MutableBlock(0)->Var(v);
|
|
|
|
|
}
|
|
|
|
|
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
|
|
|
|
|
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
|
|
|
|
|
SetOp(&prog, "concat", "Concat1", {"c"}, {"d"}, use_mkldnn);
|
|
|
|
|
return prog;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// a->Conv1->b
|
|
|
|
|
// b->Dequant1(Scale1)->c
|
|
|
|
|
// b->Conv2->d
|
|
|
|
|
ProgramDesc BuildConvDequantConvProgramDesc(bool use_mkldnn, float scale_out,
|
|
|
|
|
float scale) {
|
|
|
|
|
ProgramDesc prog;
|
|
|
|
|
for (auto& v : variable_names) {
|
|
|
|
|
prog.MutableBlock(0)->Var(v);
|
|
|
|
|
}
|
|
|
|
|
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
|
|
|
|
|
SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
|
|
|
|
|
SetOp(&prog, "conv2d", "Conv2", {"b"}, {"d"}, use_mkldnn);
|
|
|
|
|
return prog;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
|
|
|
|
|
const char* var_name) {
|
|
|
|
|
auto x = scope->Var(var_name);
|
|
|
|
@ -217,6 +247,7 @@ void EqualScaleOutTest(const ProgramDesc& prog, const std::string& name,
|
|
|
|
|
void CheckRequantScalesTest(const ProgramDesc& prog, float scale_in,
|
|
|
|
|
float scale_out) {
|
|
|
|
|
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
|
|
|
|
|
|
|
|
|
PrepareGraph(&graph, prog);
|
|
|
|
|
RegisterPass(&graph);
|
|
|
|
|
|
|
|
|
@ -238,6 +269,7 @@ TEST(CpuQuantizeSquashPass, equal_scales) {
|
|
|
|
|
auto use_mkldnn = true;
|
|
|
|
|
// Remove 4 nodes: Dequant, Quant, e, f
|
|
|
|
|
auto remove_nodes = 4;
|
|
|
|
|
|
|
|
|
|
CountNodeTest(
|
|
|
|
|
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale, scale),
|
|
|
|
|
remove_nodes);
|
|
|
|
@ -253,6 +285,7 @@ TEST(CpuQuantizeSquashPass, unequal_scales) {
|
|
|
|
|
auto use_mkldnn = true;
|
|
|
|
|
// Remove 4 nodes: Dequant, Quant, e, d
|
|
|
|
|
auto remove_nodes = 4;
|
|
|
|
|
|
|
|
|
|
CountNodeTest(
|
|
|
|
|
BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2),
|
|
|
|
|
remove_nodes);
|
|
|
|
@ -280,6 +313,7 @@ TEST(CpuQuantizeSquashPass, branch_to_equal_unequal_and_fp32) {
|
|
|
|
|
// Remove 3 nodes: Quant1, c, Quant2,
|
|
|
|
|
// Insert 1 node: Requant
|
|
|
|
|
auto remove_nodes = 2;
|
|
|
|
|
|
|
|
|
|
CountNodeTest(BuildConvMultiOutputProgramDesc(use_mkldnn, scale_out, scale,
|
|
|
|
|
scale, scale2),
|
|
|
|
|
remove_nodes);
|
|
|
|
@ -322,6 +356,7 @@ TEST(CpuQuantizeSquashPass,
|
|
|
|
|
// Remove 3 nodes: Dequant1, c, Quant
|
|
|
|
|
// Insert 1 node: Requant
|
|
|
|
|
auto remove_nodes = 2;
|
|
|
|
|
|
|
|
|
|
CountNodeTest(
|
|
|
|
|
BuildConcatDequantQuantProgramDesc(use_mkldnn, scale_out, scale, scale2),
|
|
|
|
|
remove_nodes);
|
|
|
|
@ -345,6 +380,27 @@ TEST(CpuQuantizeSquashPass, more_than_one_conv_out_outputs) {
|
|
|
|
|
remove_nodes);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// a->Conv1->c->Concat
|
|
|
|
|
TEST(CpuQuantizeSquashPass, conv_dequant_only_one_output) {
|
|
|
|
|
auto scale_out = 1.0f;
|
|
|
|
|
auto scale = 1.2345f;
|
|
|
|
|
auto use_mkldnn = true;
|
|
|
|
|
// remove 2 nodes: Dequant1, c
|
|
|
|
|
auto remove_nodes = 2;
|
|
|
|
|
CountNodeTest(BuildConvDequantConcatProgramDesc(use_mkldnn, scale_out, scale),
|
|
|
|
|
remove_nodes);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(CpuQuantizeSquashPass, conv_dequant_more_than_one_op_after_conv) {
|
|
|
|
|
auto scale_out = 1.0f;
|
|
|
|
|
auto scale = 1.2345f;
|
|
|
|
|
auto use_mkldnn = true;
|
|
|
|
|
// nothing change
|
|
|
|
|
auto remove_nodes = 0;
|
|
|
|
|
CountNodeTest(BuildConvDequantConvProgramDesc(use_mkldnn, scale_out, scale),
|
|
|
|
|
remove_nodes);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace ir
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|