|
|
|
@ -74,6 +74,14 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
|
|
|
|
|
op->SetInput("Input", {inputs[0]});
|
|
|
|
|
op->SetOutput("Output", {outputs[0]});
|
|
|
|
|
op->SetAttr("Scale", 1.0f);
|
|
|
|
|
} else if (type == "matmul") {
|
|
|
|
|
op->SetInput("X", {inputs[0]});
|
|
|
|
|
if (inputs.size() > 1) op->SetInput("Y", {inputs[1]});
|
|
|
|
|
op->SetOutput("Out", {outputs[0]});
|
|
|
|
|
op->SetAttr("use_quantizer", use_quantizer);
|
|
|
|
|
op->SetAttr("Scale_x", 1.0f);
|
|
|
|
|
op->SetAttr("Scale_y", 1.0f);
|
|
|
|
|
op->SetAttr("Scale_out", 1.0f);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -513,6 +521,89 @@ TEST(CPUQuantizePass, check_scales) {
|
|
|
|
|
MainTestCheckScales(BuildProgramDescCheckScalesConv(), var_names, "a");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static const std::initializer_list<std::string> variable_names_matmul = {
|
|
|
|
|
"a", "b", "c", "d", "e", "f"};
|
|
|
|
|
|
|
|
|
|
ProgramDesc BuildProgramDescMatmul() {
|
|
|
|
|
ProgramDesc prog;
|
|
|
|
|
for (auto& v : variable_names_transpose) {
|
|
|
|
|
prog.MutableBlock(0)->Var(v);
|
|
|
|
|
}
|
|
|
|
|
SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true);
|
|
|
|
|
SetOp(&prog, "dequantize", "Dequantize2", {"c"}, {"d"}, true);
|
|
|
|
|
SetOp(&prog, "matmul", "Matmul", {"b", "d"}, {"e"}, true, true);
|
|
|
|
|
SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, false);
|
|
|
|
|
|
|
|
|
|
return prog;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ProgramDesc BuildProgramDescMatmulNotQuantized() {
|
|
|
|
|
ProgramDesc prog;
|
|
|
|
|
for (auto& v : variable_names_transpose) {
|
|
|
|
|
prog.MutableBlock(0)->Var(v);
|
|
|
|
|
}
|
|
|
|
|
SetOp(&prog, "dropout", "Dropout", {"a"}, {"b"}, false);
|
|
|
|
|
SetOp(&prog, "dequantize", "Dequantize", {"c"}, {"d"}, true);
|
|
|
|
|
SetOp(&prog, "matmul", "Matmul", {"b", "d"}, {"e"}, true, true);
|
|
|
|
|
SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, false);
|
|
|
|
|
|
|
|
|
|
return prog;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MainTestMatmul(const ProgramDesc& prog, int matmul_count, int quant_count,
|
|
|
|
|
int dequant_count, int added_nodes_count, float scale) {
|
|
|
|
|
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
|
|
|
|
int original_nodes_num, current_nodes_num;
|
|
|
|
|
PreparePass(&graph, prog, variable_names_matmul, &original_nodes_num,
|
|
|
|
|
¤t_nodes_num);
|
|
|
|
|
|
|
|
|
|
int quantize_nodes_count = 0;
|
|
|
|
|
int dequantize_nodes_count = 0;
|
|
|
|
|
int matmul_nodes_count = 0;
|
|
|
|
|
for (auto* node : graph->Nodes()) {
|
|
|
|
|
if (node->IsOp()) {
|
|
|
|
|
auto* op = node->Op();
|
|
|
|
|
if (op->Type() == "matmul") {
|
|
|
|
|
matmul_nodes_count++;
|
|
|
|
|
auto op_name = boost::get<std::string>(op->GetAttr("name"));
|
|
|
|
|
EXPECT_EQ(boost::get<float>(op->GetAttr("Scale_x")), scale)
|
|
|
|
|
<< "Scale_x for node '" + op_name + "'.";
|
|
|
|
|
EXPECT_EQ(boost::get<float>(op->GetAttr("Scale_y")), scale)
|
|
|
|
|
<< "Scale_y for node '" + op_name + "'.";
|
|
|
|
|
EXPECT_EQ(boost::get<float>(op->GetAttr("Scale_out")), scale)
|
|
|
|
|
<< "Scale_out for node '" + op_name + "'.";
|
|
|
|
|
} else if (op->Type() == "quantize") {
|
|
|
|
|
quantize_nodes_count++;
|
|
|
|
|
} else if (op->Type() == "dequantize") {
|
|
|
|
|
dequantize_nodes_count++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
EXPECT_EQ(matmul_nodes_count, matmul_count);
|
|
|
|
|
EXPECT_EQ(quantize_nodes_count, quant_count);
|
|
|
|
|
EXPECT_EQ(dequantize_nodes_count, dequant_count);
|
|
|
|
|
EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(CpuQuantizePass, matmul) {
|
|
|
|
|
int matmul_count = 1;
|
|
|
|
|
int quant_count = 2;
|
|
|
|
|
int dequant_count = 3;
|
|
|
|
|
// 2 Quant + 2 IN + 1 DeQuant + 1 OUT
|
|
|
|
|
int added_nodes_count = 6;
|
|
|
|
|
MainTestMatmul(BuildProgramDescMatmul(), matmul_count, quant_count,
|
|
|
|
|
dequant_count, added_nodes_count, 2.0f * 127);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(CpuQuantizePass, matmul_not_quantized) {
|
|
|
|
|
int matmul_count = 1;
|
|
|
|
|
int quant_count = 0;
|
|
|
|
|
int dequant_count = 1;
|
|
|
|
|
// nothing change
|
|
|
|
|
int added_nodes_count = 0;
|
|
|
|
|
MainTestMatmul(BuildProgramDescMatmulNotQuantized(), matmul_count,
|
|
|
|
|
quant_count, dequant_count, added_nodes_count, 1.0f);
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
} // namespace ir
|
|
|
|
|