fix test_cosine_similarity_api failed (#26467)

revert-24895-update_cub
Chen Weihang 5 years ago committed by GitHub
parent a7cd61fdd1
commit 7e71ae92bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -680,9 +680,9 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-8):
# [0.99806249 0.9817672 0.94987036]
"""
w12 = sum(elementwise_mul(x1, x2), dim=dim)
w1 = sum(elementwise_mul(x1, x1), dim=dim)
w2 = sum(elementwise_mul(x2, x2), dim=dim)
w12 = sum(elementwise_mul(x1, x2), axis=dim)
w1 = sum(elementwise_mul(x1, x1), axis=dim)
w2 = sum(elementwise_mul(x2, x2), axis=dim)
n12 = sqrt(clamp(w1 * w2, min=eps * eps))
cos_sim = w12 / n12
return cos_sim

Loading…
Cancel
Save