|
|
|
@ -19,7 +19,7 @@ from .... import Program
|
|
|
|
|
from .... import program_guard
|
|
|
|
|
from .... import regularizer
|
|
|
|
|
|
|
|
|
|
__all__ = ['FSPDistiller', 'L2Distiller']
|
|
|
|
|
__all__ = ['FSPDistiller', 'L2Distiller', 'SoftLabelDistiller']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class L2Distiller(object):
|
|
|
|
@ -186,3 +186,91 @@ class FSPDistillerPass(object):
|
|
|
|
|
|
|
|
|
|
def _fsp_matrix(self, fea_map_0, fea_map_1):
|
|
|
|
|
return layers.fsp_matrix(fea_map_0, fea_map_1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SoftLabelDistiller(object):
|
|
|
|
|
"""
|
|
|
|
|
Combine two layers from student net and teacher net by softmax_with_cross_entropy loss.
|
|
|
|
|
And add the loss into the total loss using for distillation training.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
student_feature_map=None,
|
|
|
|
|
teacher_feature_map=None,
|
|
|
|
|
student_temperature=1.0,
|
|
|
|
|
teacher_temperature=1.0,
|
|
|
|
|
distillation_loss_weight=1):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
student_feature_map(str): The name of feature map from student network.
|
|
|
|
|
teacher_feature_map(str): The name of feature map from teacher network.
|
|
|
|
|
It's shape should be the same with student network.
|
|
|
|
|
student_temperature(float): Temperature used to divide student_feature_map before softmax_with_cross_entropy. default: 1.0
|
|
|
|
|
teacher_temperature(float): Temperature used to divide teacher_feature_map before softmax_with_cross_entropy. default: 1.0
|
|
|
|
|
distillation_loss_weight(float): The weight of the l2-loss.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
self.student_feature_map = student_feature_map
|
|
|
|
|
self.teacher_feature_map = teacher_feature_map
|
|
|
|
|
self.distillation_loss_weight = distillation_loss_weight
|
|
|
|
|
self.student_temperature = student_temperature
|
|
|
|
|
self.teacher_temperature = teacher_temperature
|
|
|
|
|
|
|
|
|
|
def distiller_loss(self, graph):
|
|
|
|
|
"""
|
|
|
|
|
Modify graph inplace to add softmax_with_cross_entropy loss.
|
|
|
|
|
Args:
|
|
|
|
|
graph(GraphWrapper): The graph to be modified.
|
|
|
|
|
Returns:
|
|
|
|
|
GraphWrapper: The modified graph.
|
|
|
|
|
"""
|
|
|
|
|
distiller_pass = SoftLabelDistillerPass(
|
|
|
|
|
self.student_feature_map, self.teacher_feature_map,
|
|
|
|
|
self.student_temperature, self.teacher_temperature,
|
|
|
|
|
self.distillation_loss_weight)
|
|
|
|
|
dis_graph = distiller_pass.apply(graph)
|
|
|
|
|
return dis_graph
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SoftLabelDistillerPass(object):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
student_feature_map,
|
|
|
|
|
teacher_feature_map,
|
|
|
|
|
student_temperature,
|
|
|
|
|
teacher_temperature,
|
|
|
|
|
distillation_loss_weight=1):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
student_feature_map(str): The name of feature map from student network.
|
|
|
|
|
teacher_feature_map(str): The name of feature map from teacher network.
|
|
|
|
|
It's shape should be the same with student network.
|
|
|
|
|
student_temperature(float): Temperature used to divide student_feature_map before softmax_with_cross_entropy.
|
|
|
|
|
teacher_temperature(float): Temperature used to divide teacher_feature_map before softmax_with_cross_entropy.
|
|
|
|
|
distillation_loss_weight(float): The weight of the l2-loss.
|
|
|
|
|
"""
|
|
|
|
|
self.student_feature_map = student_feature_map
|
|
|
|
|
self.teacher_feature_map = teacher_feature_map
|
|
|
|
|
self.student_temperature = student_temperature
|
|
|
|
|
self.teacher_temperature = teacher_temperature
|
|
|
|
|
self.distillation_loss_weight = distillation_loss_weight
|
|
|
|
|
|
|
|
|
|
def apply(self, graph):
|
|
|
|
|
ret_graph = graph
|
|
|
|
|
with program_guard(ret_graph.program):
|
|
|
|
|
|
|
|
|
|
student_feature_map = ret_graph.var(self.student_feature_map)._var
|
|
|
|
|
teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var
|
|
|
|
|
s_fea = student_feature_map / self.student_temperature
|
|
|
|
|
t_fea = teacher_feature_map / self.distillation_loss_weight
|
|
|
|
|
t_fea.stop_gradient = True
|
|
|
|
|
ce_loss = layers.softmax_with_cross_entropy(
|
|
|
|
|
s_fea, t_fea, soft_label=True)
|
|
|
|
|
distillation_loss = ce_loss * self.distillation_loss_weight
|
|
|
|
|
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
|
|
|
|
|
loss = distillation_loss + student_loss
|
|
|
|
|
|
|
|
|
|
ret_graph.out_nodes[
|
|
|
|
|
'soft_label_loss_' + self.student_feature_map + "_" +
|
|
|
|
|
self.teacher_feature_map] = distillation_loss.name
|
|
|
|
|
ret_graph.out_nodes['loss'] = loss.name
|
|
|
|
|
return ret_graph
|
|
|
|
|