|
|
|
|
@ -59,6 +59,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
|
|
|
|
|
inputs.size()));
|
|
|
|
|
op->SetInput("W", {inputs[1]});
|
|
|
|
|
op->SetOutput("Out", outputs);
|
|
|
|
|
op->SetAttr("Scale_out", scale);
|
|
|
|
|
} else if (type == "scale") {
|
|
|
|
|
op->SetInput("X", {inputs[0]});
|
|
|
|
|
op->SetOutput("Out", {outputs[0]});
|
|
|
|
|
@ -68,6 +69,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
|
|
|
|
|
op->SetInput("X", {inputs[0]});
|
|
|
|
|
op->SetInput("Y", {inputs[1]});
|
|
|
|
|
op->SetOutput("Out", {outputs[0]});
|
|
|
|
|
op->SetAttr("Scale_out", scale);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -96,7 +98,7 @@ ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static const std::initializer_list<std::string> variable_names{
|
|
|
|
|
"a", "b", "c", "d", "e", "f", "g", "h", "x", "y"};
|
|
|
|
|
"a", "b", "c", "d", "e", "f", "g", "h", "x", "y", "w1"};
|
|
|
|
|
|
|
|
|
|
// a->Conv1->b
|
|
|
|
|
// b->Dequant(scale1)->c
|
|
|
|
|
@ -125,23 +127,30 @@ ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out,
|
|
|
|
|
return prog;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// a->Conv1->b->Requant(scale1)->c
|
|
|
|
|
// d->Conv2->e->Requant(scale2)->f
|
|
|
|
|
// {c,f}->Concat
|
|
|
|
|
ProgramDesc BuildConvsRequantConcatProgramDesc(bool use_mkldnn, float scale_out,
|
|
|
|
|
float scale1, float scale2) {
|
|
|
|
|
// a->Conv->b->Requant(scale1)->c
|
|
|
|
|
// d->Fc->e->Requant(scale2)->f
|
|
|
|
|
// {x,y}->Matmul->g->Requant(scale3)->h
|
|
|
|
|
// {c,f,h}->Concat
|
|
|
|
|
ProgramDesc BuildOpRequantProgramDesc(bool use_mkldnn, float conv_scale,
|
|
|
|
|
float fc_scale, float matmul_scale,
|
|
|
|
|
float requant_scale1,
|
|
|
|
|
float requant_scale2,
|
|
|
|
|
float requant_scale3) {
|
|
|
|
|
ProgramDesc prog;
|
|
|
|
|
for (auto& v : variable_names) {
|
|
|
|
|
prog.MutableBlock(0)->Var(v);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
|
|
|
|
|
SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, scale1);
|
|
|
|
|
|
|
|
|
|
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out);
|
|
|
|
|
SetOp(&prog, "requantize", "Requant2", {"e"}, {"f"}, use_mkldnn, scale2);
|
|
|
|
|
|
|
|
|
|
SetOp(&prog, "concat", "Concat", {"c"}, {"f"}, use_mkldnn);
|
|
|
|
|
SetOp(&prog, "conv2d", "Conv", {"a"}, {"b"}, use_mkldnn, conv_scale);
|
|
|
|
|
SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn,
|
|
|
|
|
requant_scale1);
|
|
|
|
|
SetOp(&prog, "fc", "Fc", {"d", "w1"}, {"e"}, use_mkldnn, fc_scale);
|
|
|
|
|
SetOp(&prog, "requantize", "Requant2", {"e"}, {"f"}, use_mkldnn,
|
|
|
|
|
requant_scale2);
|
|
|
|
|
SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"g"}, use_mkldnn, matmul_scale);
|
|
|
|
|
SetOp(&prog, "requantize", "Requant3", {"g"}, {"h"}, use_mkldnn,
|
|
|
|
|
requant_scale3);
|
|
|
|
|
SetOp(&prog, "concat", "Concat", {"c", "f", "h"}, {"g"}, use_mkldnn);
|
|
|
|
|
|
|
|
|
|
return prog;
|
|
|
|
|
}
|
|
|
|
|
@ -412,27 +421,28 @@ TEST(CpuQuantizeSquashPass, unequal_scales) {
|
|
|
|
|
"Conv1", "Scale_out", scale2);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// a->Conv1->b->Requant->c
|
|
|
|
|
// d->Conv2->e->Requant->f
|
|
|
|
|
// {c,f}->Concat
|
|
|
|
|
TEST(CpuQuantizeSquashPass, equal_scales_squash_requantize) {
|
|
|
|
|
// Delete both requantize op
|
|
|
|
|
auto scale_out = 1.0f;
|
|
|
|
|
auto scale = 1.2345f;
|
|
|
|
|
// a->Conv->b->Requant->c
|
|
|
|
|
// d->Fc->e->Requant->f
|
|
|
|
|
// {x,y}->Matmul->g->Requant->h
|
|
|
|
|
// {c,f,h}->Concat
|
|
|
|
|
TEST(CpuQuantizeSquashPass, op_requantize_squash) {
|
|
|
|
|
// Delete all requantize op
|
|
|
|
|
auto conv_scale = 0.234f;
|
|
|
|
|
auto fc_scale = 1.234f;
|
|
|
|
|
auto matmul_scale = 2.234f;
|
|
|
|
|
auto requant_scale1 = 2.234f;
|
|
|
|
|
auto requant_scale2 = 3.234f;
|
|
|
|
|
auto requant_scale3 = 4.234f;
|
|
|
|
|
auto use_mkldnn = true;
|
|
|
|
|
// Remove 4 nodes: b, Requant1, e, Requant2
|
|
|
|
|
auto remove_nodes = 4;
|
|
|
|
|
CountNodeTest(
|
|
|
|
|
BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
|
|
|
|
|
remove_nodes);
|
|
|
|
|
|
|
|
|
|
// check equal scale conv->scale_out and requant->scale_out
|
|
|
|
|
EqualScaleTest(
|
|
|
|
|
BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
|
|
|
|
|
"Conv1", "Scale_out", scale);
|
|
|
|
|
EqualScaleTest(
|
|
|
|
|
BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
|
|
|
|
|
"Conv2", "Scale_out", scale);
|
|
|
|
|
// Remove 4 nodes: b, Requant1, e, Requant2, g, Requant3
|
|
|
|
|
auto remove_nodes = 6;
|
|
|
|
|
auto program_desc =
|
|
|
|
|
BuildOpRequantProgramDesc(use_mkldnn, conv_scale, fc_scale, matmul_scale,
|
|
|
|
|
requant_scale1, requant_scale2, requant_scale3);
|
|
|
|
|
CountNodeTest(program_desc, remove_nodes);
|
|
|
|
|
EqualScaleTest(program_desc, "Conv", "Scale_out", requant_scale1);
|
|
|
|
|
EqualScaleTest(program_desc, "Fc", "Scale_out", requant_scale2);
|
|
|
|
|
EqualScaleTest(program_desc, "Matmul", "Scale_out", requant_scale3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// from
|
|
|
|
|
|