|
|
|
@ -88,13 +88,15 @@ class L2DistillerPass(object):
|
|
|
|
|
layers.square(student_feature_map - teacher_feature_map))
|
|
|
|
|
|
|
|
|
|
distillation_loss = l2loss * self.distillation_loss_weight
|
|
|
|
|
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
|
|
|
|
|
student_loss = 0
|
|
|
|
|
if 'loss' in ret_graph.out_nodes:
|
|
|
|
|
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
|
|
|
|
|
loss = distillation_loss + student_loss
|
|
|
|
|
|
|
|
|
|
ret_graph.out_nodes['loss'] = loss.name
|
|
|
|
|
ret_graph.out_nodes[
|
|
|
|
|
'l2loss_' + self.student_feature_map + "_" +
|
|
|
|
|
self.teacher_feature_map] = distillation_loss.name
|
|
|
|
|
ret_graph.out_nodes['loss'] = loss.name
|
|
|
|
|
return ret_graph
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -176,12 +178,14 @@ class FSPDistillerPass(object):
|
|
|
|
|
losses.append(l2_loss)
|
|
|
|
|
distillation_loss = layers.sum(
|
|
|
|
|
losses) * self.distillation_loss_weight
|
|
|
|
|
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
|
|
|
|
|
student_loss = 0
|
|
|
|
|
if 'loss' in ret_graph.out_nodes:
|
|
|
|
|
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
|
|
|
|
|
loss = distillation_loss + student_loss
|
|
|
|
|
|
|
|
|
|
ret_graph.out_nodes['loss'] = loss.name
|
|
|
|
|
ret_graph.out_nodes[
|
|
|
|
|
'fsp_distillation_loss'] = distillation_loss.name
|
|
|
|
|
ret_graph.out_nodes['loss'] = loss.name
|
|
|
|
|
return ret_graph
|
|
|
|
|
|
|
|
|
|
def _fsp_matrix(self, fea_map_0, fea_map_1):
|
|
|
|
@ -261,16 +265,18 @@ class SoftLabelDistillerPass(object):
|
|
|
|
|
student_feature_map = ret_graph.var(self.student_feature_map)._var
|
|
|
|
|
teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var
|
|
|
|
|
s_fea = student_feature_map / self.student_temperature
|
|
|
|
|
t_fea = teacher_feature_map / self.distillation_loss_weight
|
|
|
|
|
t_fea = teacher_feature_map / self.teacher_temperature
|
|
|
|
|
t_fea.stop_gradient = True
|
|
|
|
|
ce_loss = layers.softmax_with_cross_entropy(
|
|
|
|
|
s_fea, t_fea, soft_label=True)
|
|
|
|
|
distillation_loss = ce_loss * self.distillation_loss_weight
|
|
|
|
|
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
|
|
|
|
|
student_loss = 0
|
|
|
|
|
if 'loss' in ret_graph.out_nodes:
|
|
|
|
|
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
|
|
|
|
|
loss = distillation_loss + student_loss
|
|
|
|
|
|
|
|
|
|
ret_graph.out_nodes['loss'] = loss.name
|
|
|
|
|
ret_graph.out_nodes[
|
|
|
|
|
'soft_label_loss_' + self.student_feature_map + "_" +
|
|
|
|
|
self.teacher_feature_map] = distillation_loss.name
|
|
|
|
|
ret_graph.out_nodes['loss'] = loss.name
|
|
|
|
|
return ret_graph
|
|
|
|
|