|
|
|
@ -27,9 +27,9 @@ class TransformForMkldnnPass(object):
|
|
|
|
|
1. Convert int8 range weights with float32 data type, which are generated by
|
|
|
|
|
the QuantizationFreezePass, to float32 range weights with float32 data type
|
|
|
|
|
by using the corresponding scales. This conversion is because MKL-DNN INT8
|
|
|
|
|
conv2d kernel now only supports float32 weights input, will do weights
|
|
|
|
|
quantization inside the conv2d kernel.
|
|
|
|
|
2. Create the new conv2d op with the converted weights and link its output
|
|
|
|
|
conv2d kernel and mul kernel now only support float32 weights input, hence
|
|
|
|
|
weights quantization will happen inside the conv2d and mul INT8 kernel.
|
|
|
|
|
2. Create the new conv2d or mul op with the converted weights and link its output
|
|
|
|
|
to fake_dequantize_abs_max op's output and set conv2d's attribute "force_fp32
|
|
|
|
|
_output" as true
|
|
|
|
|
3. Transform fake_quantize_xx op to quantize op
|
|
|
|
@ -73,13 +73,8 @@ class TransformForMkldnnPass(object):
|
|
|
|
|
|
|
|
|
|
self.InScale = {}
|
|
|
|
|
self.max_range = {}
|
|
|
|
|
self.conv_new_output = {}
|
|
|
|
|
self.new_output = {}
|
|
|
|
|
self.s8_max = 127
|
|
|
|
|
# Temporary code for keeping the mul op as fake quantization
|
|
|
|
|
#TODO Intel: Remove the following code when mul int8 mkldnn
|
|
|
|
|
# kernel enabled
|
|
|
|
|
self.mul_input_id = []
|
|
|
|
|
self.mul_output_id = []
|
|
|
|
|
|
|
|
|
|
def apply(self, graph):
|
|
|
|
|
"""
|
|
|
|
@ -97,7 +92,7 @@ class TransformForMkldnnPass(object):
|
|
|
|
|
|
|
|
|
|
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
|
|
|
|
|
# Collect the InScales and max_range to calculate the new scales for MKL-DNN
|
|
|
|
|
# INT8 conv2d
|
|
|
|
|
# INT8 conv2d and mul
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
if op_node.name() in self.dequantize_type:
|
|
|
|
|
input_name = op_node.input("X")[0]
|
|
|
|
@ -105,20 +100,14 @@ class TransformForMkldnnPass(object):
|
|
|
|
|
self.InScale[input_name] = self._load_param(self._scope,
|
|
|
|
|
scale_name)[0]
|
|
|
|
|
self.max_range[input_name] = op_node.op().attr("max_range")
|
|
|
|
|
self.conv_new_output[input_name] = op_node.output("Out")[0]
|
|
|
|
|
# Temporary graph transform on keeping the mul op
|
|
|
|
|
# TODO Intel: Remove following code
|
|
|
|
|
elif op_node.name() in ['mul']:
|
|
|
|
|
input_node = graph._find_node_by_name(op_node.inputs,
|
|
|
|
|
op_node.input('X')[0])
|
|
|
|
|
output_node = graph._find_node_by_name(op_node.outputs,
|
|
|
|
|
op_node.output('Out')[0])
|
|
|
|
|
self.mul_input_id.append(input_node.id())
|
|
|
|
|
self.mul_output_id.append(output_node.id())
|
|
|
|
|
self.new_output[input_name] = op_node.output("Out")[0]
|
|
|
|
|
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
if op_node.name() in self._conv_ops:
|
|
|
|
|
self._transform_to_conv_mkldnn(graph, op_node)
|
|
|
|
|
if op_node.name() in self._quantizable_ops:
|
|
|
|
|
if op_node.name() in self._conv_ops:
|
|
|
|
|
self._transform_to_conv_mkldnn(graph, op_node)
|
|
|
|
|
else:
|
|
|
|
|
self._transform_to_mul_mkldnn(graph, op_node)
|
|
|
|
|
elif op_node.name() in self.quantize_type:
|
|
|
|
|
self._transform_to_quantize_mkldnn(graph, op_node)
|
|
|
|
|
elif op_node.name() in self.dequantize_type:
|
|
|
|
@ -132,7 +121,7 @@ class TransformForMkldnnPass(object):
|
|
|
|
|
# Convert int8 range weights to fp32 range weights
|
|
|
|
|
weight = self._load_param(self._scope, weight_name)
|
|
|
|
|
w_fp32 = np.divide(
|
|
|
|
|
np.multiply(weight, 127), self.max_range[output_name])
|
|
|
|
|
np.multiply(weight, self.s8_max), self.max_range[output_name])
|
|
|
|
|
w_fp32 = w_fp32.reshape(weight.shape)
|
|
|
|
|
self._restore_var(weight_name, w_fp32)
|
|
|
|
|
input_var_node = graph._find_node_by_name(op_node.inputs,
|
|
|
|
@ -140,8 +129,8 @@ class TransformForMkldnnPass(object):
|
|
|
|
|
weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)
|
|
|
|
|
|
|
|
|
|
# Set fake_dequantize_abs_max's output as new output of conv2d
|
|
|
|
|
output_var_node = graph._find_node_by_name(
|
|
|
|
|
graph.all_var_nodes(), self.conv_new_output[output_name])
|
|
|
|
|
output_var_node = graph._find_node_by_name(graph.all_var_nodes(),
|
|
|
|
|
self.new_output[output_name])
|
|
|
|
|
attrs = {
|
|
|
|
|
name: op_node.op().attr(name)
|
|
|
|
|
for name in op_node.op().attr_names()
|
|
|
|
@ -157,7 +146,7 @@ class TransformForMkldnnPass(object):
|
|
|
|
|
# Based on the QAT's scales to calculate the scales of MKL-DNN INT8 conv2d
|
|
|
|
|
scale_in = self.s8_max / self.InScale[output_name]
|
|
|
|
|
scale_w = []
|
|
|
|
|
scale_w.append(self.max_range[output_name] / self.s8_max)
|
|
|
|
|
scale_w = [self.max_range[output_name] / self.s8_max]
|
|
|
|
|
|
|
|
|
|
conv_op_node.set_attr("Scale_weights", scale_w)
|
|
|
|
|
conv_op_node.set_attr("Scale_in", scale_in)
|
|
|
|
@ -169,6 +158,50 @@ class TransformForMkldnnPass(object):
|
|
|
|
|
graph.link_to(conv_op_node, output_var_node)
|
|
|
|
|
graph.safe_remove_nodes(op_node)
|
|
|
|
|
|
|
|
|
|
def _transform_to_mul_mkldnn(self, graph, op_node):
|
|
|
|
|
# For MKL-DNN INT8 mul, input Y should be the weights
|
|
|
|
|
weight_name = op_node.input("Y")[0]
|
|
|
|
|
output_name = op_node.output("Out")[0]
|
|
|
|
|
# Convert int8 range weights to fp32 range weights
|
|
|
|
|
weight = self._load_param(self._scope, weight_name)
|
|
|
|
|
w_fp32 = np.divide(
|
|
|
|
|
np.multiply(weight, self.s8_max), self.max_range[output_name])
|
|
|
|
|
w_fp32 = w_fp32.reshape(weight.shape)
|
|
|
|
|
self._restore_var(weight_name, w_fp32)
|
|
|
|
|
input_var_node = graph._find_node_by_name(op_node.inputs,
|
|
|
|
|
op_node.input("X")[0])
|
|
|
|
|
weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)
|
|
|
|
|
|
|
|
|
|
# Set fake_dequantize_abs_max's output as new output of mul
|
|
|
|
|
output_var_node = graph._find_node_by_name(graph.all_var_nodes(),
|
|
|
|
|
self.new_output[output_name])
|
|
|
|
|
attrs = {
|
|
|
|
|
name: op_node.op().attr(name)
|
|
|
|
|
for name in op_node.op().attr_names()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mul_op_node = graph.create_op_node(
|
|
|
|
|
op_type='mul',
|
|
|
|
|
attrs=attrs,
|
|
|
|
|
inputs={'X': input_var_node,
|
|
|
|
|
'Y': weight_var_node},
|
|
|
|
|
outputs={'Out': output_var_node})
|
|
|
|
|
|
|
|
|
|
# Based on the QAT's scales to calculate MKL-DNN INT8 mul's scales
|
|
|
|
|
scale_in = self.s8_max / self.InScale[output_name]
|
|
|
|
|
scale_w = []
|
|
|
|
|
scale_w = [self.max_range[output_name] / self.s8_max]
|
|
|
|
|
|
|
|
|
|
mul_op_node.set_attr("scale_y", scale_w)
|
|
|
|
|
mul_op_node.set_attr("scale_x", scale_in)
|
|
|
|
|
mul_op_node.set_attr("scale_out", 1.0)
|
|
|
|
|
mul_op_node.set_attr("use_mkldnn", 1)
|
|
|
|
|
mul_op_node.set_attr("force_fp32_output", 1)
|
|
|
|
|
graph.link_to(input_var_node, mul_op_node)
|
|
|
|
|
graph.link_to(weight_var_node, mul_op_node)
|
|
|
|
|
graph.link_to(mul_op_node, output_var_node)
|
|
|
|
|
graph.safe_remove_nodes(op_node)
|
|
|
|
|
|
|
|
|
|
def _transform_to_quantize_mkldnn(self, graph, op_node):
|
|
|
|
|
"""
|
|
|
|
|
Transform fake_quantize_xx op to quantize mkldnn op in the graph.
|
|
|
|
@ -177,32 +210,26 @@ class TransformForMkldnnPass(object):
|
|
|
|
|
op_node.input("X")[0])
|
|
|
|
|
output_var_node = graph._find_node_by_name(op_node.outputs,
|
|
|
|
|
op_node.output("Out")[0])
|
|
|
|
|
if output_var_node.id() in self.mul_input_id:
|
|
|
|
|
return
|
|
|
|
|
else:
|
|
|
|
|
scale_in = self.s8_max / self._load_param(
|
|
|
|
|
self._scope, op_node.input("InScale")[0])[0]
|
|
|
|
|
quant_op_node = graph.create_op_node(
|
|
|
|
|
op_type='quantize',
|
|
|
|
|
attrs={
|
|
|
|
|
'data_format': 'MKLDNNLAYOUT',
|
|
|
|
|
'use_mkldnn': 1,
|
|
|
|
|
'Scale': scale_in,
|
|
|
|
|
'is_negative_input': 1
|
|
|
|
|
},
|
|
|
|
|
inputs={'Input': input_var_node},
|
|
|
|
|
outputs={'Output': output_var_node})
|
|
|
|
|
graph.link_to(input_var_node, quant_op_node)
|
|
|
|
|
graph.link_to(quant_op_node, output_var_node)
|
|
|
|
|
graph.safe_remove_nodes(op_node)
|
|
|
|
|
scale_in = self.s8_max / self._load_param(
|
|
|
|
|
self._scope, op_node.input("InScale")[0])[0]
|
|
|
|
|
quant_op_node = graph.create_op_node(
|
|
|
|
|
op_type='quantize',
|
|
|
|
|
attrs={
|
|
|
|
|
'data_format': 'MKLDNNLAYOUT',
|
|
|
|
|
'use_mkldnn': 1,
|
|
|
|
|
'Scale': scale_in,
|
|
|
|
|
'is_negative_input': 1
|
|
|
|
|
},
|
|
|
|
|
inputs={'Input': input_var_node},
|
|
|
|
|
outputs={'Output': output_var_node})
|
|
|
|
|
graph.link_to(input_var_node, quant_op_node)
|
|
|
|
|
graph.link_to(quant_op_node, output_var_node)
|
|
|
|
|
graph.safe_remove_nodes(op_node)
|
|
|
|
|
|
|
|
|
|
def _remove_fake_dequantize_op(self, graph, op_node):
|
|
|
|
|
input_var_node = graph._find_node_by_name(op_node.inputs,
|
|
|
|
|
op_node.input("X")[0])
|
|
|
|
|
if input_var_node.id() in self.mul_output_id:
|
|
|
|
|
return
|
|
|
|
|
else:
|
|
|
|
|
graph.safe_remove_nodes(op_node)
|
|
|
|
|
graph.safe_remove_nodes(op_node)
|
|
|
|
|
|
|
|
|
|
def _load_param(self, scope, param_name):
|
|
|
|
|
return np.array(scope.find_var(param_name).get_tensor())
|
|
|
|
|