soft_label_distiller fix, test=develop (#20645)

revert-20712-fix_depthwise_conv
Bai Yifan 6 years ago committed by GitHub
parent 003f369bb2
commit ffec9195e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -264,11 +264,14 @@ class SoftLabelDistillerPass(object):
student_feature_map = ret_graph.var(self.student_feature_map)._var
teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var
s_fea = student_feature_map / self.student_temperature
t_fea = teacher_feature_map / self.teacher_temperature
s_fea = layers.softmax(student_feature_map /
self.student_temperature)
t_fea = layers.softmax(teacher_feature_map /
self.teacher_temperature)
t_fea.stop_gradient = True
ce_loss = layers.softmax_with_cross_entropy(
s_fea, t_fea, soft_label=True)
ce_loss = layres.reduce_mean(
layers.cross_entropy(
s_fea, t_fea, soft_label=True))
distillation_loss = ce_loss * self.distillation_loss_weight
student_loss = 0
if 'loss' in ret_graph.out_nodes:

Loading…
Cancel
Save