|
|
|
@ -2001,9 +2001,15 @@ def nce(input,
|
|
|
|
|
sample_logits = helper.create_tmp_variable(dtype=input.dtype)
|
|
|
|
|
sample_labels = helper.create_tmp_variable(dtype=label.dtype)
|
|
|
|
|
|
|
|
|
|
attrs = {'num_total_classes': int(num_total_classes)}
|
|
|
|
|
if num_neg_samples is not None:
|
|
|
|
|
attrs['num_neg_samples'] = int(num_neg_samples)
|
|
|
|
|
if num_neg_samples is None:
|
|
|
|
|
num_neg_samples = 10
|
|
|
|
|
else:
|
|
|
|
|
num_neg_samples = int(num_neg_samples)
|
|
|
|
|
|
|
|
|
|
attrs = {
|
|
|
|
|
'num_total_classes': int(num_total_classes),
|
|
|
|
|
'num_neg_samples': num_neg_samples
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type='nce',
|
|
|
|
@ -2020,4 +2026,4 @@ def nce(input,
|
|
|
|
|
'SampleLabels': sample_labels
|
|
|
|
|
},
|
|
|
|
|
attrs=attrs)
|
|
|
|
|
return cost
|
|
|
|
|
return cost / (num_neg_samples + 1)
|
|
|
|
|