|
|
|
@ -75,7 +75,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
|
|
|
|
|
BOOST_GET_CONST(float, quant_op->Op()->GetAttr("Scale"));
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
nodes_keep_counter->find(dequant_out), nodes_keep_counter->end(),
|
|
|
|
|
platform::errors::NotFound("The dequant output node is not found"));
|
|
|
|
|
platform::errors::NotFound("The dequant output node is not found."));
|
|
|
|
|
|
|
|
|
|
// check if dequantize op should be kept or removed, decrease the counter
|
|
|
|
|
bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1;
|
|
|
|
@ -153,8 +153,9 @@ void CPUQuantizeSquashPass::OpRequantSquash(Graph* graph) const {
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
any_op_output_name.empty(), true,
|
|
|
|
|
platform::errors::NotFound("Operator before requantize operator "
|
|
|
|
|
"should have requantize input as output"));
|
|
|
|
|
platform::errors::NotFound("Operator before requantize operator(%s) "
|
|
|
|
|
"should have requantize input as output.",
|
|
|
|
|
requant_in->Name()));
|
|
|
|
|
|
|
|
|
|
float requant_scale_out =
|
|
|
|
|
BOOST_GET_CONST(float, requant_op->Op()->GetAttr("Scale_out"));
|
|
|
|
@ -195,10 +196,11 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
|
|
|
|
|
for (auto input_name : any_op->Op()->Input(name))
|
|
|
|
|
if (input_name == requant_out->Name()) any_op_input_name = name;
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
any_op_input_name.empty(), true,
|
|
|
|
|
platform::errors::NotFound("The operator after requantize operator "
|
|
|
|
|
"should have requantize output as input"));
|
|
|
|
|
PADDLE_ENFORCE_NE(any_op_input_name.empty(), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"The operator after requantize operator(%s) "
|
|
|
|
|
"should have requantize output as input.",
|
|
|
|
|
requant_out->Name()));
|
|
|
|
|
float requant_scale_in =
|
|
|
|
|
boost::get<float>(requant_op->Op()->GetAttr("Scale_in"));
|
|
|
|
|
|
|
|
|
@ -206,11 +208,14 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
|
|
|
|
|
if (any_op->Op()->Type() == "matmul")
|
|
|
|
|
scale_name = any_op_input_name == "X" ? "Scale_x" : "Scale_y";
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(requant_op->Op()->GetAttrIfExists<float>("Scale_out"),
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
requant_op->Op()->GetAttrIfExists<float>("Scale_out"),
|
|
|
|
|
any_op->Op()->GetAttrIfExists<float>(scale_name),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The operator after requantize should have input "
|
|
|
|
|
"scale equal to requantize output scale"));
|
|
|
|
|
"scale(%f) equal to requantize output scale(%f).",
|
|
|
|
|
any_op->Op()->GetAttrIfExists<float>(scale_name),
|
|
|
|
|
requant_op->Op()->GetAttrIfExists<float>("Scale_out")));
|
|
|
|
|
any_op->Op()->SetAttr(scale_name, requant_scale_in);
|
|
|
|
|
any_op->Op()->SetInput(any_op_input_name,
|
|
|
|
|
std::vector<std::string>({requant_in->Name()}));
|
|
|
|
@ -286,8 +291,9 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
|
|
|
|
|
auto* first_quant_out = first_quant_op->outputs[0];
|
|
|
|
|
float scale = first_quant_op->Op()->GetAttrIfExists<float>("Scale");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(scale, 0, platform::errors::InvalidArgument(
|
|
|
|
|
"Quantize scale should not be equal 0"));
|
|
|
|
|
PADDLE_ENFORCE_NE(scale, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Quantize scale(%f) should not be equal 0.", scale));
|
|
|
|
|
|
|
|
|
|
for (int iter = prev_out->outputs.size() - 1; iter >= 0; iter--) {
|
|
|
|
|
auto quant_op = prev_out->outputs[iter];
|
|
|
|
@ -304,8 +310,9 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
last_op_input_name.empty(), true,
|
|
|
|
|
platform::errors::NotFound("Operator after quantize operator "
|
|
|
|
|
"should has quantize output as input"));
|
|
|
|
|
platform::errors::NotFound("Operator after quantize operator(%s) "
|
|
|
|
|
"should has quantize output as input.",
|
|
|
|
|
quant_out->Name()));
|
|
|
|
|
last_op->Op()->SetInput(
|
|
|
|
|
last_op_input_name,
|
|
|
|
|
std::vector<std::string>({first_quant_out->Name()}));
|
|
|
|
@ -345,10 +352,12 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GT(dequant_scale, 0.0f,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Dequantize scale should have positive value"));
|
|
|
|
|
"Dequantize scale(%f) should have positive value.",
|
|
|
|
|
dequant_scale));
|
|
|
|
|
PADDLE_ENFORCE_GT(scale_scale, 0.0f,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Scale of scale op should have positive value"));
|
|
|
|
|
"Scale(%f) of scale op should have positive value.",
|
|
|
|
|
scale_scale));
|
|
|
|
|
|
|
|
|
|
dequant_op->Op()->SetAttr("Scale", dequant_scale / scale_scale);
|
|
|
|
|
dequant_op->Op()->SetOutput(
|
|
|
|
@ -367,8 +376,8 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
|
|
|
|
|
void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
graph,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"The graph in function CPUQuantizeSquashPass::ApplyImpl is null"));
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The graph in function CPUQuantizeSquashPass::ApplyImpl is null."));
|
|
|
|
|
FusePassBase::Init("cpu_quantize_squash_pass", graph);
|
|
|
|
|
|
|
|
|
|
std::unordered_map<const Node*, int> nodes_keep_counter;
|
|
|
|
|