|
|
@ -177,7 +177,10 @@ void CPUQuantizeSquashPass::ConvDequantSquash(Graph* graph) const {
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, conv_dequant_pattern);
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, conv_dequant_pattern);
|
|
|
|
|
|
|
|
|
|
|
|
// if conv2d has one output
|
|
|
|
// if conv2d has one output
|
|
|
|
if (conv_out->outputs.size() == 1) {
|
|
|
|
// and there is no fuse residual connection
|
|
|
|
|
|
|
|
// because residual fusion does not support force output with fp32
|
|
|
|
|
|
|
|
if (conv_out->outputs.size() == 1 &&
|
|
|
|
|
|
|
|
!(conv_op->Op()->GetAttrIfExists<bool>("fuse_residual_connection"))) {
|
|
|
|
conv_op->Op()->SetAttr("force_fp32_output", true);
|
|
|
|
conv_op->Op()->SetAttr("force_fp32_output", true);
|
|
|
|
conv_op->Op()->SetOutput("Output",
|
|
|
|
conv_op->Op()->SetOutput("Output",
|
|
|
|
std::vector<std::string>({dequant_out->Name()}));
|
|
|
|
std::vector<std::string>({dequant_out->Name()}));
|
|
|
|