|
|
|
|
@ -19,6 +19,7 @@ import numpy as np
|
|
|
|
|
import six
|
|
|
|
|
from six.moves import cPickle as pickle
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle.io import Dataset
|
|
|
|
|
from paddle.dataset.common import _check_exists_and_download
|
|
|
|
|
|
|
|
|
|
@ -113,6 +114,8 @@ class Cifar10(Dataset):
|
|
|
|
|
# read dataset into memory
|
|
|
|
|
self._load_data()
|
|
|
|
|
|
|
|
|
|
self.dtype = paddle.get_default_dtype()
|
|
|
|
|
|
|
|
|
|
def _init_url_md5_flag(self):
|
|
|
|
|
self.data_url = CIFAR10_URL
|
|
|
|
|
self.data_md5 = CIFAR10_MD5
|
|
|
|
|
@ -142,7 +145,7 @@ class Cifar10(Dataset):
|
|
|
|
|
image = np.reshape(image, [3, 32, 32])
|
|
|
|
|
if self.transform is not None:
|
|
|
|
|
image = self.transform(image)
|
|
|
|
|
return image, label
|
|
|
|
|
return image.astype(self.dtype), np.array(label).astype('int64')
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.data)
|
|
|
|
|
|