QAT int8 MKL-DNN transformation pass (#17819)
parent
377f9e6142
commit
90ebce9ead
@ -0,0 +1,229 @@
|
|||||||
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from .... import core
|
||||||
|
from ....framework import IrGraph
|
||||||
|
from ....framework import IrNode
|
||||||
|
|
||||||
|
__all__ = ['TransformForMkldnnPass']
|
||||||
|
|
||||||
|
|
||||||
|
class TransformForMkldnnPass(object):
|
||||||
|
"""
|
||||||
|
Convert QuantizationFreezePass generated IrGraph to MKL-DNN supported INT8
|
||||||
|
IrGraph. Following transformations did in this pass:
|
||||||
|
1. Convert int8 range weights with float32 data type, which are generated by
|
||||||
|
the QuantizationFreezePass, to float32 range weights with float32 data type
|
||||||
|
by using the corresponding scales. This conversion is because MKL-DNN INT8
|
||||||
|
conv2d kernel now only supports float32 weights input, will do weights
|
||||||
|
quantization inside the conv2d kernel.
|
||||||
|
2. Create the new conv2d op with the converted weights and link its output
|
||||||
|
to fake_dequantize_abs_max op's output and set conv2d's attribute "force_fp32
|
||||||
|
_output" as true
|
||||||
|
3. Transform fake_quantize_xx op to quantize op
|
||||||
|
4. Remove fake_dequantize_abs_max op
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scope=None, place=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
scope(fluid.Scope): scope is used to initialize the new parameters.
|
||||||
|
place(fluid.CPUPlace): place is used to initialize the new parameters.
|
||||||
|
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
.. code-block:: python
|
||||||
|
# The original graph will be rewrite.
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from paddle.fluid.contrib.slim.quantization \
|
||||||
|
import TransformForMkldnnPass
|
||||||
|
from paddle.fluid.framework import IrGraph
|
||||||
|
from paddle.fluid import core
|
||||||
|
|
||||||
|
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
|
||||||
|
place = fluid.CPUPlace()
|
||||||
|
mkldnn_pass = TransformForMkldnnPass(fluid.global_scope(),
|
||||||
|
place)
|
||||||
|
mkldnn_pass.apply(graph)
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._scope = scope
|
||||||
|
self._place = place
|
||||||
|
|
||||||
|
self.quantize_type = [
|
||||||
|
'fake_quantize_moving_average_abs_max',
|
||||||
|
'fake_quantize_range_abs_max'
|
||||||
|
]
|
||||||
|
self.dequantize_type = ['fake_dequantize_max_abs']
|
||||||
|
|
||||||
|
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
|
||||||
|
self._conv_ops = ['conv2d', 'depthwise_conv2d']
|
||||||
|
|
||||||
|
self.InScale = {}
|
||||||
|
self.max_range = {}
|
||||||
|
self.conv_new_output = {}
|
||||||
|
self.s8_max = 127
|
||||||
|
# Temporary code for keeping the mul op as fake quantization
|
||||||
|
#TODO Intel: Remove the following code when mul int8 mkldnn
|
||||||
|
# kernel enabled
|
||||||
|
self.mul_input_id = []
|
||||||
|
self.mul_output_id = []
|
||||||
|
|
||||||
|
def apply(self, graph):
|
||||||
|
"""
|
||||||
|
Quantize the graph for running MKL-DNN INT8 inference. According
|
||||||
|
to activation quantization type, the graph will transform fake
|
||||||
|
quantize ops to quantize ops and remove the fake dequantize ops.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph(IrGraph): the applied graph.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert isinstance(graph,
|
||||||
|
IrGraph), 'graph must be the instance of IrGraph.'
|
||||||
|
ops = graph.all_op_nodes()
|
||||||
|
|
||||||
|
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
|
||||||
|
# Collect the InScales and max_range to calculate the new scales for MKL-DNN
|
||||||
|
# INT8 conv2d
|
||||||
|
for op_node in ops:
|
||||||
|
if op_node.name() in self.dequantize_type:
|
||||||
|
input_name = op_node.input("X")[0]
|
||||||
|
scale_name = op_node.input("Scale")[0]
|
||||||
|
self.InScale[input_name] = self._load_param(self._scope,
|
||||||
|
scale_name)[0]
|
||||||
|
self.max_range[input_name] = op_node.op().attr("max_range")
|
||||||
|
self.conv_new_output[input_name] = op_node.output("Out")[0]
|
||||||
|
# Temporary graph transform on keeping the mul op
|
||||||
|
# TODO Intel: Remove following code
|
||||||
|
elif op_node.name() in ['mul']:
|
||||||
|
input_node = graph._find_node_by_name(op_node.inputs,
|
||||||
|
op_node.input('X')[0])
|
||||||
|
output_node = graph._find_node_by_name(op_node.outputs,
|
||||||
|
op_node.output('Out')[0])
|
||||||
|
self.mul_input_id.append(input_node.id())
|
||||||
|
self.mul_output_id.append(output_node.id())
|
||||||
|
|
||||||
|
for op_node in ops:
|
||||||
|
if op_node.name() in self._conv_ops:
|
||||||
|
self._transform_to_conv_mkldnn(graph, op_node)
|
||||||
|
elif op_node.name() in self.quantize_type:
|
||||||
|
self._transform_to_quantize_mkldnn(graph, op_node)
|
||||||
|
elif op_node.name() in self.dequantize_type:
|
||||||
|
self._remove_fake_dequantize_op(graph, op_node)
|
||||||
|
self._remove_unused_var_nodes(graph)
|
||||||
|
return graph
|
||||||
|
|
||||||
|
def _transform_to_conv_mkldnn(self, graph, op_node):
|
||||||
|
weight_name = op_node.input("Filter")[0]
|
||||||
|
output_name = op_node.output("Output")[0]
|
||||||
|
# Convert int8 range weights to fp32 range weights
|
||||||
|
weight = self._load_param(self._scope, weight_name)
|
||||||
|
w_fp32 = np.divide(
|
||||||
|
np.multiply(weight, 127), self.max_range[output_name])
|
||||||
|
w_fp32 = w_fp32.reshape(weight.shape)
|
||||||
|
self._restore_var(weight_name, w_fp32)
|
||||||
|
input_var_node = graph._find_node_by_name(op_node.inputs,
|
||||||
|
op_node.input("Input")[0])
|
||||||
|
weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)
|
||||||
|
|
||||||
|
# Set fake_dequantize_abs_max's output as new output of conv2d
|
||||||
|
output_var_node = graph._find_node_by_name(
|
||||||
|
graph.all_var_nodes(), self.conv_new_output[output_name])
|
||||||
|
attrs = {
|
||||||
|
name: op_node.op().attr(name)
|
||||||
|
for name in op_node.op().attr_names()
|
||||||
|
}
|
||||||
|
|
||||||
|
conv_op_node = graph.create_op_node(
|
||||||
|
op_type='conv2d',
|
||||||
|
attrs=attrs,
|
||||||
|
inputs={'Input': input_var_node,
|
||||||
|
'Filter': weight_var_node},
|
||||||
|
outputs={'Output': output_var_node})
|
||||||
|
|
||||||
|
# Based on the QAT's scales to calculate the scales of MKL-DNN INT8 conv2d
|
||||||
|
scale_in = self.s8_max / self.InScale[output_name]
|
||||||
|
scale_w = []
|
||||||
|
scale_w.append(self.max_range[output_name] / self.s8_max)
|
||||||
|
|
||||||
|
conv_op_node.set_attr("Scale_weights", scale_w)
|
||||||
|
conv_op_node.set_attr("Scale_in", scale_in)
|
||||||
|
conv_op_node.set_attr("Scale_out", 1.0)
|
||||||
|
conv_op_node.set_attr("use_mkldnn", 1)
|
||||||
|
conv_op_node.set_attr("force_fp32_output", 1)
|
||||||
|
graph.link_to(input_var_node, conv_op_node)
|
||||||
|
graph.link_to(weight_var_node, conv_op_node)
|
||||||
|
graph.link_to(conv_op_node, output_var_node)
|
||||||
|
graph.safe_remove_nodes(op_node)
|
||||||
|
|
||||||
|
def _transform_to_quantize_mkldnn(self, graph, op_node):
|
||||||
|
"""
|
||||||
|
Transform fake_quantize_xx op to quantize mkldnn op in the graph.
|
||||||
|
"""
|
||||||
|
input_var_node = graph._find_node_by_name(op_node.inputs,
|
||||||
|
op_node.input("X")[0])
|
||||||
|
output_var_node = graph._find_node_by_name(op_node.outputs,
|
||||||
|
op_node.output("Out")[0])
|
||||||
|
if output_var_node.id() in self.mul_input_id:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
scale_in = self.s8_max / self._load_param(
|
||||||
|
self._scope, op_node.input("InScale")[0])[0]
|
||||||
|
quant_op_node = graph.create_op_node(
|
||||||
|
op_type='quantize',
|
||||||
|
attrs={
|
||||||
|
'data_format': 'MKLDNNLAYOUT',
|
||||||
|
'use_mkldnn': 1,
|
||||||
|
'Scale': scale_in,
|
||||||
|
'is_negative_input': 1
|
||||||
|
},
|
||||||
|
inputs={'Input': input_var_node},
|
||||||
|
outputs={'Output': output_var_node})
|
||||||
|
graph.link_to(input_var_node, quant_op_node)
|
||||||
|
graph.link_to(quant_op_node, output_var_node)
|
||||||
|
graph.safe_remove_nodes(op_node)
|
||||||
|
|
||||||
|
def _remove_fake_dequantize_op(self, graph, op_node):
|
||||||
|
input_var_node = graph._find_node_by_name(op_node.inputs,
|
||||||
|
op_node.input("X")[0])
|
||||||
|
if input_var_node.id() in self.mul_output_id:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
graph.safe_remove_nodes(op_node)
|
||||||
|
|
||||||
|
def _load_param(self, scope, param_name):
|
||||||
|
return np.array(scope.find_var(param_name).get_tensor())
|
||||||
|
|
||||||
|
def _restore_var(self, name, array):
|
||||||
|
tensor = self._scope.find_var(name).get_tensor()
|
||||||
|
tensor.set(array, self._place)
|
||||||
|
|
||||||
|
def _remove_unused_var_nodes(self, graph):
|
||||||
|
all_used_vars = set()
|
||||||
|
ops = graph.all_op_nodes()
|
||||||
|
for op_node in ops:
|
||||||
|
for input_node in op_node.inputs:
|
||||||
|
all_used_vars.add(input_node)
|
||||||
|
for output_node in op_node.outputs:
|
||||||
|
all_used_vars.add(output_node)
|
||||||
|
|
||||||
|
all_used_vars = {n.node for n in all_used_vars}
|
||||||
|
all_unused_vars = {
|
||||||
|
n
|
||||||
|
for n in filter(lambda node: node.node not in all_used_vars,
|
||||||
|
graph.all_var_nodes())
|
||||||
|
}
|
||||||
|
graph.safe_remove_nodes(all_unused_vars)
|
||||||
@ -0,0 +1,193 @@
|
|||||||
|
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
|
||||||
|
#
|
||||||
|
# licensed under the apache license, version 2.0 (the "license");
|
||||||
|
# you may not use this file except in compliance with the license.
|
||||||
|
# you may obtain a copy of the license at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/license-2.0
|
||||||
|
#
|
||||||
|
# unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the license is distributed on an "as is" basis,
|
||||||
|
# without warranties or conditions of any kind, either express or implied.
|
||||||
|
# see the license for the specific language governing permissions and
|
||||||
|
# limitations under the license.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import six
|
||||||
|
import paddle
|
||||||
|
from paddle.fluid.framework import IrGraph
|
||||||
|
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
|
||||||
|
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
|
||||||
|
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass
|
||||||
|
from paddle.fluid import core
|
||||||
|
|
||||||
|
os.environ["CPU_NUM"] = "1"
|
||||||
|
|
||||||
|
|
||||||
|
def conv_net(img, label):
|
||||||
|
conv_pool_1 = fluid.nets.simple_img_conv_pool(
|
||||||
|
input=img,
|
||||||
|
filter_size=5,
|
||||||
|
num_filters=20,
|
||||||
|
pool_size=2,
|
||||||
|
pool_stride=2,
|
||||||
|
act="relu")
|
||||||
|
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
|
||||||
|
conv_pool_2 = fluid.nets.simple_img_conv_pool(
|
||||||
|
input=conv_pool_1,
|
||||||
|
filter_size=5,
|
||||||
|
num_filters=50,
|
||||||
|
pool_size=2,
|
||||||
|
pool_stride=2,
|
||||||
|
act="relu")
|
||||||
|
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
|
||||||
|
loss = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||||
|
avg_loss = fluid.layers.mean(loss)
|
||||||
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
|
class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.quantizable_op_and_inputs = {
|
||||||
|
'conv2d': ['Input', 'Filter'],
|
||||||
|
'depthwise_conv2d': ['Input', 'Filter'],
|
||||||
|
# Mul int8 op is under internal test
|
||||||
|
# TODO Update this when mul op is merged
|
||||||
|
#'mul': ['X', 'Y']
|
||||||
|
}
|
||||||
|
|
||||||
|
def check_program(self, program):
|
||||||
|
for block in program.blocks:
|
||||||
|
for op in block.ops:
|
||||||
|
if op.type in self.quantizable_op_and_inputs:
|
||||||
|
for arg_name in op.output_arg_names:
|
||||||
|
# Check quantizable op's output is linked to
|
||||||
|
# fake_dequantize's output
|
||||||
|
self.assertTrue(arg_name.endswith('.dequantized'))
|
||||||
|
|
||||||
|
def isinteger(self, x):
|
||||||
|
return np.equal(np.mod(x, 1), 0)
|
||||||
|
|
||||||
|
def build_program(self, main, startup, is_test, seed):
|
||||||
|
main.random_seed = seed
|
||||||
|
startup.random_seed = seed
|
||||||
|
with fluid.unique_name.guard():
|
||||||
|
with fluid.program_guard(main, startup):
|
||||||
|
img = fluid.layers.data(
|
||||||
|
name='image', shape=[1, 28, 28], dtype='float32')
|
||||||
|
label = fluid.layers.data(
|
||||||
|
name='label', shape=[1], dtype='int64')
|
||||||
|
loss = conv_net(img, label)
|
||||||
|
if not is_test:
|
||||||
|
opt = fluid.optimizer.Adam(learning_rate=0.001)
|
||||||
|
opt.minimize(loss)
|
||||||
|
return [img, label], loss
|
||||||
|
|
||||||
|
def mkldnn_based_freeze_graph(self,
|
||||||
|
use_cuda,
|
||||||
|
seed,
|
||||||
|
activation_quant_type,
|
||||||
|
weight_quant_type='abs_max',
|
||||||
|
for_ci=False):
|
||||||
|
random.seed(0)
|
||||||
|
np.random.seed(0)
|
||||||
|
|
||||||
|
main = fluid.Program()
|
||||||
|
startup = fluid.Program()
|
||||||
|
test_program = fluid.Program()
|
||||||
|
feeds, loss = self.build_program(main, startup, False, seed)
|
||||||
|
self.build_program(test_program, startup, True, seed)
|
||||||
|
test_program = test_program.clone(for_test=True)
|
||||||
|
main_graph = IrGraph(core.Graph(main.desc), for_test=False)
|
||||||
|
test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
|
||||||
|
|
||||||
|
place = fluid.CPUPlace()
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
scope = fluid.Scope()
|
||||||
|
with fluid.scope_guard(scope):
|
||||||
|
exe.run(startup)
|
||||||
|
# Apply the QAT QuantizationTransformPass
|
||||||
|
transform_pass = QuantizationTransformPass(
|
||||||
|
scope=scope,
|
||||||
|
place=place,
|
||||||
|
activation_quantize_type=activation_quant_type,
|
||||||
|
weight_quantize_type=weight_quant_type)
|
||||||
|
transform_pass.apply(main_graph)
|
||||||
|
transform_pass.apply(test_graph)
|
||||||
|
|
||||||
|
build_strategy = fluid.BuildStrategy()
|
||||||
|
build_strategy.memory_optimize = False
|
||||||
|
build_strategy.enable_inplace = False
|
||||||
|
binary = fluid.CompiledProgram(main_graph.graph).with_data_parallel(
|
||||||
|
loss_name=loss.name, build_strategy=build_strategy)
|
||||||
|
quantized_test_program = test_graph.to_program()
|
||||||
|
iters = 5
|
||||||
|
batch_size = 8
|
||||||
|
|
||||||
|
train_reader = paddle.batch(
|
||||||
|
paddle.reader.shuffle(
|
||||||
|
paddle.dataset.mnist.train(), buf_size=500),
|
||||||
|
batch_size=batch_size)
|
||||||
|
test_reader = paddle.batch(
|
||||||
|
paddle.dataset.mnist.test(), batch_size=batch_size)
|
||||||
|
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
|
||||||
|
|
||||||
|
# Training the model to get the weights value
|
||||||
|
with fluid.scope_guard(scope):
|
||||||
|
for _ in range(iters):
|
||||||
|
data = next(train_reader())
|
||||||
|
loss_v = exe.run(binary,
|
||||||
|
feed=feeder.feed(data),
|
||||||
|
fetch_list=[loss])
|
||||||
|
|
||||||
|
# Freeze graph for inference, but the weight of fc/conv is still float type.
|
||||||
|
freeze_pass = QuantizationFreezePass(
|
||||||
|
scope=scope, place=place, weight_quantize_type=weight_quant_type)
|
||||||
|
freeze_pass.apply(test_graph)
|
||||||
|
|
||||||
|
# Transform quantized graph for MKL-DNN INT8 inference
|
||||||
|
mkldnn_int8_pass = TransformForMkldnnPass(scope=scope, place=place)
|
||||||
|
mkldnn_int8_pass.apply(test_graph)
|
||||||
|
dev_name = '_cpu_'
|
||||||
|
if not for_ci:
|
||||||
|
marked_nodes = set()
|
||||||
|
for op in test_graph.all_op_nodes():
|
||||||
|
if op.name().find('quantize') > -1:
|
||||||
|
marked_nodes.add(op)
|
||||||
|
test_graph.draw('.', 'test_mkldnn' + dev_name +
|
||||||
|
activation_quant_type + '_' + weight_quant_type,
|
||||||
|
marked_nodes)
|
||||||
|
mkldnn_program = test_graph.to_program()
|
||||||
|
w_mkldnn = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
|
||||||
|
# Check if weights are still integer
|
||||||
|
self.assertFalse(self.isinteger(np.sum(w_mkldnn)))
|
||||||
|
|
||||||
|
# Check if the conv2d output is rightly linked to fake_dequantize's
|
||||||
|
# output
|
||||||
|
self.check_program(mkldnn_program)
|
||||||
|
if not for_ci:
|
||||||
|
print('{}: {}'.format('w_mkldnn' + dev_name + activation_quant_type
|
||||||
|
+ '_' + weight_quant_type, np.sum(w_mkldnn)))
|
||||||
|
|
||||||
|
def test_mkldnn_graph_cpu_static(self):
|
||||||
|
with fluid.unique_name.guard():
|
||||||
|
self.mkldnn_based_freeze_graph(
|
||||||
|
False,
|
||||||
|
seed=2,
|
||||||
|
activation_quant_type='range_abs_max',
|
||||||
|
weight_quant_type='abs_max',
|
||||||
|
for_ci=True)
|
||||||
|
self.mkldnn_based_freeze_graph(
|
||||||
|
False,
|
||||||
|
seed=2,
|
||||||
|
activation_quant_type='moving_average_abs_max',
|
||||||
|
weight_quant_type='abs_max',
|
||||||
|
for_ci=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in new issue