|
|
|
@ -25,17 +25,18 @@ class TestSamplingIdOp(OpTest):
|
|
|
|
|
self.op_type = "sampling_id"
|
|
|
|
|
self.use_mkldnn = False
|
|
|
|
|
self.init_kernel_type()
|
|
|
|
|
X = np.random.random((3, 4)).astype('float32')
|
|
|
|
|
self.inputs = {"X": X}
|
|
|
|
|
Y = np.random.random(3).astype('float32')
|
|
|
|
|
self.outputs = {'Out': Y}
|
|
|
|
|
self.X = np.random.random((8, 4)).astype('float32')
|
|
|
|
|
self.inputs = {"X": self.X}
|
|
|
|
|
self.Y = np.random.random(8).astype('float32')
|
|
|
|
|
self.outputs = {'Out': self.Y}
|
|
|
|
|
self.attrs = {'use_mkldnn': self.use_mkldnn}
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
|
self.check_output()
|
|
|
|
|
self.check_output_customized(self.verify_output)
|
|
|
|
|
|
|
|
|
|
def test_check_grad(self):
|
|
|
|
|
self.check_grad(['X'], 'Out')
|
|
|
|
|
def verify_output(self, outs):
|
|
|
|
|
out = np.array(outs[0])
|
|
|
|
|
self.assertEqual(len(out), len(self.Y))
|
|
|
|
|
|
|
|
|
|
def init_kernel_type(self):
|
|
|
|
|
pass
|
|
|
|
|