fix second grad output of BCEWithLogicLoss.

pull/12953/head
liuxiao93 4 years ago
parent c9ac0da266
commit 522aad27b6

@ -1090,19 +1090,25 @@ def get_bprop_ce_with_logits_loss(self):
add = P.Add() add = P.Add()
sub = P.Sub() sub = P.Sub()
size = P.Size() size = P.Size()
neg = P.Neg()
log = P.Log()
def bprop(predict, target, weight, pos_weight, out, dout): def bprop(predict, target, weight, pos_weight, out, dout):
sigmoid_input = sigmoid(predict) sigmoid_input = sigmoid(predict)
if pos_weight is not None: if pos_weight is not None:
t = mul(target, pos_weight) t = mul(target, pos_weight)
dx = mul(sub(mul(sub(add(t, 1), target), sigmoid_input), t), dout) dx = mul(sub(mul(sub(add(t, 1), target), sigmoid_input), t), dout)
grad_target = mul(sub(log(sub(1, sigmoid_input)), mul(pos_weight, log(sigmoid_input))), dout)
else: else:
dx = mul((sigmoid_input - target), dout) dx = mul((sigmoid_input - target), dout)
grad_target = mul(predict, neg(dout))
if weight is not None: if weight is not None:
dx = mul(dx, weight) dx = mul(dx, weight)
grad_target = mul(grad_target, weight)
if reduction == 'mean': if reduction == 'mean':
dx = dx / size(dx) dx = dx / size(dx)
return dx, zeros_like(target), zeros_like(weight), zeros_like(pos_weight) grad_target = grad_target / size(target)
return dx, grad_target, zeros_like(weight), zeros_like(pos_weight)
return bprop return bprop

Loading…
Cancel
Save