|
|
|
@ -264,11 +264,14 @@ class SoftLabelDistillerPass(object):
|
|
|
|
|
|
|
|
|
|
student_feature_map = ret_graph.var(self.student_feature_map)._var
|
|
|
|
|
teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var
|
|
|
|
|
s_fea = student_feature_map / self.student_temperature
|
|
|
|
|
t_fea = teacher_feature_map / self.teacher_temperature
|
|
|
|
|
s_fea = layers.softmax(student_feature_map /
|
|
|
|
|
self.student_temperature)
|
|
|
|
|
t_fea = layers.softmax(teacher_feature_map /
|
|
|
|
|
self.teacher_temperature)
|
|
|
|
|
t_fea.stop_gradient = True
|
|
|
|
|
ce_loss = layers.softmax_with_cross_entropy(
|
|
|
|
|
s_fea, t_fea, soft_label=True)
|
|
|
|
|
ce_loss = layres.reduce_mean(
|
|
|
|
|
layers.cross_entropy(
|
|
|
|
|
s_fea, t_fea, soft_label=True))
|
|
|
|
|
distillation_loss = ce_loss * self.distillation_loss_weight
|
|
|
|
|
student_loss = 0
|
|
|
|
|
if 'loss' in ret_graph.out_nodes:
|
|
|
|
|