fix cross_entropy bug of the axis parameter in log_softmax (#27311)

revert-27520-disable_pr
chajchaj 5 years ago committed by GitHub
parent d28162b97f
commit fef94eac4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,7 +26,7 @@ def stable_softmax(x):
return exps / np.sum(exps)
def log_softmax(x, axis=-1):
def log_softmax(x, axis=1):
softmax_out = np.apply_along_axis(stable_softmax, axis, x)
return np.log(softmax_out)

@ -1093,7 +1093,7 @@ def cross_entropy(input,
" 'none', but received %s, which is not allowed." % reduction)
#step 1. log_softmax
log_softmax_out = paddle.nn.functional.log_softmax(input)
log_softmax_out = paddle.nn.functional.log_softmax(input, axis=1)
if weight is not None and not isinstance(weight, Variable):
raise ValueError(
"The weight' is not a Variable, please convert to Variable.")

Loading…
Cancel
Save