use paddle.get_default_dtype in vision datasets. test=develop (#27426)

revert-27520-disable_pr
Kaipeng Deng 5 years ago committed by GitHub
parent fc61efd736
commit 4bd7aa2566
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,6 +19,7 @@ import numpy as np
import six
from six.moves import cPickle as pickle
import paddle
from paddle.io import Dataset
from paddle.dataset.common import _check_exists_and_download
@ -113,6 +114,8 @@ class Cifar10(Dataset):
# read dataset into memory
self._load_data()
self.dtype = paddle.get_default_dtype()
def _init_url_md5_flag(self):
self.data_url = CIFAR10_URL
self.data_md5 = CIFAR10_MD5
@ -142,7 +145,7 @@ class Cifar10(Dataset):
image = np.reshape(image, [3, 32, 32])
if self.transform is not None:
image = self.transform(image)
return image, label
return image.astype(self.dtype), np.array(label).astype('int64')
def __len__(self):
return len(self.data)

@ -21,6 +21,7 @@ import numpy as np
import scipy.io as scio
from PIL import Image
import paddle
from paddle.io import Dataset
from paddle.dataset.common import _check_exists_and_download
@ -104,6 +105,8 @@ class Flowers(Dataset):
# read dataset into memory
self._load_anno()
self.dtype = paddle.get_default_dtype()
def _load_anno(self):
self.name2mem = {}
self.data_tar = tarfile.open(self.data_file)
@ -124,7 +127,7 @@ class Flowers(Dataset):
if self.transform is not None:
image = self.transform(image)
return image, label.astype('int64')
return image.astype(self.dtype), label.astype('int64')
def __len__(self):
return len(self.indexes)

@ -15,6 +15,7 @@
import os
import sys
import paddle
from paddle.io import Dataset
from paddle.utils import try_import
@ -143,6 +144,8 @@ class DatasetFolder(Dataset):
self.samples = samples
self.targets = [s[1] for s in samples]
self.dtype = paddle.get_default_dtype()
def _find_classes(self, dir):
"""
Finds the class folders in a dataset.

@ -19,6 +19,7 @@ import gzip
import struct
import numpy as np
import paddle
from paddle.io import Dataset
from paddle.dataset.common import _check_exists_and_download
@ -95,6 +96,8 @@ class MNIST(Dataset):
# read dataset into memory
self._parse_dataset()
self.dtype = paddle.get_default_dtype()
def _parse_dataset(self, buffer_size=100):
self.images = []
self.labels = []
@ -145,7 +148,7 @@ class MNIST(Dataset):
image = np.reshape(image, [1, 28, 28])
if self.transform is not None:
image = self.transform(image)
return image, label
return image.astype(self.dtype), label.astype('int64')
def __len__(self):
return len(self.labels)

@ -19,6 +19,7 @@ import tarfile
import numpy as np
from PIL import Image
import paddle
from paddle.io import Dataset
from paddle.dataset.common import _check_exists_and_download
@ -96,6 +97,8 @@ class VOC2012(Dataset):
# read dataset into memory
self._load_anno()
self.dtype = paddle.get_default_dtype()
def _load_anno(self):
self.name2mem = {}
self.data_tar = tarfile.open(self.data_file)
@ -127,7 +130,7 @@ class VOC2012(Dataset):
label = np.array(label)
if self.transform is not None:
data = self.transform(data)
return data, label
return data.astype(self.dtype), label.astype(self.dtype)
def __len__(self):
return len(self.data)

Loading…
Cancel
Save