|
|
@ -20,7 +20,7 @@ from .... import io
|
|
|
|
from .... import core
|
|
|
|
from .... import core
|
|
|
|
from ....compiler import CompiledProgram
|
|
|
|
from ....compiler import CompiledProgram
|
|
|
|
from ....compiler import BuildStrategy
|
|
|
|
from ....compiler import BuildStrategy
|
|
|
|
from ....framework import IrGraph
|
|
|
|
from ....framework import IrGraph, Variable, Program
|
|
|
|
from ..core.strategy import Strategy
|
|
|
|
from ..core.strategy import Strategy
|
|
|
|
from .quantization_pass import *
|
|
|
|
from .quantization_pass import *
|
|
|
|
|
|
|
|
|
|
|
@ -88,41 +88,76 @@ class QuantizationStrategy(Strategy):
|
|
|
|
self.save_out_nodes = save_out_nodes
|
|
|
|
self.save_out_nodes = save_out_nodes
|
|
|
|
self.save_in_nodes = save_in_nodes
|
|
|
|
self.save_in_nodes = save_in_nodes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_compression_begin(self, context):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Restore graph when the compressoin task is inited from checkpoint.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
# It is inited from checkpoint and has missed start epoch.
|
|
|
|
|
|
|
|
if context.epoch_id != 0 and context.epoch_id > self.start_epoch:
|
|
|
|
|
|
|
|
_logger.info("Restore quantization task from checkpoint")
|
|
|
|
|
|
|
|
self._modify_graph_for_quantization(context)
|
|
|
|
|
|
|
|
_logger.info("Finish restoring quantization task from checkpoint")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _modify_graph_for_quantization(self, context):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Insert fake_quantize_op and fake_dequantize_op before trainging and testing.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
train_ir_graph = IrGraph(
|
|
|
|
|
|
|
|
core.Graph(context.optimize_graph.program.clone().desc),
|
|
|
|
|
|
|
|
for_test=False)
|
|
|
|
|
|
|
|
test_ir_graph = IrGraph(
|
|
|
|
|
|
|
|
core.Graph(context.eval_graph.program.clone().desc), for_test=True)
|
|
|
|
|
|
|
|
transform_pass = QuantizationTransformPass(
|
|
|
|
|
|
|
|
scope=context.scope,
|
|
|
|
|
|
|
|
place=context.place,
|
|
|
|
|
|
|
|
weight_bits=self.weight_bits,
|
|
|
|
|
|
|
|
activation_bits=self.activation_bits,
|
|
|
|
|
|
|
|
activation_quantize_type=self.activation_quantize_type,
|
|
|
|
|
|
|
|
weight_quantize_type=self.weight_quantize_type)
|
|
|
|
|
|
|
|
transform_pass.apply(train_ir_graph)
|
|
|
|
|
|
|
|
transform_pass.apply(test_ir_graph)
|
|
|
|
|
|
|
|
# Put persistables created by transform_pass into context.optimize_graph.persistables
|
|
|
|
|
|
|
|
# for saving checkpoint.
|
|
|
|
|
|
|
|
program_persistables = set()
|
|
|
|
|
|
|
|
for var in context.optimize_graph.program.list_vars():
|
|
|
|
|
|
|
|
if var.persistable:
|
|
|
|
|
|
|
|
program_persistables.add(var.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
program = Program()
|
|
|
|
|
|
|
|
for var_node in train_ir_graph.all_persistable_nodes():
|
|
|
|
|
|
|
|
if var_node.name() not in program_persistables:
|
|
|
|
|
|
|
|
var_desc = var_node.var()
|
|
|
|
|
|
|
|
var = program.global_block().create_var(
|
|
|
|
|
|
|
|
name=var_node.name(),
|
|
|
|
|
|
|
|
shape=var_desc.shape(),
|
|
|
|
|
|
|
|
dtype=var_desc.dtype(),
|
|
|
|
|
|
|
|
type=var_desc.type(),
|
|
|
|
|
|
|
|
lod_level=var_desc.lod_level())
|
|
|
|
|
|
|
|
context.optimize_graph.persistables[var.name] = var
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build_strategy = BuildStrategy()
|
|
|
|
|
|
|
|
build_strategy.enable_inplace = False
|
|
|
|
|
|
|
|
build_strategy.memory_optimize = False
|
|
|
|
|
|
|
|
# for quantization training
|
|
|
|
|
|
|
|
context.optimize_graph.compiled_graph = CompiledProgram(
|
|
|
|
|
|
|
|
train_ir_graph.graph).with_data_parallel(
|
|
|
|
|
|
|
|
loss_name=context.optimize_graph.out_nodes['loss'],
|
|
|
|
|
|
|
|
build_strategy=build_strategy)
|
|
|
|
|
|
|
|
# for evaluation. And program compiled from ir graph must be with data parallel.
|
|
|
|
|
|
|
|
context.eval_graph.compiled_graph = CompiledProgram(
|
|
|
|
|
|
|
|
test_ir_graph.graph).with_data_parallel(
|
|
|
|
|
|
|
|
build_strategy=build_strategy)
|
|
|
|
|
|
|
|
# for saving inference model after training
|
|
|
|
|
|
|
|
context.put('quantization_test_ir_graph_backup', test_ir_graph)
|
|
|
|
|
|
|
|
|
|
|
|
def on_epoch_begin(self, context):
|
|
|
|
def on_epoch_begin(self, context):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Insert fake_quantize_op and fake_dequantize_op before trainging and testing.
|
|
|
|
Insert fake_quantize_op and fake_dequantize_op before trainging and testing.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
super(QuantizationStrategy, self).on_compression_begin(context)
|
|
|
|
super(QuantizationStrategy, self).on_epoch_begin(context)
|
|
|
|
if self.start_epoch == context.epoch_id:
|
|
|
|
if self.start_epoch == context.epoch_id:
|
|
|
|
_logger.info('QuantizationStrategy::on_epoch_begin')
|
|
|
|
_logger.info('QuantizationStrategy::on_epoch_begin')
|
|
|
|
train_ir_graph = IrGraph(
|
|
|
|
self._modify_graph_for_quantization(context)
|
|
|
|
core.Graph(context.optimize_graph.program.desc), for_test=False)
|
|
|
|
|
|
|
|
test_ir_graph = IrGraph(
|
|
|
|
|
|
|
|
core.Graph(context.eval_graph.program.desc), for_test=True)
|
|
|
|
|
|
|
|
transform_pass = QuantizationTransformPass(
|
|
|
|
|
|
|
|
scope=context.scope,
|
|
|
|
|
|
|
|
place=context.place,
|
|
|
|
|
|
|
|
weight_bits=self.weight_bits,
|
|
|
|
|
|
|
|
activation_bits=self.activation_bits,
|
|
|
|
|
|
|
|
activation_quantize_type=self.activation_quantize_type,
|
|
|
|
|
|
|
|
weight_quantize_type=self.weight_quantize_type)
|
|
|
|
|
|
|
|
transform_pass.apply(train_ir_graph)
|
|
|
|
|
|
|
|
transform_pass.apply(test_ir_graph)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build_strategy = BuildStrategy()
|
|
|
|
|
|
|
|
build_strategy.enable_inplace = False
|
|
|
|
|
|
|
|
build_strategy.memory_optimize = False
|
|
|
|
|
|
|
|
# for quantization training
|
|
|
|
|
|
|
|
context.optimize_graph.compiled_graph = CompiledProgram(
|
|
|
|
|
|
|
|
train_ir_graph.graph).with_data_parallel(
|
|
|
|
|
|
|
|
loss_name=context.optimize_graph.out_nodes['loss'],
|
|
|
|
|
|
|
|
build_strategy=build_strategy)
|
|
|
|
|
|
|
|
# for evaluation. And program compiled from ir graph must be with data parallel.
|
|
|
|
|
|
|
|
context.eval_graph.compiled_graph = CompiledProgram(
|
|
|
|
|
|
|
|
test_ir_graph.graph).with_data_parallel(
|
|
|
|
|
|
|
|
build_strategy=build_strategy)
|
|
|
|
|
|
|
|
# for saving inference model after training
|
|
|
|
|
|
|
|
context.put('quantization_test_ir_graph_backup', test_ir_graph)
|
|
|
|
|
|
|
|
_logger.info('Finish QuantizationStrategy::on_epoch_begin')
|
|
|
|
_logger.info('Finish QuantizationStrategy::on_epoch_begin')
|
|
|
|
|
|
|
|
|
|
|
|
def on_epoch_end(self, context):
|
|
|
|
def on_epoch_end(self, context):
|
|
|
|