test=develop

test=develop
revert-16045-imperative_remove_desc
ceci3 6 years ago
parent 5f343b0e3a
commit d3656ff304

@ -10695,13 +10695,11 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002):
labels = reshape(labels, shape=[batch_size, 1], inplace=True)
labels = expand(labels, expand_times=[1, batch_size])
labels = control_flow.equal(
labels, transpose(
labels, perm=[1, 0])).astype('float32')
labels = equal(labels, transpose(labels, perm=[1, 0])).astype('float32')
labels = labels / reduce_sum(labels, dim=1, keep_dim=True)
l2loss = reduce_mean(reduce_sum(ops.square(anchor), 1)) \
+ reduce_mean(reduce_sum(ops.square(positive), 1))
l2loss = reduce_mean(reduce_sum(square(anchor), 1)) \
+ reduce_mean(reduce_sum(square(positive), 1))
l2loss = l2loss * Beta * l2_reg
similarity_matrix = matmul(

Loading…
Cancel
Save