You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py

104 lines
4.3 KiB

# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..core.strategy import Strategy
from ....framework import Program, Variable, program_guard
from .... import Executor
import logging
__all__ = ['DistillationStrategy']
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
class DistillationStrategy(Strategy):
def __init__(self, distillers=None, start_epoch=0, end_epoch=0):
"""
Args:
distillers(list): A list of distiller used to combine student graph and teacher graph
by adding some loss.
start_epoch(int): The epoch when to merge student graph and teacher graph for
distillation training. default: 0
end_epoch(int): The epoch when to finish distillation training. default: 0
"""
super(DistillationStrategy, self).__init__(start_epoch, end_epoch)
self.distillers = distillers
def restore_from_checkpoint(self, context):
# load from checkpoint
if context.epoch_id > 0:
if context.epoch_id > self.start_epoch and context.epoch_id < self.end_epoch:
_logger.info('Restore DistillationStrategy')
self._create_distillation_graph(context)
_logger.info('Restore DistillationStrategy finish.')
def on_epoch_begin(self, context):
if self.start_epoch == context.epoch_id:
_logger.info('DistillationStrategy::on_epoch_begin.')
self._create_distillation_graph(context)
_logger.info('DistillationStrategy set optimize_graph.')
def _create_distillation_graph(self, context):
"""
step 1: Merge student graph and teacher graph into distillation graph.
step 2: Add loss into distillation graph by distillers.
step 3: Append backward ops and optimize ops into distillation graph for training.
"""
# step 1
teacher = context.teacher_graphs[0]
for var in teacher.program.list_vars():
var.stop_gradient = True
graph = context.train_graph.clone()
graph.merge(teacher)
graph.out_nodes['student_loss'] = graph.out_nodes['loss']
# step 2
for distiller in self.distillers:
graph = distiller.distiller_loss(graph)
# step 3
startup_program = Program()
with program_guard(graph.program, startup_program):
context.distiller_optimizer._name = 'distillation_optimizer'
# The learning rate variable may be created in other program.
# Update information in optimizer to make
# learning rate variable being accessible in current program.
optimizer = context.distiller_optimizer
if isinstance(optimizer._learning_rate, Variable):
optimizer._learning_rate_map[
graph.program] = optimizer._learning_rate
optimizer.minimize(graph.var(graph.out_nodes['loss'])._var)
exe = Executor(context.place)
exe.run(startup_program, scope=context.scope)
# backup graph for fine-tune after distillation
context.put('distillation_backup_optimize_graph',
context.optimize_graph)
context.optimize_graph = graph
def on_epoch_end(self, context):
if context.epoch_id == (self.end_epoch - 1):
_logger.info('DistillationStrategy::on_epoch_end.')
# restore optimize_graph for fine-tune or other strategy in next stage.
context.optimize_graph = context.get(
'distillation_backup_optimize_graph')
_logger.info(
'DistillationStrategy set context.optimize_graph to None.')