You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py

694 lines
29 KiB

# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from .... import core
from ....framework import IrGraph
from ....framework import IrNode
__all__ = ['QatInt8MkldnnPass', 'Qat2Int8MkldnnPass']
class QatInt8MkldnnPass(object):
"""
Convert QuantizationFreezePass generated IrGraph to MKL-DNN supported INT8
IrGraph. Following transformations did in this pass:
1. Convert int8 range weights with float32 data type, which are generated by
the QuantizationFreezePass, to float32 range weights with float32 data type
by using the corresponding scales. This conversion is because MKL-DNN INT8
conv2d kernel and mul kernel now only support float32 weights input, hence
weights quantization will happen inside the conv2d and mul INT8 kernel.
2. Create the new conv2d or mul op with the converted weights and link its output
to fake_dequantize_abs_max op's output and set conv2d's attribute "force_fp32
_output" as true
3. Transform fake_quantize_xx op to quantize op
4. Remove fake_dequantize_abs_max op
"""
def __init__(self, _scope=None, _place=None):
"""
Args:
scope(fluid.Scope): scope is used to initialize the new parameters.
place(fluid.CPUPlace): place is used to initialize the new parameters.
Examples:
.. code-block:: python
# The original graph will be rewrite.
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization \
import QatInt8MkldnnPass
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
place = fluid.CPUPlace()
mkldnn_pass = QatInt8MkldnnPass(fluid.global_scope(),
place)
mkldnn_pass.apply(graph)
"""
self._scope = _scope
self._place = _place
self._quantize_type = [
'fake_quantize_moving_average_abs_max',
'fake_quantize_range_abs_max'
]
self._dequantize_type = ['fake_dequantize_max_abs']
self._quantize_dequantize_type = [
'fake_quantize_dequantize_moving_average_abs_max'
]
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._pool_ops = ['pool2d']
self._in_scale = {}
self._max_range = {}
self._new_output = {}
self._s8_max = 127
def apply(self, graph):
"""
Quantize the graph for running MKL-DNN INT8 inference. According
to activation quantization type, the graph will transform fake
quantize ops to quantize ops and remove the fake dequantize ops.
Args:
graph(IrGraph): the applied graph.
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
ops = graph.all_op_nodes()
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
# Collect the _in_scales and _max_range to calculate the new scales for MKL-DNN
# INT8 conv2d and mul
for op_node in ops:
if op_node.name() in self._dequantize_type:
input_name = op_node.input("X")[0]
scale_name = op_node.input("Scale")[0]
self._in_scale[input_name] = self._load_param(self._scope,
scale_name)[0]
self._max_range[input_name] = op_node.op().attr("max_range")
self._new_output[input_name] = op_node.output("Out")[0]
if op_node.name() in self._quantize_dequantize_type:
inputs = op_node.op().input_names()
attrs = op_node.op().attr_names()
input_name = op_node.input("X")[0]
scale_name = op_node.input("InScale")[0]
self._in_scale[input_name] = self._load_param(self._scope,
scale_name)[0]
# self._max_range[input_name] = op_node.op().attr("max_range")
self._new_output[input_name] = op_node.output("Out")[0]
for op_node in ops:
if op_node.name() in self._quantizable_ops:
if op_node.name() in self._conv_ops:
self._transform_to_conv_mkldnn(graph, op_node)
elif op_node.name() in self._pool_ops:
self._transform_to_pool_mkldnn(graph, op_node)
else:
self._transform_to_mul_mkldnn(graph, op_node)
elif op_node.name() in self._quantize_type:
self._transform_to_quantize_mkldnn(graph, op_node)
elif op_node.name() in self._dequantize_type:
self._remove_fake_dequantize_op(graph, op_node)
self._remove_unused_var_nodes(graph)
return graph
def _transform_to_pool_mkldnn(self, graph, op):
output_name = op.output("Out")[0]
input_name = op.input("X")[0]
def _transform_to_conv_mkldnn(self, graph, op_node):
weight_name = op_node.input("Filter")[0]
output_name = op_node.output("Output")[0]
# Convert int8 range weights to fp32 range weights
weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide(
np.multiply(weight, self._s8_max), self._max_range[output_name])
w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32)
input_var_node = graph._find_node_by_name(op_node.inputs,
op_node.input("Input")[0])
weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)
# Set fake_dequantize_abs_max's output as new output of conv2d
output_var_node = graph._find_node_by_name(
graph.all_var_nodes(), self._new_output[output_name])
attrs = {
name: op_node.op().attr(name)
for name in op_node.op().attr_names()
}
conv_op_node = graph.create_op_node(
op_type='conv2d',
attrs=attrs,
inputs={'Input': input_var_node,
'Filter': weight_var_node},
outputs={'Output': output_var_node})
# Based on the QAT's scales to calculate the scales of MKL-DNN INT8 conv2d
scale_in = self._s8_max / self._in_scale[output_name]
scale_w = []
scale_w = [self._max_range[output_name] / self._s8_max]
conv_op_node.set_attr("Scale_weights", scale_w)
conv_op_node.set_attr("Scale_in", scale_in)
conv_op_node.set_attr("Scale_out", 1.0)
conv_op_node.set_attr("use_mkldnn", 1)
conv_op_node.set_attr("force_fp32_output", 1)
graph.link_to(input_var_node, conv_op_node)
graph.link_to(weight_var_node, conv_op_node)
graph.link_to(conv_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
def _transform_to_mul_mkldnn(self, graph, op_node):
# For MKL-DNN INT8 mul, input Y should be the weights
weight_name = op_node.input("Y")[0]
output_name = op_node.output("Out")[0]
# Convert int8 range weights to fp32 range weights
weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide(
np.multiply(weight, self._s8_max), self._max_range[output_name])
w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32)
input_var_node = graph._find_node_by_name(op_node.inputs,
op_node.input("X")[0])
weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)
# Set fake_dequantize_abs_max's output as new output of mul
output_var_node = graph._find_node_by_name(
graph.all_var_nodes(), self._new_output[output_name])
attrs = {
name: op_node.op().attr(name)
for name in op_node.op().attr_names()
}
mul_op_node = graph.create_op_node(
op_type='mul',
attrs=attrs,
inputs={'X': input_var_node,
'Y': weight_var_node},
outputs={'Out': output_var_node})
# Based on the QAT's scales to calculate MKL-DNN INT8 mul's scales
scale_in = self._s8_max / self._in_scale[output_name]
scale_w = []
scale_w = [self._max_range[output_name] / self._s8_max]
mul_op_node.set_attr("scale_y", scale_w)
mul_op_node.set_attr("scale_x", scale_in)
mul_op_node.set_attr("scale_out", 1.0)
mul_op_node.set_attr("use_mkldnn", 1)
mul_op_node.set_attr("force_fp32_output", 1)
graph.link_to(input_var_node, mul_op_node)
graph.link_to(weight_var_node, mul_op_node)
graph.link_to(mul_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
def _transform_to_quantize_mkldnn(self, graph, op_node):
"""
Transform fake_quantize_xx op to quantize mkldnn op in the graph.
"""
input_var_node = graph._find_node_by_name(op_node.inputs,
op_node.input("X")[0])
output_var_node = graph._find_node_by_name(op_node.outputs,
op_node.output("Out")[0])
scale_in = self._s8_max / self._load_param(
self._scope, op_node.input("InScale")[0])[0]
quant_op_node = graph.create_op_node(
op_type='quantize',
attrs={
'data_format': 'MKLDNNLAYOUT',
'use_mkldnn': 1,
'Scale': scale_in,
'is_negative_input': 1
},
inputs={'Input': input_var_node},
outputs={'Output': output_var_node})
graph.link_to(input_var_node, quant_op_node)
graph.link_to(quant_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
def _remove_fake_dequantize_op(self, graph, op_node):
input_var_node = graph._find_node_by_name(op_node.inputs,
op_node.input("X")[0])
graph.safe_remove_nodes(op_node)
def _load_param(self, scope, param_name):
return np.array(scope.find_var(param_name).get_tensor())
def _restore_var(self, name, array):
tensor = self._scope.find_var(name).get_tensor()
tensor.set(array, self._place)
def _remove_unused_var_nodes(self, graph):
all_used_vars = set()
ops = graph.all_op_nodes()
for op_node in ops:
for input_node in op_node.inputs:
all_used_vars.add(input_node)
for output_node in op_node.outputs:
all_used_vars.add(output_node)
all_used_vars = {n.node for n in all_used_vars}
all_unused_vars = {
n
for n in filter(lambda node: node.node not in all_used_vars,
graph.all_var_nodes())
}
graph.safe_remove_nodes(all_unused_vars)
class Qat2Int8MkldnnPass(object):
"""
Transform a QAT model IrGraph into MKL-DNN supported INT8 IrGraph.
The pass consists of the following transformations:
1. gather scale values from fake quantize/dequantize operators,
2. extract FP32 inference model graph from the QAT graph, i.e.
a. remove fake quantize/dequantize operators,
b. dequantize conv2d and mul's weights,
3. optimize the FP32 graph using standard FP32 optimization fuses
(e.g. `conv2d`+`bn` -> `conv2d`),
4. quantize the optimized FP32 graph using standard INT8v2 quantization
passes (`cpu_quantize_pass`, `cpu_quantize_squash_pass`).
"""
def __init__(self,
_quantized_ops,
_scope=None,
_place=None,
_core=None,
_debug=False):
self._scope = _scope
self._place = _place
self._core = _core
self._debug = _debug
self._quantize_types = [
'fake_quantize_moving_average_abs_max',
'fake_quantize_range_abs_max',
'fake_quantize_dequantize_moving_average_abs_max'
]
self._fake_quantize_types = [
'fake_quantize_moving_average_abs_max',
'fake_quantize_dequantize_moving_average_abs_max'
]
self._fake_dequantize_types = ['fake_dequantize_max_abs']
self._quantized_ops = _quantized_ops
self._scale_immutable_ops = [
'transpose2', 'reshape2', 'pool2d', 'scale'
]
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._pool_ops = ['pool2d']
self._mul_ops = ['mul']
self._fc_ops = ['fc']
self._weight_scales = {}
# Collect the Input and Output sclaes from Fake QAT models
self._var_quant_scales = {}
self._max_range = {}
self._s8_max = 127
def apply(self, graph):
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
graph = self._gather_scales(graph)
graph = self._remove_fake_ops(graph)
graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph)
graph = self._compute_weight_scales(graph)
graph = self._update_relu_output_scales(graph)
graph = self._propagate_scales(graph)
graph = self._set_dummy_fc_out_scales(graph)
graph = self._quantize_fp32_graph(graph)
graph = self._remove_unused_var_nodes(graph)
return graph
def apply_fp32(self, graph):
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
graph = self._gather_scales(graph)
graph = self._remove_fake_ops(graph)
graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph)
graph = self._remove_unused_var_nodes(graph)
return graph
def _convert_scale2tensor(self, scale):
tensor = core.LoDTensor()
tensor.set(scale, core.CPUPlace())
return tensor
def _is_conv_quantized(self):
return any(op_type in self._quantized_ops for op_type in self._conv_ops)
def _is_fc_quantized(self):
return 'fc' in self._quantized_ops
def _gather_scales(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._quantize_types:
bit_length = op.op().attr("bit_length")
assert bit_length == 8, 'Unsupported number quantization bits ({}). Only 8 is supported now.'.format(
bit_length)
input_name = op.input("X")[0]
scale_name = op.input("InScale")[0]
# Gather new weights scale after folding batchnorm in convolution
scale = np.array(1.0 / self._load_param(
self._scope, scale_name)[0]).astype(np.float64)
lod_tensor = self._convert_scale2tensor(scale)
use_unsigned_int = False
self._var_quant_scales[input_name] = (use_unsigned_int,
lod_tensor)
self._var_quant_scales[scale_name.replace(".scale", "")] = (
use_unsigned_int, lod_tensor)
if op.name() in self._fake_dequantize_types:
input_name = op.input("X")[0]
_max_range = op.op().attr("max_range")
self._weight_scales[input_name] = _max_range
return graph
def _propagate_scales(self, graph):
def _update_scale_op_in_scale(op, input, output):
unsigned, tensor = self._var_quant_scales[output]
scale = np.array(tensor) * op.op().attr("scale")
new_tensor = self._convert_scale2tensor(scale.astype(np.float64))
self._var_quant_scales[input] = (unsigned, new_tensor)
def _update_scales(graph):
waiting_for_scale = set()
for op in graph.all_op_nodes():
if op.name() in self._scale_immutable_ops:
input_name = op.input("X")[0]
output_name = op.output("Out")[0]
tensor_names = [input_name, output_name]
# Scale is not quantized, so if it doesn't have any scales
# to propagate, its tensors won't be added to the waiting list.
if all(name not in self._var_quant_scales for name in tensor_names) \
and op.name() != 'scale':
waiting_for_scale.update(tensor_names)
continue
if input_name in self._var_quant_scales:
self._var_quant_scales[
output_name] = self._var_quant_scales[input_name]
elif output_name in self._var_quant_scales:
if op.name() == 'scale':
_update_scale_op_in_scale(op, input_name,
output_name)
else:
self._var_quant_scales[
input_name] = self._var_quant_scales[
output_name]
return waiting_for_scale
waiting_for_scale = _update_scales(graph)
waiting_for_scale_prev = set()
while len(waiting_for_scale
) != 0 and waiting_for_scale != waiting_for_scale_prev:
waiting_for_scale_prev = waiting_for_scale
waiting_for_scale = _update_scales(graph)
return graph
def _set_dummy_fc_out_scales(self, graph):
'''
For the output tensors of FC that do not have an assigned scale,
assign a dummy scale (same scale as input), so that the quantize pass
won't fail. In the end these scales aren't used, since FCs that
have an unassigend output scale will have a force_fp32_output attr
set to True.
'''
for op in graph.all_op_nodes():
if op.name() in self._fc_ops:
input_name = op.input("Input")[0]
output_name = op.output("Out")[0]
if input_name in self._var_quant_scales and \
output_name not in self._var_quant_scales:
# use input scale as a "dummy" scale
self._var_quant_scales[
output_name] = self._var_quant_scales[input_name]
return graph
def _load_param(self, scope, param_name):
return np.array(scope.find_var(param_name).get_tensor())
def _remove_fake_ops(self, graph):
'''
When FC isn't quantized:
Remove fake (de)quantize ops that do not surround mul.
When FC is quantized:
Remove all fake (de)quantize ops.
'''
is_fc_quantized = self._is_fc_quantized()
for op in graph.all_op_nodes():
if op.name() in self._fake_quantize_types:
op_out = graph._find_node_by_name(op.outputs,
op.output("Out")[0])
next_op = op_out.outputs[0]
if next_op.name() not in self._mul_ops or is_fc_quantized:
self._remove_fake_quantize(graph, op)
for op in graph.all_op_nodes():
if op.name() in self._fake_dequantize_types:
op_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
prev_op = op_in.inputs[0]
if prev_op.name() not in self._mul_ops or is_fc_quantized:
self._remove_fake_dequantize(graph, op)
return graph
def _remove_fake_quantize(self, graph, op):
fake_quant_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
fake_quant_in_scale = graph._find_node_by_name(op.inputs,
op.input("InScale")[0])
fake_quant_out = graph._find_node_by_name(op.outputs,
op.output("Out")[0])
fake_quant_out_scale = graph._find_node_by_name(
op.outputs, op.output("OutScale")[0])
next_ops = fake_quant_out.outputs
for next_op in next_ops:
self._swap_inputs(next_op, fake_quant_out, fake_quant_in)
graph.link_to(fake_quant_in, next_op)
graph.safe_remove_nodes(
{op, fake_quant_in_scale, fake_quant_out, fake_quant_out_scale})
return graph
def _remove_fake_dequantize(self, graph, op):
fake_dequant_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
fake_dequant_out = graph._find_node_by_name(op.outputs,
op.output("Out")[0])
next_ops = fake_dequant_out.outputs
for next_op in next_ops:
self._swap_inputs(next_op, fake_dequant_out, fake_dequant_in)
graph.link_to(fake_dequant_in, next_op)
graph.safe_remove_nodes({op, fake_dequant_out})
return graph
def _swap_inputs(self, op, old_input, new_input):
for input_name in op.op().input_names():
if old_input.name() in op.input(input_name):
op.op().set_input(input_name, [
new_input.name() if x == old_input.name() else x
for x in op.input(input_name)
])
def _dequantize_weights(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._conv_ops:
self._dequantize_conv_weights(graph, op)
elif self._is_fc_quantized() and op.name() in self._mul_ops:
self._dequantize_mul_weights(graph, op)
return graph
def _dequantize_conv_weights(self, graph, op_node):
weight_name = op_node.input("Filter")[0]
output_name = op_node.output("Output")[0]
# Convert int8 range weights to fp32 range weights
scales = self._weight_scales[output_name]
weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales)
w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32)
def _dequantize_mul_weights(self, graph, op_node):
weight_name = op_node.input("Y")[0]
output_name = op_node.output("Out")[0]
scales = self._weight_scales[output_name]
weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales)
w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32)
def _restore_var(self, name, array):
tensor = self._scope.find_var(name).get_tensor()
tensor.set(array, self._place)
def _remove_ctrl_vars(self, graph):
remove_ctr_vars = set()
for node in graph.all_var_nodes():
if node.is_ctrl_var():
remove_ctr_vars.add(node)
graph.safe_remove_nodes(remove_ctr_vars)
return graph
def _optimize_fp32_graph(self, graph):
graph = self._remove_ctrl_vars(graph)
graph = self._apply_pass(graph, 'mkldnn_placement_pass',
['mkldnn_enabled_op_types'], [set()])
if self._is_conv_quantized():
graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass')
graph = self._apply_pass(graph, 'conv_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass')
graph = self._apply_pass(graph,
'conv_elementwise_add_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass')
if self._is_fc_quantized():
graph = self._apply_pass(graph, 'fc_fuse_pass',
['use_gpu', 'use_fc_padding'],
[False, False])
graph = self._apply_pass(graph, 'fc_mkldnn_pass')
return graph
def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None):
ir_pass = core.get_pass(pass_name)
cpp_graph = graph.graph
if not cpp_graph.has('__param_scope__'):
cpp_graph.set_not_owned('__param_scope__', self._scope)
if attrs:
assert attr_values and len(attrs) == len(
attr_values
), "Different number of pass attributes and their values."
for attr, value in zip(attrs, attr_values):
ir_pass.set(attr, value)
ir_pass.apply(cpp_graph)
if self._debug:
graph.draw('.', 'qat_fp32_{}'.format(pass_name),
graph.all_op_nodes())
self._remove_unused_var_nodes(graph)
return graph
def _remove_unused_var_nodes(self, graph):
all_used_vars = set()
ops = graph.all_op_nodes()
for op_node in ops:
for input_node in op_node.inputs:
all_used_vars.add(input_node)
for output_node in op_node.outputs:
all_used_vars.add(output_node)
all_used_vars = {n.node for n in all_used_vars}
all_unused_vars = {
n
for n in filter(lambda node: node.node not in all_used_vars,
graph.all_var_nodes())
}
graph.safe_remove_nodes(all_unused_vars)
return graph
def _compute_weight_scales(self, graph):
def _compute_var_scales(ops, out_name, w_name, axis):
for op in graph.all_op_nodes():
if op.op().type() in ops:
weight_var_name = op.input(w_name)[0]
weights = np.array(
self._load_param(self._scope, weight_var_name))
scales = 1.0 / np.amax(
np.abs(weights.reshape(weights.shape[0], -1)).astype(
np.float64),
axis=axis)
scales[scales == np.Inf] = 0.0
lod_tensor = self._convert_scale2tensor(scales)
use_unsigned_int = False
self._var_quant_scales[weight_var_name] = (use_unsigned_int,
lod_tensor)
_compute_var_scales(self._conv_ops, "Output", "Filter", axis=1)
_compute_var_scales(self._fc_ops, "Out", "W", axis=0)
return graph
def _find_avg_pooling_ids(self, graph):
ids = []
for op in graph.all_op_nodes():
if op.name() in self._pool_ops:
if op.op().attr("pooling_type") == "avg":
ids.append(op.id())
return set(ids) if len(ids) else set([-1])
def _update_relu_output_scales(self, graph):
def _update_scale(graph, ops, op_out_name, predicate):
'''
Sets the type of an output scale of a passed op type(s) to 'unsigned int8' if the
predicate applied on op passes. Typically, the predicate checks if op's
activation is set to relu.
'''
for op in graph.all_op_nodes():
if op.name() in ops:
out_name = op.output(op_out_name)[0]
if out_name in self._var_quant_scales and predicate(op.op(
)):
_, tensor = self._var_quant_scales[out_name]
self._var_quant_scales[out_name] = (True, tensor)
return graph
if self._is_conv_quantized():
conv_predicate = lambda op: op.attr("fuse_activation") == 'relu' and \
op.attr("fuse_residual_connection") == False
graph = _update_scale(graph, self._conv_ops, "Output",
conv_predicate)
if self._is_fc_quantized():
fc_predicate = lambda op: op.attr("activation_type") == 'relu'
graph = _update_scale(graph, self._fc_ops, "Out", fc_predicate)
return graph
def _get_data_layout(self):
return 'NHWC' if self._is_conv_quantized() else 'NCHW'
def _quantize_fp32_graph(self, graph):
ir_pass = self._core.get_pass('cpu_quantize_placement_pass')
cpp_graph = graph.graph
ir_pass.set('quantize_enabled_op_types', self._quantized_ops)
ir_pass.set('quantize_excluded_op_ids',
self._find_avg_pooling_ids(graph))
ir_pass.apply(cpp_graph)
if self._debug:
graph.draw('.', 'qat_int8_{}'.format(ir_pass.type()),
graph.all_op_nodes())
graph = self._apply_pass(
graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'],
[self._var_quant_scales, self._get_data_layout()])
graph = self._apply_pass(graph, 'cpu_quantize_squash_pass')
return graph