|
|
|
@ -263,7 +263,7 @@ class TestSampleLogitsOpV2(OpTest):
|
|
|
|
|
'remove_accidental_hits': remove_accidental_hits,
|
|
|
|
|
'seed': seed
|
|
|
|
|
}
|
|
|
|
|
self.inputs = {'Logits': logits, 'Label': label}
|
|
|
|
|
self.inputs = {'Logits': logits, 'Label': label.astype(np.int64)}
|
|
|
|
|
|
|
|
|
|
def set_data(self, num_classes, num_samples, seed, remove_accidental_hits):
|
|
|
|
|
label = np.array([[6, 12, 15, 5, 1], [0, 9, 4, 1, 10],
|
|
|
|
@ -347,7 +347,7 @@ class TestSampleLogitsOpV3(OpTest):
|
|
|
|
|
'remove_accidental_hits': remove_accidental_hits,
|
|
|
|
|
'seed': seed
|
|
|
|
|
}
|
|
|
|
|
self.inputs = {'Logits': logits, 'Label': label}
|
|
|
|
|
self.inputs = {'Logits': logits, 'Label': label.astype(np.int64)}
|
|
|
|
|
|
|
|
|
|
def set_data(self, num_classes, num_samples, seed, remove_accidental_hits):
|
|
|
|
|
label = [52, 2, 2, 17, 96, 2, 17, 96, 37, 2]
|
|
|
|
|