diff --git a/python/paddle/dataset/flowers.py b/python/paddle/dataset/flowers.py index 914dae348b..7d14cc5dc8 100644 --- a/python/paddle/dataset/flowers.py +++ b/python/paddle/dataset/flowers.py @@ -116,8 +116,8 @@ def reader_creator(data_file, for file in open(file_list): file = file.strip() batch = None - with open(file, 'r') as f: - batch = pickle.load(f) + with open(file, 'rb') as f: + batch = pickle.loads(f.read()) data = batch['data'] labels = batch['label'] for sample, label in zip(data, batch['label']): diff --git a/python/paddle/reader/tests/creator_test.py b/python/paddle/reader/tests/creator_test.py index c4238c12a7..567f38c96e 100644 --- a/python/paddle/reader/tests/creator_test.py +++ b/python/paddle/reader/tests/creator_test.py @@ -29,6 +29,7 @@ import os import unittest import numpy as np import paddle.reader.creator +import six class TestNumpyArray(unittest.TestCase): @@ -37,7 +38,7 @@ class TestNumpyArray(unittest.TestCase): x = np.array(l, np.int32) reader = paddle.reader.creator.np_array(x) for idx, e in enumerate(reader()): - self.assertItemsEqual(e, l[idx]) + six.assertCountEqual(e, l[idx]) class TestTextFile(unittest.TestCase):