|
|
@ -305,7 +305,8 @@ class TestSampleLogitsOpV2(OpTest):
|
|
|
|
out = sample_logits(self.inputs["Logits"], self.inputs["Label"],
|
|
|
|
out = sample_logits(self.inputs["Logits"], self.inputs["Label"],
|
|
|
|
self.attrs["num_samples"], self.attrs["seed"],
|
|
|
|
self.attrs["num_samples"], self.attrs["seed"],
|
|
|
|
self.attrs["remove_accidental_hits"], True,
|
|
|
|
self.attrs["remove_accidental_hits"], True,
|
|
|
|
self.fetched_samples, self.probabilities)
|
|
|
|
self.fetched_samples.astype(np.int64),
|
|
|
|
|
|
|
|
self.probabilities)
|
|
|
|
self.outputs = {
|
|
|
|
self.outputs = {
|
|
|
|
'SampledLogits': out[0],
|
|
|
|
'SampledLogits': out[0],
|
|
|
|
'Samples': out[1],
|
|
|
|
'Samples': out[1],
|
|
|
@ -365,7 +366,6 @@ class TestSampleLogitsOpV3(OpTest):
|
|
|
|
batch_size, num_true = label.shape
|
|
|
|
batch_size, num_true = label.shape
|
|
|
|
use_custom_samples = False
|
|
|
|
use_custom_samples = False
|
|
|
|
|
|
|
|
|
|
|
|
#import pdb; pdb.set_trace()
|
|
|
|
|
|
|
|
num_sampled_classes = num_samples + num_true
|
|
|
|
num_sampled_classes = num_samples + num_true
|
|
|
|
logits = np.random.randn(batch_size, num_classes)
|
|
|
|
logits = np.random.randn(batch_size, num_classes)
|
|
|
|
|
|
|
|
|
|
|
@ -391,7 +391,8 @@ class TestSampleLogitsOpV3(OpTest):
|
|
|
|
out = sample_logits(self.inputs["Logits"], self.inputs["Label"],
|
|
|
|
out = sample_logits(self.inputs["Logits"], self.inputs["Label"],
|
|
|
|
self.attrs["num_samples"], self.attrs["seed"],
|
|
|
|
self.attrs["num_samples"], self.attrs["seed"],
|
|
|
|
self.attrs["remove_accidental_hits"], True,
|
|
|
|
self.attrs["remove_accidental_hits"], True,
|
|
|
|
self.fetched_samples, self.probabilities)
|
|
|
|
self.fetched_samples.astype(np.int64),
|
|
|
|
|
|
|
|
self.probabilities)
|
|
|
|
self.outputs = {
|
|
|
|
self.outputs = {
|
|
|
|
'SampledLogits': out[0],
|
|
|
|
'SampledLogits': out[0],
|
|
|
|
'Samples': out[1],
|
|
|
|
'Samples': out[1],
|
|
|
|