You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
104 lines
4.3 KiB
104 lines
4.3 KiB
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from ..core.strategy import Strategy
|
|
from ....framework import Program, Variable, program_guard
|
|
from .... import Executor
|
|
import logging
|
|
|
|
__all__ = ['DistillationStrategy']
|
|
|
|
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
|
|
_logger = logging.getLogger(__name__)
|
|
_logger.setLevel(logging.INFO)
|
|
|
|
|
|
class DistillationStrategy(Strategy):
|
|
def __init__(self, distillers=None, start_epoch=0, end_epoch=0):
|
|
"""
|
|
Args:
|
|
distillers(list): A list of distiller used to combine student graph and teacher graph
|
|
by adding some loss.
|
|
start_epoch(int): The epoch when to merge student graph and teacher graph for
|
|
distillation training. default: 0
|
|
end_epoch(int): The epoch when to finish distillation training. default: 0
|
|
|
|
"""
|
|
super(DistillationStrategy, self).__init__(start_epoch, end_epoch)
|
|
self.distillers = distillers
|
|
|
|
def restore_from_checkpoint(self, context):
|
|
# load from checkpoint
|
|
if context.epoch_id > 0:
|
|
if context.epoch_id > self.start_epoch and context.epoch_id < self.end_epoch:
|
|
_logger.info('Restore DistillationStrategy')
|
|
self._create_distillation_graph(context)
|
|
_logger.info('Restore DistillationStrategy finish.')
|
|
|
|
def on_epoch_begin(self, context):
|
|
if self.start_epoch == context.epoch_id:
|
|
_logger.info('DistillationStrategy::on_epoch_begin.')
|
|
self._create_distillation_graph(context)
|
|
_logger.info('DistillationStrategy set optimize_graph.')
|
|
|
|
def _create_distillation_graph(self, context):
|
|
"""
|
|
step 1: Merge student graph and teacher graph into distillation graph.
|
|
step 2: Add loss into distillation graph by distillers.
|
|
step 3: Append backward ops and optimize ops into distillation graph for training.
|
|
"""
|
|
# step 1
|
|
teacher = context.teacher_graphs[0]
|
|
for var in teacher.program.list_vars():
|
|
var.stop_gradient = True
|
|
graph = context.train_graph.clone()
|
|
graph.merge(teacher)
|
|
graph.out_nodes['student_loss'] = graph.out_nodes['loss']
|
|
|
|
# step 2
|
|
for distiller in self.distillers:
|
|
graph = distiller.distiller_loss(graph)
|
|
|
|
# step 3
|
|
startup_program = Program()
|
|
with program_guard(graph.program, startup_program):
|
|
context.distiller_optimizer._name = 'distillation_optimizer'
|
|
|
|
# The learning rate variable may be created in other program.
|
|
# Update information in optimizer to make
|
|
# learning rate variable being accessible in current program.
|
|
optimizer = context.distiller_optimizer
|
|
if isinstance(optimizer._learning_rate, Variable):
|
|
optimizer._learning_rate_map[
|
|
graph.program] = optimizer._learning_rate
|
|
|
|
optimizer.minimize(graph.var(graph.out_nodes['loss'])._var)
|
|
|
|
exe = Executor(context.place)
|
|
exe.run(startup_program, scope=context.scope)
|
|
|
|
# backup graph for fine-tune after distillation
|
|
context.put('distillation_backup_optimize_graph',
|
|
context.optimize_graph)
|
|
context.optimize_graph = graph
|
|
|
|
def on_epoch_end(self, context):
|
|
if context.epoch_id == (self.end_epoch - 1):
|
|
_logger.info('DistillationStrategy::on_epoch_end.')
|
|
# restore optimize_graph for fine-tune or other strategy in next stage.
|
|
context.optimize_graph = context.get(
|
|
'distillation_backup_optimize_graph')
|
|
_logger.info(
|
|
'DistillationStrategy set context.optimize_graph to None.')
|