QAT int8 MKL-DNN transformation pass (#17819)
parent
377f9e6142
commit
90ebce9ead
@ -0,0 +1,229 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from .... import core
|
||||
from ....framework import IrGraph
|
||||
from ....framework import IrNode
|
||||
|
||||
__all__ = ['TransformForMkldnnPass']
|
||||
|
||||
|
||||
class TransformForMkldnnPass(object):
|
||||
"""
|
||||
Convert QuantizationFreezePass generated IrGraph to MKL-DNN supported INT8
|
||||
IrGraph. Following transformations did in this pass:
|
||||
1. Convert int8 range weights with float32 data type, which are generated by
|
||||
the QuantizationFreezePass, to float32 range weights with float32 data type
|
||||
by using the corresponding scales. This conversion is because MKL-DNN INT8
|
||||
conv2d kernel now only supports float32 weights input, will do weights
|
||||
quantization inside the conv2d kernel.
|
||||
2. Create the new conv2d op with the converted weights and link its output
|
||||
to fake_dequantize_abs_max op's output and set conv2d's attribute "force_fp32
|
||||
_output" as true
|
||||
3. Transform fake_quantize_xx op to quantize op
|
||||
4. Remove fake_dequantize_abs_max op
|
||||
"""
|
||||
|
||||
def __init__(self, scope=None, place=None):
|
||||
"""
|
||||
Args:
|
||||
scope(fluid.Scope): scope is used to initialize the new parameters.
|
||||
place(fluid.CPUPlace): place is used to initialize the new parameters.
|
||||
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
# The original graph will be rewrite.
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.contrib.slim.quantization \
|
||||
import TransformForMkldnnPass
|
||||
from paddle.fluid.framework import IrGraph
|
||||
from paddle.fluid import core
|
||||
|
||||
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
|
||||
place = fluid.CPUPlace()
|
||||
mkldnn_pass = TransformForMkldnnPass(fluid.global_scope(),
|
||||
place)
|
||||
mkldnn_pass.apply(graph)
|
||||
"""
|
||||
|
||||
self._scope = scope
|
||||
self._place = place
|
||||
|
||||
self.quantize_type = [
|
||||
'fake_quantize_moving_average_abs_max',
|
||||
'fake_quantize_range_abs_max'
|
||||
]
|
||||
self.dequantize_type = ['fake_dequantize_max_abs']
|
||||
|
||||
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
|
||||
self._conv_ops = ['conv2d', 'depthwise_conv2d']
|
||||
|
||||
self.InScale = {}
|
||||
self.max_range = {}
|
||||
self.conv_new_output = {}
|
||||
self.s8_max = 127
|
||||
# Temporary code for keeping the mul op as fake quantization
|
||||
#TODO Intel: Remove the following code when mul int8 mkldnn
|
||||
# kernel enabled
|
||||
self.mul_input_id = []
|
||||
self.mul_output_id = []
|
||||
|
||||
def apply(self, graph):
|
||||
"""
|
||||
Quantize the graph for running MKL-DNN INT8 inference. According
|
||||
to activation quantization type, the graph will transform fake
|
||||
quantize ops to quantize ops and remove the fake dequantize ops.
|
||||
|
||||
Args:
|
||||
graph(IrGraph): the applied graph.
|
||||
"""
|
||||
|
||||
assert isinstance(graph,
|
||||
IrGraph), 'graph must be the instance of IrGraph.'
|
||||
ops = graph.all_op_nodes()
|
||||
|
||||
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
|
||||
# Collect the InScales and max_range to calculate the new scales for MKL-DNN
|
||||
# INT8 conv2d
|
||||
for op_node in ops:
|
||||
if op_node.name() in self.dequantize_type:
|
||||
input_name = op_node.input("X")[0]
|
||||
scale_name = op_node.input("Scale")[0]
|
||||
self.InScale[input_name] = self._load_param(self._scope,
|
||||
scale_name)[0]
|
||||
self.max_range[input_name] = op_node.op().attr("max_range")
|
||||
self.conv_new_output[input_name] = op_node.output("Out")[0]
|
||||
# Temporary graph transform on keeping the mul op
|
||||
# TODO Intel: Remove following code
|
||||
elif op_node.name() in ['mul']:
|
||||
input_node = graph._find_node_by_name(op_node.inputs,
|
||||
op_node.input('X')[0])
|
||||
output_node = graph._find_node_by_name(op_node.outputs,
|
||||
op_node.output('Out')[0])
|
||||
self.mul_input_id.append(input_node.id())
|
||||
self.mul_output_id.append(output_node.id())
|
||||
|
||||
for op_node in ops:
|
||||
if op_node.name() in self._conv_ops:
|
||||
self._transform_to_conv_mkldnn(graph, op_node)
|
||||
elif op_node.name() in self.quantize_type:
|
||||
self._transform_to_quantize_mkldnn(graph, op_node)
|
||||
elif op_node.name() in self.dequantize_type:
|
||||
self._remove_fake_dequantize_op(graph, op_node)
|
||||
self._remove_unused_var_nodes(graph)
|
||||
return graph
|
||||
|
||||
def _transform_to_conv_mkldnn(self, graph, op_node):
|
||||
weight_name = op_node.input("Filter")[0]
|
||||
output_name = op_node.output("Output")[0]
|
||||
# Convert int8 range weights to fp32 range weights
|
||||
weight = self._load_param(self._scope, weight_name)
|
||||
w_fp32 = np.divide(
|
||||
np.multiply(weight, 127), self.max_range[output_name])
|
||||
w_fp32 = w_fp32.reshape(weight.shape)
|
||||
self._restore_var(weight_name, w_fp32)
|
||||
input_var_node = graph._find_node_by_name(op_node.inputs,
|
||||
op_node.input("Input")[0])
|
||||
weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)
|
||||
|
||||
# Set fake_dequantize_abs_max's output as new output of conv2d
|
||||
output_var_node = graph._find_node_by_name(
|
||||
graph.all_var_nodes(), self.conv_new_output[output_name])
|
||||
attrs = {
|
||||
name: op_node.op().attr(name)
|
||||
for name in op_node.op().attr_names()
|
||||
}
|
||||
|
||||
conv_op_node = graph.create_op_node(
|
||||
op_type='conv2d',
|
||||
attrs=attrs,
|
||||
inputs={'Input': input_var_node,
|
||||
'Filter': weight_var_node},
|
||||
outputs={'Output': output_var_node})
|
||||
|
||||
# Based on the QAT's scales to calculate the scales of MKL-DNN INT8 conv2d
|
||||
scale_in = self.s8_max / self.InScale[output_name]
|
||||
scale_w = []
|
||||
scale_w.append(self.max_range[output_name] / self.s8_max)
|
||||
|
||||
conv_op_node.set_attr("Scale_weights", scale_w)
|
||||
conv_op_node.set_attr("Scale_in", scale_in)
|
||||
conv_op_node.set_attr("Scale_out", 1.0)
|
||||
conv_op_node.set_attr("use_mkldnn", 1)
|
||||
conv_op_node.set_attr("force_fp32_output", 1)
|
||||
graph.link_to(input_var_node, conv_op_node)
|
||||
graph.link_to(weight_var_node, conv_op_node)
|
||||
graph.link_to(conv_op_node, output_var_node)
|
||||
graph.safe_remove_nodes(op_node)
|
||||
|
||||
def _transform_to_quantize_mkldnn(self, graph, op_node):
|
||||
"""
|
||||
Transform fake_quantize_xx op to quantize mkldnn op in the graph.
|
||||
"""
|
||||
input_var_node = graph._find_node_by_name(op_node.inputs,
|
||||
op_node.input("X")[0])
|
||||
output_var_node = graph._find_node_by_name(op_node.outputs,
|
||||
op_node.output("Out")[0])
|
||||
if output_var_node.id() in self.mul_input_id:
|
||||
return
|
||||
else:
|
||||
scale_in = self.s8_max / self._load_param(
|
||||
self._scope, op_node.input("InScale")[0])[0]
|
||||
quant_op_node = graph.create_op_node(
|
||||
op_type='quantize',
|
||||
attrs={
|
||||
'data_format': 'MKLDNNLAYOUT',
|
||||
'use_mkldnn': 1,
|
||||
'Scale': scale_in,
|
||||
'is_negative_input': 1
|
||||
},
|
||||
inputs={'Input': input_var_node},
|
||||
outputs={'Output': output_var_node})
|
||||
graph.link_to(input_var_node, quant_op_node)
|
||||
graph.link_to(quant_op_node, output_var_node)
|
||||
graph.safe_remove_nodes(op_node)
|
||||
|
||||
def _remove_fake_dequantize_op(self, graph, op_node):
|
||||
input_var_node = graph._find_node_by_name(op_node.inputs,
|
||||
op_node.input("X")[0])
|
||||
if input_var_node.id() in self.mul_output_id:
|
||||
return
|
||||
else:
|
||||
graph.safe_remove_nodes(op_node)
|
||||
|
||||
def _load_param(self, scope, param_name):
|
||||
return np.array(scope.find_var(param_name).get_tensor())
|
||||
|
||||
def _restore_var(self, name, array):
|
||||
tensor = self._scope.find_var(name).get_tensor()
|
||||
tensor.set(array, self._place)
|
||||
|
||||
def _remove_unused_var_nodes(self, graph):
|
||||
all_used_vars = set()
|
||||
ops = graph.all_op_nodes()
|
||||
for op_node in ops:
|
||||
for input_node in op_node.inputs:
|
||||
all_used_vars.add(input_node)
|
||||
for output_node in op_node.outputs:
|
||||
all_used_vars.add(output_node)
|
||||
|
||||
all_used_vars = {n.node for n in all_used_vars}
|
||||
all_unused_vars = {
|
||||
n
|
||||
for n in filter(lambda node: node.node not in all_used_vars,
|
||||
graph.all_var_nodes())
|
||||
}
|
||||
graph.safe_remove_nodes(all_unused_vars)
|
@ -0,0 +1,193 @@
|
||||
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
|
||||
#
|
||||
# licensed under the apache license, version 2.0 (the "license");
|
||||
# you may not use this file except in compliance with the license.
|
||||
# you may obtain a copy of the license at
|
||||
#
|
||||
# http://www.apache.org/licenses/license-2.0
|
||||
#
|
||||
# unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the license is distributed on an "as is" basis,
|
||||
# without warranties or conditions of any kind, either express or implied.
|
||||
# see the license for the specific language governing permissions and
|
||||
# limitations under the license.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import random
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
import six
|
||||
import paddle
|
||||
from paddle.fluid.framework import IrGraph
|
||||
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
|
||||
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
|
||||
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass
|
||||
from paddle.fluid import core
|
||||
|
||||
os.environ["CPU_NUM"] = "1"
|
||||
|
||||
|
||||
def conv_net(img, label):
|
||||
conv_pool_1 = fluid.nets.simple_img_conv_pool(
|
||||
input=img,
|
||||
filter_size=5,
|
||||
num_filters=20,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
act="relu")
|
||||
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
|
||||
conv_pool_2 = fluid.nets.simple_img_conv_pool(
|
||||
input=conv_pool_1,
|
||||
filter_size=5,
|
||||
num_filters=50,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
act="relu")
|
||||
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
|
||||
loss = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||
avg_loss = fluid.layers.mean(loss)
|
||||
return avg_loss
|
||||
|
||||
|
||||
class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.quantizable_op_and_inputs = {
|
||||
'conv2d': ['Input', 'Filter'],
|
||||
'depthwise_conv2d': ['Input', 'Filter'],
|
||||
# Mul int8 op is under internal test
|
||||
# TODO Update this when mul op is merged
|
||||
#'mul': ['X', 'Y']
|
||||
}
|
||||
|
||||
def check_program(self, program):
|
||||
for block in program.blocks:
|
||||
for op in block.ops:
|
||||
if op.type in self.quantizable_op_and_inputs:
|
||||
for arg_name in op.output_arg_names:
|
||||
# Check quantizable op's output is linked to
|
||||
# fake_dequantize's output
|
||||
self.assertTrue(arg_name.endswith('.dequantized'))
|
||||
|
||||
def isinteger(self, x):
|
||||
return np.equal(np.mod(x, 1), 0)
|
||||
|
||||
def build_program(self, main, startup, is_test, seed):
|
||||
main.random_seed = seed
|
||||
startup.random_seed = seed
|
||||
with fluid.unique_name.guard():
|
||||
with fluid.program_guard(main, startup):
|
||||
img = fluid.layers.data(
|
||||
name='image', shape=[1, 28, 28], dtype='float32')
|
||||
label = fluid.layers.data(
|
||||
name='label', shape=[1], dtype='int64')
|
||||
loss = conv_net(img, label)
|
||||
if not is_test:
|
||||
opt = fluid.optimizer.Adam(learning_rate=0.001)
|
||||
opt.minimize(loss)
|
||||
return [img, label], loss
|
||||
|
||||
def mkldnn_based_freeze_graph(self,
|
||||
use_cuda,
|
||||
seed,
|
||||
activation_quant_type,
|
||||
weight_quant_type='abs_max',
|
||||
for_ci=False):
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
main = fluid.Program()
|
||||
startup = fluid.Program()
|
||||
test_program = fluid.Program()
|
||||
feeds, loss = self.build_program(main, startup, False, seed)
|
||||
self.build_program(test_program, startup, True, seed)
|
||||
test_program = test_program.clone(for_test=True)
|
||||
main_graph = IrGraph(core.Graph(main.desc), for_test=False)
|
||||
test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
scope = fluid.Scope()
|
||||
with fluid.scope_guard(scope):
|
||||
exe.run(startup)
|
||||
# Apply the QAT QuantizationTransformPass
|
||||
transform_pass = QuantizationTransformPass(
|
||||
scope=scope,
|
||||
place=place,
|
||||
activation_quantize_type=activation_quant_type,
|
||||
weight_quantize_type=weight_quant_type)
|
||||
transform_pass.apply(main_graph)
|
||||
transform_pass.apply(test_graph)
|
||||
|
||||
build_strategy = fluid.BuildStrategy()
|
||||
build_strategy.memory_optimize = False
|
||||
build_strategy.enable_inplace = False
|
||||
binary = fluid.CompiledProgram(main_graph.graph).with_data_parallel(
|
||||
loss_name=loss.name, build_strategy=build_strategy)
|
||||
quantized_test_program = test_graph.to_program()
|
||||
iters = 5
|
||||
batch_size = 8
|
||||
|
||||
train_reader = paddle.batch(
|
||||
paddle.reader.shuffle(
|
||||
paddle.dataset.mnist.train(), buf_size=500),
|
||||
batch_size=batch_size)
|
||||
test_reader = paddle.batch(
|
||||
paddle.dataset.mnist.test(), batch_size=batch_size)
|
||||
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
|
||||
|
||||
# Training the model to get the weights value
|
||||
with fluid.scope_guard(scope):
|
||||
for _ in range(iters):
|
||||
data = next(train_reader())
|
||||
loss_v = exe.run(binary,
|
||||
feed=feeder.feed(data),
|
||||
fetch_list=[loss])
|
||||
|
||||
# Freeze graph for inference, but the weight of fc/conv is still float type.
|
||||
freeze_pass = QuantizationFreezePass(
|
||||
scope=scope, place=place, weight_quantize_type=weight_quant_type)
|
||||
freeze_pass.apply(test_graph)
|
||||
|
||||
# Transform quantized graph for MKL-DNN INT8 inference
|
||||
mkldnn_int8_pass = TransformForMkldnnPass(scope=scope, place=place)
|
||||
mkldnn_int8_pass.apply(test_graph)
|
||||
dev_name = '_cpu_'
|
||||
if not for_ci:
|
||||
marked_nodes = set()
|
||||
for op in test_graph.all_op_nodes():
|
||||
if op.name().find('quantize') > -1:
|
||||
marked_nodes.add(op)
|
||||
test_graph.draw('.', 'test_mkldnn' + dev_name +
|
||||
activation_quant_type + '_' + weight_quant_type,
|
||||
marked_nodes)
|
||||
mkldnn_program = test_graph.to_program()
|
||||
w_mkldnn = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
|
||||
# Check if weights are still integer
|
||||
self.assertFalse(self.isinteger(np.sum(w_mkldnn)))
|
||||
|
||||
# Check if the conv2d output is rightly linked to fake_dequantize's
|
||||
# output
|
||||
self.check_program(mkldnn_program)
|
||||
if not for_ci:
|
||||
print('{}: {}'.format('w_mkldnn' + dev_name + activation_quant_type
|
||||
+ '_' + weight_quant_type, np.sum(w_mkldnn)))
|
||||
|
||||
def test_mkldnn_graph_cpu_static(self):
|
||||
with fluid.unique_name.guard():
|
||||
self.mkldnn_based_freeze_graph(
|
||||
False,
|
||||
seed=2,
|
||||
activation_quant_type='range_abs_max',
|
||||
weight_quant_type='abs_max',
|
||||
for_ci=True)
|
||||
self.mkldnn_based_freeze_graph(
|
||||
False,
|
||||
seed=2,
|
||||
activation_quant_type='moving_average_abs_max',
|
||||
weight_quant_type='abs_max',
|
||||
for_ci=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue