|
|
@ -680,9 +680,9 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-8):
|
|
|
|
# [0.99806249 0.9817672 0.94987036]
|
|
|
|
# [0.99806249 0.9817672 0.94987036]
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
w12 = sum(elementwise_mul(x1, x2), dim=dim)
|
|
|
|
w12 = sum(elementwise_mul(x1, x2), axis=dim)
|
|
|
|
w1 = sum(elementwise_mul(x1, x1), dim=dim)
|
|
|
|
w1 = sum(elementwise_mul(x1, x1), axis=dim)
|
|
|
|
w2 = sum(elementwise_mul(x2, x2), dim=dim)
|
|
|
|
w2 = sum(elementwise_mul(x2, x2), axis=dim)
|
|
|
|
n12 = sqrt(clamp(w1 * w2, min=eps * eps))
|
|
|
|
n12 = sqrt(clamp(w1 * w2, min=eps * eps))
|
|
|
|
cos_sim = w12 / n12
|
|
|
|
cos_sim = w12 / n12
|
|
|
|
return cos_sim
|
|
|
|
return cos_sim
|
|
|
|