From 283bdc5062be0ba14b0ae3ca6cc211ddaf25fd1c Mon Sep 17 00:00:00 2001
From: gongweibao <weibao.gong@gmail.com>
Date: Mon, 12 Jun 2017 10:29:35 +0800
Subject: [PATCH] fix by helin's comments

---
 paddle/parameter/tests/test_argument.cpp      |  2 +-
 python/paddle/v2/dataset/common.py            | 58 +++++++++++--------
 python/paddle/v2/dataset/tests/common_test.py | 26 +++++++--
 3 files changed, 56 insertions(+), 30 deletions(-)

diff --git a/paddle/parameter/tests/test_argument.cpp b/paddle/parameter/tests/test_argument.cpp
index 81fe4ee397..98ab013548 100644
--- a/paddle/parameter/tests/test_argument.cpp
+++ b/paddle/parameter/tests/test_argument.cpp
@@ -42,7 +42,7 @@ TEST(Argument, poolSequenceWithStride) {
     CHECK_EQ(outStart[3], 4);
     CHECK_EQ(outStart[4], 7);
 
-    CHECK_EQ(stridePositions->getSize(), 8);
+    CHECK_EQ(stridePositions->getSize(), 8UL);
     auto result = reversed ? strideResultReversed : strideResult;
     for (int i = 0; i < 8; i++) {
       CHECK_EQ(stridePositions->getData()[i], result[i]);
diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py
index 89675080e2..8023fa3cf8 100644
--- a/python/paddle/v2/dataset/common.py
+++ b/python/paddle/v2/dataset/common.py
@@ -151,9 +151,14 @@ def cluster_files_reader(files_pattern,
     return reader
 
 
-def convert(output_path, eader, num_shards, name_prefix):
+def convert(output_path,
+            reader,
+            num_shards,
+            name_prefix,
+            max_lines_to_shuffle=10000):
     import recordio
     import cPickle as pickle
+    import random
     """
     Convert data from reader to recordio format files.
 
@@ -161,35 +166,40 @@ def convert(output_path, eader, num_shards, name_prefix):
     :param reader: a data reader, from which the convert program will read data instances.
     :param num_shards: the number of shards that the dataset will be partitioned into.
     :param name_prefix: the name prefix of generated files.
+    :param max_lines_to_shuffle: the max lines numbers to shuffle before writing.
     """
 
-    def open_needs(idx):
-        n = "%s/%s-%05d" % (output_path, name_prefix, idx)
-        w = recordio.writer(n)
-        f = open(n, "w")
-        idx += 1
+    assert num_shards >= 1
+    assert max_lines_to_shuffle >= 1
 
-        return w, f, idx
+    def open_writers():
+        w = []
+        for i in range(0, num_shards):
+            n = "%s/%s-%05d-of-%05d" % (output_path, name_prefix, i,
+                                        num_shards - 1)
+            w.append(recordio.writer(n))
 
-    def close_needs(w, f):
-        if w is not None:
-            w.close()
+        return w
 
-        if f is not None:
-            f.close()
+    def close_writers(w):
+        for i in range(0, num_shards):
+            w[i].close()
 
-    idx = 0
-    w = None
-    f = None
+    def write_data(w, lines):
+        random.shuffle(lines)
+        for i, d in enumerate(lines):
+            d = pickle.dumps(d, pickle.HIGHEST_PROTOCOL)
+            w[i % num_shards].write(d)
 
-    for i, d in enumerate(reader()):
-        if w is None:
-            w, f, idx = open_needs(idx)
-
-        w.write(pickle.dumps(d, pickle.HIGHEST_PROTOCOL))
+    w = open_writers()
+    lines = []
 
-        if i % num_shards == 0 and i >= num_shards:
-            close_needs(w, f)
-            w, f, idx = open_needs(idx)
+    for i, d in enumerate(reader()):
+        lines.append(d)
+        if i % max_lines_to_shuffle == 0 and i >= max_lines_to_shuffle:
+            write_data(w, lines)
+            lines = []
+            continue
 
-    close_needs(w, f)
+    write_data(w, lines)
+    close_writers(w)
diff --git a/python/paddle/v2/dataset/tests/common_test.py b/python/paddle/v2/dataset/tests/common_test.py
index 3120026e1e..cfa194eba3 100644
--- a/python/paddle/v2/dataset/tests/common_test.py
+++ b/python/paddle/v2/dataset/tests/common_test.py
@@ -58,20 +58,36 @@ class TestCommon(unittest.TestCase):
             self.assertEqual(e, str("0"))
 
     def test_convert(self):
+        record_num = 10
+        num_shards = 4
+
         def test_reader():
             def reader():
-                for x in xrange(10):
+                for x in xrange(record_num):
                     yield x
 
             return reader
 
         path = tempfile.mkdtemp()
-
         paddle.v2.dataset.common.convert(path,
-                                         test_reader(), 4, 'random_images')
+                                         test_reader(), num_shards,
+                                         'random_images')
 
-        files = glob.glob(temp_path + '/random_images-*')
-        self.assertEqual(len(files), 3)
+        files = glob.glob(path + '/random_images-*')
+        self.assertEqual(len(files), num_shards)
+
+        recs = []
+        for i in range(0, num_shards):
+            n = "%s/random_images-%05d-of-%05d" % (path, i, num_shards - 1)
+            r = recordio.reader(n)
+            while True:
+                d = r.read()
+                if d is None:
+                    break
+                recs.append(d)
+
+        recs.sort()
+        self.assertEqual(total, record_num)
 
 
 if __name__ == '__main__':