|
|
|
@ -64,6 +64,10 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
|
|
|
|
|
op->SetOutput("Out", {outputs[0]});
|
|
|
|
|
op->SetAttr("scale", scale);
|
|
|
|
|
op->SetAttr("bias", bias);
|
|
|
|
|
} else if (type == "matmul") {
|
|
|
|
|
op->SetInput("X", {inputs[0]});
|
|
|
|
|
op->SetInput("Y", {inputs[1]});
|
|
|
|
|
op->SetOutput("Out", {outputs[0]});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -92,7 +96,7 @@ ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static const std::initializer_list<std::string> variable_names{
|
|
|
|
|
"a", "b", "c", "d", "e", "f", "g", "h"};
|
|
|
|
|
"a", "b", "c", "d", "e", "f", "g", "h", "x", "y"};
|
|
|
|
|
|
|
|
|
|
// a->Conv1->b
|
|
|
|
|
// b->Dequant(scale1)->c
|
|
|
|
@ -272,6 +276,21 @@ ProgramDesc BuildDequantScaleProgramDesc(bool use_mkldnn, float dequant_scale,
|
|
|
|
|
return prog;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// {x,y}->Matmul->b
|
|
|
|
|
// b->Dequant->c
|
|
|
|
|
ProgramDesc BuildMatmulDequantProgramDesc(bool use_mkldnn,
|
|
|
|
|
float dequant_scale) {
|
|
|
|
|
ProgramDesc prog;
|
|
|
|
|
for (auto& v : variable_names) {
|
|
|
|
|
prog.MutableBlock(0)->Var(v);
|
|
|
|
|
}
|
|
|
|
|
SetOp(&prog, "matmul", "Matmul", {"x", "y"}, {"b"}, use_mkldnn);
|
|
|
|
|
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn,
|
|
|
|
|
dequant_scale);
|
|
|
|
|
|
|
|
|
|
return prog;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
|
|
|
|
|
const char* var_name) {
|
|
|
|
|
auto x = scope->Var(var_name);
|
|
|
|
@ -595,6 +614,17 @@ TEST(CpuQuantizeSquashPass, dequantize_scale_with_bias) {
|
|
|
|
|
scale_scale, bias),
|
|
|
|
|
"Dequant", "Scale", dequant_scale);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(CpuQuantizeSquashPass, matmul_with_dequant) {
|
|
|
|
|
auto dequant_scale = 1.2345f;
|
|
|
|
|
auto use_mkldnn = true;
|
|
|
|
|
// remove: matmul_out, dequant_op
|
|
|
|
|
auto remove_nodes = 2;
|
|
|
|
|
CountNodeTest(BuildMatmulDequantProgramDesc(use_mkldnn, dequant_scale),
|
|
|
|
|
remove_nodes);
|
|
|
|
|
IsForceFp32OutputTest(
|
|
|
|
|
BuildMatmulDequantProgramDesc(use_mkldnn, dequant_scale), "matmul", true);
|
|
|
|
|
}
|
|
|
|
|
} // namespace ir
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|