You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py

255 lines
11 KiB

# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
import numpy as np
from .... import Executor
from .... import io
from .... import core, scope_guard
from ....compiler import CompiledProgram
from ....compiler import BuildStrategy
from ....framework import IrGraph, Variable, Program
from ....log_helper import get_logger
from ..core.strategy import Strategy
from .quantization_pass import *
__all__ = ['QuantizationStrategy']
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class QuantizationStrategy(Strategy):
"""
The strategy for Quantization.
"""
def __init__(self,
start_epoch=0,
end_epoch=0,
float_model_save_path=None,
mobile_model_save_path=None,
int8_model_save_path=None,
activation_bits=8,
weight_bits=8,
activation_quantize_type='abs_max',
weight_quantize_type='abs_max',
save_in_nodes=None,
save_out_nodes=None):
"""
Args:
start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 0
float_model_save_path(str): The path to save model with float weights.
None means it doesn't save float model. default: None.
mobile_model_save_path(str): The path to save model for paddle-mobile execution.
None means it doesn't save mobile model. default: None.
int8_model_save_path(str): The path to save model with int8_t weight.
None means it doesn't save int8 model. default: None.
activation_bits(int): quantization bit number for activation. default: 8.
weight_bits(int): quantization bit number for weights. The bias is not quantized.
default: 8.
activation_quantize_type(str): quantization type for activation,
now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
If use 'abs_max' mode, the quantization scale will be calculated
dynamically each step in both training and testing period. If use
'range_abs_max', a static quantization scale will be calculated
during training and used in inference.
weight_quantize_type (str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'.
The 'range_abs_max' usually is not used for weight, since weights are fixed once the model is well trained.
save_in_nodes(list<str>): A list of variable names used to prune graph
for saving inference model.
save_out_nodes(list<str>): A list of variable names used to prune graph
for saving inference model.
"""
super(QuantizationStrategy, self).__init__(start_epoch, end_epoch)
self.start_epoch = start_epoch
self.end_epoch = end_epoch
self.float_model_save_path = float_model_save_path
self.mobile_model_save_path = mobile_model_save_path
self.int8_model_save_path = int8_model_save_path
self.activation_bits = activation_bits
self.weight_bits = weight_bits
self.activation_quantize_type = activation_quantize_type
self.weight_quantize_type = weight_quantize_type
self.save_out_nodes = save_out_nodes
self.save_in_nodes = save_in_nodes
def restore_from_checkpoint(self, context):
"""
Restore graph when the compression task is inited from checkpoint.
"""
# It is inited from checkpoint and has missed start epoch.
if context.epoch_id != 0 and context.epoch_id > self.start_epoch:
_logger.info("Restore quantization task from checkpoint")
self._modify_graph_for_quantization(context)
_logger.info("Finish restoring quantization task from checkpoint")
def _modify_graph_for_quantization(self, context):
"""
Insert fake_quantize_op and fake_dequantize_op before training and testing.
"""
train_ir_graph = IrGraph(
core.Graph(context.optimize_graph.program.clone().desc),
for_test=False)
test_ir_graph = IrGraph(
core.Graph(context.eval_graph.program.clone().desc), for_test=True)
transform_pass = QuantizationTransformPass(
scope=context.scope,
place=context.place,
weight_bits=self.weight_bits,
activation_bits=self.activation_bits,
activation_quantize_type=self.activation_quantize_type,
weight_quantize_type=self.weight_quantize_type)
transform_pass.apply(train_ir_graph)
transform_pass.apply(test_ir_graph)
# Put persistables created by transform_pass into context.optimize_graph.persistables
# for saving checkpoint.
program_persistables = set()
for var in context.optimize_graph.program.list_vars():
if var.persistable:
program_persistables.add(var.name)
program = Program()
for var_node in train_ir_graph.all_persistable_nodes():
if var_node.name() not in program_persistables:
var_desc = var_node.var()
var = program.global_block().create_var(
name=var_node.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level())
context.optimize_graph.persistables[var.name] = var
build_strategy = BuildStrategy()
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
build_strategy.fuse_all_reduce_ops = False
# for quantization training
context.optimize_graph.compiled_graph = CompiledProgram(
train_ir_graph.graph).with_data_parallel(
loss_name=context.optimize_graph.out_nodes['loss'],
build_strategy=build_strategy)
context.eval_graph.program = test_ir_graph.to_program()
# for saving inference model after training
context.put('quantization_test_ir_graph_backup', test_ir_graph)
def on_epoch_begin(self, context):
"""
Insert fake_quantize_op and fake_dequantize_op before training and testing.
"""
super(QuantizationStrategy, self).on_epoch_begin(context)
if self.start_epoch == context.epoch_id:
_logger.info('QuantizationStrategy::on_epoch_begin')
self._modify_graph_for_quantization(context)
_logger.info('Finish QuantizationStrategy::on_epoch_begin')
def on_epoch_end(self, context):
"""
Free and save inference model.
"""
super(QuantizationStrategy, self).on_compression_end(context)
if context.epoch_id == self.end_epoch:
_logger.info('QuantizationStrategy::on_epoch_end')
test_ir_graph = context.get('quantization_test_ir_graph_backup')
# freeze the graph after training
freeze_pass = QuantizationFreezePass(
scope=context.scope,
place=context.place,
weight_bits=self.weight_bits,
activation_bits=self.activation_bits,
weight_quantize_type=self.weight_quantize_type)
freeze_pass.apply(test_ir_graph)
# for other strategies
context.eval_graph.program = test_ir_graph.to_program()
if self.save_out_nodes == None:
out_vars = [
context.eval_graph.var(var_name)._var
for var_name in context.eval_graph.out_nodes.values()
]
else:
out_vars = [
context.eval_graph.var(var_name)._var
for var_name in self.save_out_nodes
]
if self.save_in_nodes == None:
in_vars = list(context.eval_graph.in_nodes.values())
else:
in_vars = self.save_in_nodes
# save float model
if self.float_model_save_path:
executor = Executor(context.place)
with scope_guard(context.scope):
io.save_inference_model(
self.float_model_save_path,
in_vars,
out_vars,
executor,
main_program=test_ir_graph.to_program(),
model_filename='model',
params_filename='weights',
export_for_deployment=True)
# save int8 model
if self.int8_model_save_path:
convert_int8_pass = ConvertToInt8Pass(
scope=context.scope, place=context.place)
convert_int8_pass.apply(test_ir_graph)
executor = Executor(context.place)
with scope_guard(context.scope):
io.save_inference_model(
self.int8_model_save_path,
in_vars,
out_vars,
executor,
main_program=test_ir_graph.to_program(),
model_filename='model',
params_filename='weights',
export_for_deployment=True)
# save mobile model
if self.mobile_model_save_path:
if not self.int8_model_save_path:
# convert the weights as int8_t type
convert_int8_pass = ConvertToInt8Pass(
scope=context.scope, place=context.place)
convert_int8_pass.apply(test_ir_graph)
# make some changes on the graph for the mobile inference
mobile_pass = TransformForMobilePass()
mobile_pass.apply(test_ir_graph)
executor = Executor(context.place)
with scope_guard(context.scope):
io.save_inference_model(
self.mobile_model_save_path,
in_vars,
out_vars,
executor,
main_program=test_ir_graph.to_program(),
model_filename='model',
params_filename='weights',
export_for_deployment=True)
_logger.info('Finish QuantizationStrategy::on_epoch_end')