|
|
|
@ -35,7 +35,7 @@ def nce(input, weight, bias, sample_weight, labels, num_classes,
|
|
|
|
|
o = sample_out[i]
|
|
|
|
|
cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b))
|
|
|
|
|
out[samples[i][0]] += cost * samples[i][3]
|
|
|
|
|
return (out, np.array(sample_out).reshape(
|
|
|
|
|
return (out[:, np.newaxis], np.array(sample_out).reshape(
|
|
|
|
|
batch_size, num_sample_class + num_true_class),
|
|
|
|
|
np.array(sample_labels).reshape(batch_size,
|
|
|
|
|
num_sample_class + num_true_class))
|
|
|
|
@ -43,16 +43,16 @@ def nce(input, weight, bias, sample_weight, labels, num_classes,
|
|
|
|
|
|
|
|
|
|
class TestNCE(OpTest):
|
|
|
|
|
def generate_data(self, dim, batch_size, num_classes, num_true_class,
|
|
|
|
|
num_sampled_classes):
|
|
|
|
|
num_neg_samples):
|
|
|
|
|
input = np.random.randn(batch_size, dim).astype(np.float32)
|
|
|
|
|
weight = np.random.randn(num_classes, dim).astype(np.float32)
|
|
|
|
|
bias = np.random.randn(num_classes).astype(np.float32)
|
|
|
|
|
sample_weight = np.random.randn(batch_size).astype(np.float32)
|
|
|
|
|
labels = np.random.randint(0, num_classes, (batch_size, num_true_class))
|
|
|
|
|
self.attrs = {
|
|
|
|
|
'num_classes': num_classes,
|
|
|
|
|
'num_sampled_classes': num_sampled_classes,
|
|
|
|
|
'sampled_labels': range(num_sampled_classes)
|
|
|
|
|
'num_total_classes': num_classes,
|
|
|
|
|
'num_neg_samples': num_neg_samples,
|
|
|
|
|
'custom_neg_classes': range(num_neg_samples)
|
|
|
|
|
}
|
|
|
|
|
self.inputs = {
|
|
|
|
|
'Input': input,
|
|
|
|
@ -68,8 +68,8 @@ class TestNCE(OpTest):
|
|
|
|
|
def compute(self):
|
|
|
|
|
out = nce(self.inputs['Input'], self.inputs['Weight'],
|
|
|
|
|
self.inputs['Bias'], self.inputs['SampleWeight'],
|
|
|
|
|
self.inputs['Label'], self.attrs['num_classes'],
|
|
|
|
|
self.attrs['num_sampled_classes'])
|
|
|
|
|
self.inputs['Label'], self.attrs['num_total_classes'],
|
|
|
|
|
self.attrs['num_neg_samples'])
|
|
|
|
|
self.outputs = {
|
|
|
|
|
'Cost': out[0],
|
|
|
|
|
'SampleLogits': out[1],
|
|
|
|
|