Fix distillation for soft label. (#16538)

test=develop
revert-16555-model_data_cryption_link_all_lib
whs 6 years ago committed by GitHub
parent 3e6aa498d6
commit 73c4f2b7b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,7 +19,7 @@ from .... import Program
from .... import program_guard
from .... import regularizer
__all__ = ['FSPDistiller', 'L2Distiller']
__all__ = ['FSPDistiller', 'L2Distiller', 'SoftLabelDistiller']
class L2Distiller(object):
@ -186,3 +186,91 @@ class FSPDistillerPass(object):
def _fsp_matrix(self, fea_map_0, fea_map_1):
return layers.fsp_matrix(fea_map_0, fea_map_1)
class SoftLabelDistiller(object):
"""
Combine two layers from student net and teacher net by softmax_with_cross_entropy loss.
And add the loss into the total loss using for distillation training.
"""
def __init__(self,
student_feature_map=None,
teacher_feature_map=None,
student_temperature=1.0,
teacher_temperature=1.0,
distillation_loss_weight=1):
"""
Args:
student_feature_map(str): The name of feature map from student network.
teacher_feature_map(str): The name of feature map from teacher network.
It's shape should be the same with student network.
student_temperature(float): Temperature used to divide student_feature_map before softmax_with_cross_entropy. default: 1.0
teacher_temperature(float): Temperature used to divide teacher_feature_map before softmax_with_cross_entropy. default: 1.0
distillation_loss_weight(float): The weight of the l2-loss.
"""
self.student_feature_map = student_feature_map
self.teacher_feature_map = teacher_feature_map
self.distillation_loss_weight = distillation_loss_weight
self.student_temperature = student_temperature
self.teacher_temperature = teacher_temperature
def distiller_loss(self, graph):
"""
Modify graph inplace to add softmax_with_cross_entropy loss.
Args:
graph(GraphWrapper): The graph to be modified.
Returns:
GraphWrapper: The modified graph.
"""
distiller_pass = SoftLabelDistillerPass(
self.student_feature_map, self.teacher_feature_map,
self.student_temperature, self.teacher_temperature,
self.distillation_loss_weight)
dis_graph = distiller_pass.apply(graph)
return dis_graph
class SoftLabelDistillerPass(object):
def __init__(self,
student_feature_map,
teacher_feature_map,
student_temperature,
teacher_temperature,
distillation_loss_weight=1):
"""
Args:
student_feature_map(str): The name of feature map from student network.
teacher_feature_map(str): The name of feature map from teacher network.
It's shape should be the same with student network.
student_temperature(float): Temperature used to divide student_feature_map before softmax_with_cross_entropy.
teacher_temperature(float): Temperature used to divide teacher_feature_map before softmax_with_cross_entropy.
distillation_loss_weight(float): The weight of the l2-loss.
"""
self.student_feature_map = student_feature_map
self.teacher_feature_map = teacher_feature_map
self.student_temperature = student_temperature
self.teacher_temperature = teacher_temperature
self.distillation_loss_weight = distillation_loss_weight
def apply(self, graph):
ret_graph = graph
with program_guard(ret_graph.program):
student_feature_map = ret_graph.var(self.student_feature_map)._var
teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var
s_fea = student_feature_map / self.student_temperature
t_fea = teacher_feature_map / self.distillation_loss_weight
t_fea.stop_gradient = True
ce_loss = layers.softmax_with_cross_entropy(
s_fea, t_fea, soft_label=True)
distillation_loss = ce_loss * self.distillation_loss_weight
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
loss = distillation_loss + student_loss
ret_graph.out_nodes[
'soft_label_loss_' + self.student_feature_map + "_" +
self.teacher_feature_map] = distillation_loss.name
ret_graph.out_nodes['loss'] = loss.name
return ret_graph

@ -33,10 +33,17 @@ distillers:
teacher_feature_map: 'teacher.tmp_2'
student_feature_map: 'student.tmp_2'
distillation_loss_weight: 1
soft_label_distiller:
class: 'SoftLabelDistiller'
student_temperature: 1.0
teacher_temperature: 1.0
teacher_feature_map: 'teacher.tmp_1'
student_feature_map: 'student.tmp_1'
distillation_loss_weight: 0.001
strategies:
distillation_strategy:
class: 'DistillationStrategy'
distillers: ['fsp_distiller', 'l2_distiller']
distillers: ['fsp_distiller', 'l2_distiller', 'soft_label_distiller']
start_epoch: 0
end_epoch: 1
compressor:

Loading…
Cancel
Save