@ -1223,22 +1223,50 @@ def cross_entropy(input,
ignore_index = ignore_index ,
axis = axis )
if weight is not None :
weight_gather = core . ops . gather_nd ( weight , label ) #trans to sample
weight_gather = core . ops . gather_nd (
weight , label ) #trans weight from class to sample, shape:N
input_shape = list ( label . shape )
weight_gather_reshape , _ = core . ops . reshape2 ( weight_gather , None ,
' shape ' , input_shape )
weight_gather_reshape = reshape ( weight_gather , shape = input_shape )
out = core . ops . elementwise_mul ( out , weight_gather_reshape )
if reduction == " sum " :
# because of softmax_with_cross_entropy op's inner logic,
# in the out tensor of this op, the loss of sample with class_index==ignore_index is 0
# so, reduce_sum all directly is ok
return core . ops . reduce_sum ( out , ' reduce_all ' , True )
elif reduction == " mean " :
if weight is not None :
#1. if weight==none,
# numerator: reduce_sum all loss directly is ok causeof softmax_with_cross_entropy's inner logic
# denominator: count sample num with class_index!=ignore_index
#2. else
# numerator: loss's weighted sum
# denominator: cal the sum of weight where the sample's class_index!=ignore_index
if ignore_index != - 100 :
out_sum = core . ops . reduce_sum ( out , ' reduce_all ' , True )
#for each label[i],set 1 or 0, according to ignore_index
#mask[i]=0, if label[i]==ignore_index
#mask[i]=1, otherwise
mask = ( label != ignore_index )
if ( weight is None ) :
mask = paddle . cast ( mask , dtype = out_sum . dtype )
count = core . ops . reduce_sum ( mask , ' reduce_all ' , True )
ret = out_sum / count
else :
mask = paddle . cast ( mask , weight_gather_reshape . dtype )
weight_ignored = core . ops . elementwise_mul (
mask , weight_gather_reshape )
weight_sum = core . ops . reduce_sum ( weight_ignored ,
' reduce_all ' , True )
ret = out_sum / weight_sum
return ret
elif weight is not None :
out_sum = core . ops . reduce_sum ( out , ' reduce_all ' , True )
total_weight = core . ops . reduce_sum ( weight_gather_reshape ,
' reduce_all ' , True )
return out_sum / total_weight
else :
return core . ops . mean ( out )
else :
if input_dims - 1 == label_dims :
out = paddle . squeeze ( out , axis = axis )
@ -1258,7 +1286,8 @@ def cross_entropy(input,
fluid . data_feeder . check_variable_and_dtype (
weight , ' weight ' , [ ' float32 ' , ' float64 ' ] , ' softmax_cross_entropy ' )
weight_name = name if reduction == ' none ' else None
weight_gather = paddle . gather_nd ( weight , label ) #trans to sample
weight_gather = paddle . gather_nd (
weight , label ) #trans weight from class to sample, shape:N
input_shape = list ( label . shape )
weight_gather_reshape = reshape ( weight_gather , shape = input_shape )
out = paddle . multiply ( out , weight_gather_reshape , name = weight_name )
@ -1266,12 +1295,29 @@ def cross_entropy(input,
if reduction == " sum " :
return paddle . sum ( out , name = name )
elif reduction == " mean " :
if weight is not None :
if ignore_index != - 100 :
out_sum = paddle . sum ( out , name = name )
#for each label[i],set 1 or 0, according to ignore_index
#mask[i]=0, if label[i]==ignore_index
#mask[i]=1, otherwise
mask = ( label != ignore_index )
if ( weight is None ) :
mask = paddle . cast ( mask , dtype = out_sum . dtype )
count = paddle . sum ( mask , name = name )
ret = out_sum / count
else :
mask = paddle . cast ( mask , weight_gather_reshape . dtype )
weight_ignored = paddle . multiply ( mask , weight_gather_reshape )
weight_sum = paddle . sum ( weight_ignored , name = name )
ret = out_sum / weight_sum
return ret
elif weight is not None :
out_sum = paddle . sum ( out , name = name )
total_weight = paddle . sum ( weight_gather_reshape )
return out_sum / total_weight
else :
return paddle . mean ( out , name = name )
else :
if input_dims - 1 == label_dims :
out = paddle . squeeze ( out , axis = axis )