diff --git a/python/paddle/vision/datasets/cifar.py b/python/paddle/vision/datasets/cifar.py index 631892ee4d..c531f3d0e4 100644 --- a/python/paddle/vision/datasets/cifar.py +++ b/python/paddle/vision/datasets/cifar.py @@ -19,6 +19,7 @@ import numpy as np import six from six.moves import cPickle as pickle +import paddle from paddle.io import Dataset from paddle.dataset.common import _check_exists_and_download @@ -113,6 +114,8 @@ class Cifar10(Dataset): # read dataset into memory self._load_data() + self.dtype = paddle.get_default_dtype() + def _init_url_md5_flag(self): self.data_url = CIFAR10_URL self.data_md5 = CIFAR10_MD5 @@ -142,7 +145,7 @@ class Cifar10(Dataset): image = np.reshape(image, [3, 32, 32]) if self.transform is not None: image = self.transform(image) - return image, label + return image.astype(self.dtype), np.array(label).astype('int64') def __len__(self): return len(self.data) diff --git a/python/paddle/vision/datasets/flowers.py b/python/paddle/vision/datasets/flowers.py index 1c0f41123e..2251333fd8 100644 --- a/python/paddle/vision/datasets/flowers.py +++ b/python/paddle/vision/datasets/flowers.py @@ -21,6 +21,7 @@ import numpy as np import scipy.io as scio from PIL import Image +import paddle from paddle.io import Dataset from paddle.dataset.common import _check_exists_and_download @@ -104,6 +105,8 @@ class Flowers(Dataset): # read dataset into memory self._load_anno() + self.dtype = paddle.get_default_dtype() + def _load_anno(self): self.name2mem = {} self.data_tar = tarfile.open(self.data_file) @@ -124,7 +127,7 @@ class Flowers(Dataset): if self.transform is not None: image = self.transform(image) - return image, label.astype('int64') + return image.astype(self.dtype), label.astype('int64') def __len__(self): return len(self.indexes) diff --git a/python/paddle/vision/datasets/folder.py b/python/paddle/vision/datasets/folder.py index 8a3053abef..19d913504b 100644 --- a/python/paddle/vision/datasets/folder.py +++ b/python/paddle/vision/datasets/folder.py @@ -15,6 +15,7 @@ import os import sys +import paddle from paddle.io import Dataset from paddle.utils import try_import @@ -143,6 +144,8 @@ class DatasetFolder(Dataset): self.samples = samples self.targets = [s[1] for s in samples] + self.dtype = paddle.get_default_dtype() + def _find_classes(self, dir): """ Finds the class folders in a dataset. diff --git a/python/paddle/vision/datasets/mnist.py b/python/paddle/vision/datasets/mnist.py index 597d404644..16c39e56ef 100644 --- a/python/paddle/vision/datasets/mnist.py +++ b/python/paddle/vision/datasets/mnist.py @@ -19,6 +19,7 @@ import gzip import struct import numpy as np +import paddle from paddle.io import Dataset from paddle.dataset.common import _check_exists_and_download @@ -95,6 +96,8 @@ class MNIST(Dataset): # read dataset into memory self._parse_dataset() + self.dtype = paddle.get_default_dtype() + def _parse_dataset(self, buffer_size=100): self.images = [] self.labels = [] @@ -145,7 +148,7 @@ class MNIST(Dataset): image = np.reshape(image, [1, 28, 28]) if self.transform is not None: image = self.transform(image) - return image, label + return image.astype(self.dtype), label.astype('int64') def __len__(self): return len(self.labels) diff --git a/python/paddle/vision/datasets/voc2012.py b/python/paddle/vision/datasets/voc2012.py index ae14ea3016..5fc9d7c381 100644 --- a/python/paddle/vision/datasets/voc2012.py +++ b/python/paddle/vision/datasets/voc2012.py @@ -19,6 +19,7 @@ import tarfile import numpy as np from PIL import Image +import paddle from paddle.io import Dataset from paddle.dataset.common import _check_exists_and_download @@ -96,6 +97,8 @@ class VOC2012(Dataset): # read dataset into memory self._load_anno() + self.dtype = paddle.get_default_dtype() + def _load_anno(self): self.name2mem = {} self.data_tar = tarfile.open(self.data_file) @@ -127,7 +130,7 @@ class VOC2012(Dataset): label = np.array(label) if self.transform is not None: data = self.transform(data) - return data, label + return data.astype(self.dtype), label.astype(self.dtype) def __len__(self): return len(self.data)