|
|
|
@ -94,13 +94,13 @@ class TestMNISTTest(unittest.TestCase):
|
|
|
|
|
mnist = MNIST(mode='test', transform=transform)
|
|
|
|
|
self.assertTrue(len(mnist) == 10000)
|
|
|
|
|
|
|
|
|
|
for i in range(len(mnist)):
|
|
|
|
|
image, label = mnist[i]
|
|
|
|
|
self.assertTrue(image.shape[0] == 1)
|
|
|
|
|
self.assertTrue(image.shape[1] == 28)
|
|
|
|
|
self.assertTrue(image.shape[2] == 28)
|
|
|
|
|
self.assertTrue(label.shape[0] == 1)
|
|
|
|
|
self.assertTrue(0 <= int(label) <= 9)
|
|
|
|
|
i = np.random.randint(0, len(mnist) - 1)
|
|
|
|
|
image, label = mnist[i]
|
|
|
|
|
self.assertTrue(image.shape[0] == 1)
|
|
|
|
|
self.assertTrue(image.shape[1] == 28)
|
|
|
|
|
self.assertTrue(image.shape[2] == 28)
|
|
|
|
|
self.assertTrue(label.shape[0] == 1)
|
|
|
|
|
self.assertTrue(0 <= int(label) <= 9)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestMNISTTrain(unittest.TestCase):
|
|
|
|
@ -109,13 +109,13 @@ class TestMNISTTrain(unittest.TestCase):
|
|
|
|
|
mnist = MNIST(mode='train', transform=transform)
|
|
|
|
|
self.assertTrue(len(mnist) == 60000)
|
|
|
|
|
|
|
|
|
|
for i in range(len(mnist)):
|
|
|
|
|
image, label = mnist[i]
|
|
|
|
|
self.assertTrue(image.shape[0] == 1)
|
|
|
|
|
self.assertTrue(image.shape[1] == 28)
|
|
|
|
|
self.assertTrue(image.shape[2] == 28)
|
|
|
|
|
self.assertTrue(label.shape[0] == 1)
|
|
|
|
|
self.assertTrue(0 <= int(label) <= 9)
|
|
|
|
|
i = np.random.randint(0, len(mnist) - 1)
|
|
|
|
|
image, label = mnist[i]
|
|
|
|
|
self.assertTrue(image.shape[0] == 1)
|
|
|
|
|
self.assertTrue(image.shape[1] == 28)
|
|
|
|
|
self.assertTrue(image.shape[2] == 28)
|
|
|
|
|
self.assertTrue(label.shape[0] == 1)
|
|
|
|
|
self.assertTrue(0 <= int(label) <= 9)
|
|
|
|
|
|
|
|
|
|
# test cv2 backend
|
|
|
|
|
mnist = MNIST(mode='train', transform=transform, backend='cv2')
|
|
|
|
@ -140,13 +140,13 @@ class TestFASHIONMNISTTest(unittest.TestCase):
|
|
|
|
|
mnist = FashionMNIST(mode='test', transform=transform)
|
|
|
|
|
self.assertTrue(len(mnist) == 10000)
|
|
|
|
|
|
|
|
|
|
for i in range(len(mnist)):
|
|
|
|
|
image, label = mnist[i]
|
|
|
|
|
self.assertTrue(image.shape[0] == 1)
|
|
|
|
|
self.assertTrue(image.shape[1] == 28)
|
|
|
|
|
self.assertTrue(image.shape[2] == 28)
|
|
|
|
|
self.assertTrue(label.shape[0] == 1)
|
|
|
|
|
self.assertTrue(0 <= int(label) <= 9)
|
|
|
|
|
i = np.random.randint(0, len(mnist) - 1)
|
|
|
|
|
image, label = mnist[i]
|
|
|
|
|
self.assertTrue(image.shape[0] == 1)
|
|
|
|
|
self.assertTrue(image.shape[1] == 28)
|
|
|
|
|
self.assertTrue(image.shape[2] == 28)
|
|
|
|
|
self.assertTrue(label.shape[0] == 1)
|
|
|
|
|
self.assertTrue(0 <= int(label) <= 9)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestFASHIONMNISTTrain(unittest.TestCase):
|
|
|
|
@ -155,13 +155,13 @@ class TestFASHIONMNISTTrain(unittest.TestCase):
|
|
|
|
|
mnist = FashionMNIST(mode='train', transform=transform)
|
|
|
|
|
self.assertTrue(len(mnist) == 60000)
|
|
|
|
|
|
|
|
|
|
for i in range(len(mnist)):
|
|
|
|
|
image, label = mnist[i]
|
|
|
|
|
self.assertTrue(image.shape[0] == 1)
|
|
|
|
|
self.assertTrue(image.shape[1] == 28)
|
|
|
|
|
self.assertTrue(image.shape[2] == 28)
|
|
|
|
|
self.assertTrue(label.shape[0] == 1)
|
|
|
|
|
self.assertTrue(0 <= int(label) <= 9)
|
|
|
|
|
i = np.random.randint(0, len(mnist) - 1)
|
|
|
|
|
image, label = mnist[i]
|
|
|
|
|
self.assertTrue(image.shape[0] == 1)
|
|
|
|
|
self.assertTrue(image.shape[1] == 28)
|
|
|
|
|
self.assertTrue(image.shape[2] == 28)
|
|
|
|
|
self.assertTrue(label.shape[0] == 1)
|
|
|
|
|
self.assertTrue(0 <= int(label) <= 9)
|
|
|
|
|
|
|
|
|
|
# test cv2 backend
|
|
|
|
|
mnist = FashionMNIST(mode='train', transform=transform, backend='cv2')
|
|
|
|
|