|  |  | @ -15,16 +15,16 @@ limitations under the License. */ | 
			
		
	
		
		
			
				
					
					|  |  |  | #ifndef PADDLE_NO_PYTHON |  |  |  | #ifndef PADDLE_NO_PYTHON | 
			
		
	
		
		
			
				
					
					|  |  |  | #include <gtest/gtest.h> |  |  |  | #include <gtest/gtest.h> | 
			
		
	
		
		
			
				
					
					|  |  |  | #include <fstream> |  |  |  | #include <fstream> | 
			
		
	
		
		
			
				
					
					|  |  |  | #include "paddle/utils/Util.h" |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  | #include "paddle/utils/PythonUtil.h" |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  | #include "paddle/gserver/dataproviders/DataProvider.h" |  |  |  | #include "paddle/gserver/dataproviders/DataProvider.h" | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | #include "paddle/utils/PythonUtil.h" | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | #include "paddle/utils/Util.h" | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | P_DEFINE_string(train_list, "unittest.list", "file list for unittest"); |  |  |  | P_DEFINE_string(train_list, "unittest.list", "file list for unittest"); | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | namespace paddle { |  |  |  | namespace paddle { | 
			
		
	
		
		
			
				
					
					|  |  |  | namespace unittest { |  |  |  | namespace unittest { | 
			
		
	
		
		
			
				
					
					|  |  |  | namespace pydp2 { |  |  |  | namespace pydp2 { | 
			
		
	
		
		
			
				
					
					|  |  |  | extern void setOnPoolFilledHook(const std::function<void(size_t)>& func); |  |  |  | extern void setOnPoolFilledHook(const std::function<void(size_t)> &func); | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  | extern void clearOnPoolFilledHook(); |  |  |  | extern void clearOnPoolFilledHook(); | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | }  // namespace pydp2
 |  |  |  | }  // namespace pydp2
 | 
			
		
	
	
		
		
			
				
					|  |  | @ -33,8 +33,8 @@ extern void clearOnPoolFilledHook(); | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | const paddle::real epsilon = 1e-5; |  |  |  | const paddle::real epsilon = 1e-5; | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | static inline int64_t readDataBatch(paddle::DataBatch* batch, |  |  |  | static inline int64_t readDataBatch(paddle::DataBatch *batch, | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                                     const std::string& funcName, |  |  |  |                                     const std::string &funcName, | 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					|  |  |  |                                     int64_t batchSize = 65535) { |  |  |  |                                     int64_t batchSize = 65535) { | 
			
		
	
		
		
			
				
					
					|  |  |  |   paddle::DataConfig config; |  |  |  |   paddle::DataConfig config; | 
			
		
	
		
		
			
				
					
					|  |  |  |   config.set_type("py2"); |  |  |  |   config.set_type("py2"); | 
			
		
	
	
		
		
			
				
					|  |  | @ -143,7 +143,7 @@ TEST(PyDataProvider2, init_hook) { | 
			
		
	
		
		
			
				
					
					|  |  |  |   paddle::DataBatch batch; |  |  |  |   paddle::DataBatch batch; | 
			
		
	
		
		
			
				
					
					|  |  |  |   int64_t num = provider->getNextBatchInternal(100000, &batch); |  |  |  |   int64_t num = provider->getNextBatchInternal(100000, &batch); | 
			
		
	
		
		
			
				
					
					|  |  |  |   ASSERT_EQ(num, 200); |  |  |  |   ASSERT_EQ(num, 200); | 
			
		
	
		
		
			
				
					
					|  |  |  |   auto& mat = batch.getStreams()[0].value; |  |  |  |   auto &mat = batch.getStreams()[0].value; | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |   ASSERT_EQ((size_t)mat->getWidth(), (size_t)20); |  |  |  |   ASSERT_EQ((size_t)mat->getWidth(), (size_t)20); | 
			
		
	
		
		
			
				
					
					|  |  |  |   for (size_t i = 0; i < 200; ++i) { |  |  |  |   for (size_t i = 0; i < 200; ++i) { | 
			
		
	
		
		
			
				
					
					|  |  |  |     for (size_t j = 0; j < 20; ++j) { |  |  |  |     for (size_t j = 0; j < 20; ++j) { | 
			
		
	
	
		
		
			
				
					|  |  | @ -170,7 +170,7 @@ TEST(PyDataProvider2, sparse_no_value_no_seq) { | 
			
		
	
		
		
			
				
					
					|  |  |  |   CHECK(csm != nullptr); |  |  |  |   CHECK(csm != nullptr); | 
			
		
	
		
		
			
				
					
					|  |  |  |   for (int i = 0; i < 200; ++i) { |  |  |  |   for (int i = 0; i < 200; ++i) { | 
			
		
	
		
		
			
				
					
					|  |  |  |     CHECK_EQ(csm->getColNum(i), (size_t)10); |  |  |  |     CHECK_EQ(csm->getColNum(i), (size_t)10); | 
			
		
	
		
		
			
				
					
					|  |  |  |     int* cols = csm->getRowCols(i); |  |  |  |     int *cols = csm->getRowCols(i); | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |     for (int j = 0; j < 10; ++j) { |  |  |  |     for (int j = 0; j < 10; ++j) { | 
			
		
	
		
		
			
				
					
					|  |  |  |       CHECK_EQ(cols[j], (i + 1) * (j + 1)); |  |  |  |       CHECK_EQ(cols[j], (i + 1) * (j + 1)); | 
			
		
	
		
		
			
				
					
					|  |  |  |     } |  |  |  |     } | 
			
		
	
	
		
		
			
				
					|  |  | @ -185,8 +185,8 @@ TEST(PyDataProvider2, sparse_value_no_seq) { | 
			
		
	
		
		
			
				
					
					|  |  |  |   CHECK(csm != nullptr); |  |  |  |   CHECK(csm != nullptr); | 
			
		
	
		
		
			
				
					
					|  |  |  |   for (int i = 0; i < 200; ++i) { |  |  |  |   for (int i = 0; i < 200; ++i) { | 
			
		
	
		
		
			
				
					
					|  |  |  |     CHECK_EQ(csm->getColNum(i), (size_t)10); |  |  |  |     CHECK_EQ(csm->getColNum(i), (size_t)10); | 
			
		
	
		
		
			
				
					
					|  |  |  |     int* cols = csm->getRowCols(i); |  |  |  |     int *cols = csm->getRowCols(i); | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |     real* dat = csm->getRowValues(i); |  |  |  |     real *dat = csm->getRowValues(i); | 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					|  |  |  |     for (int j = 0; j < 10; ++j) { |  |  |  |     for (int j = 0; j < 10; ++j) { | 
			
		
	
		
		
			
				
					
					|  |  |  |       EXPECT_EQ(cols[j], (i + 1) * (j + 1)); |  |  |  |       EXPECT_EQ(cols[j], (i + 1) * (j + 1)); | 
			
		
	
		
		
			
				
					
					|  |  |  |       EXPECT_EQ(dat[j], real(j) / real(i + 1)); |  |  |  |       EXPECT_EQ(dat[j], real(j) / real(i + 1)); | 
			
		
	
	
		
		
			
				
					|  |  | @ -197,7 +197,7 @@ TEST(PyDataProvider2, sparse_value_no_seq) { | 
			
		
	
		
		
			
				
					
					|  |  |  | TEST(PyDataProvider2, index_seq) { |  |  |  | TEST(PyDataProvider2, index_seq) { | 
			
		
	
		
		
			
				
					
					|  |  |  |   paddle::DataBatch batch; |  |  |  |   paddle::DataBatch batch; | 
			
		
	
		
		
			
				
					
					|  |  |  |   CHECK_EQ(readDataBatch(&batch, "test_index_seq"), 200); |  |  |  |   CHECK_EQ(readDataBatch(&batch, "test_index_seq"), 200); | 
			
		
	
		
		
			
				
					
					|  |  |  |   auto& arg = batch.getStreams()[0]; |  |  |  |   auto &arg = batch.getStreams()[0]; | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |   CHECK_EQ((int)arg.ids->getSize(), (200 + 1) * 200 / 2); |  |  |  |   CHECK_EQ((int)arg.ids->getSize(), (200 + 1) * 200 / 2); | 
			
		
	
		
		
			
				
					
					|  |  |  |   size_t tmp = 0; |  |  |  |   size_t tmp = 0; | 
			
		
	
		
		
			
				
					
					|  |  |  |   for (size_t i = 0; i < 200; ++i) {  // CHECK DATA CORRECT
 |  |  |  |   for (size_t i = 0; i < 200; ++i) {  // CHECK DATA CORRECT
 | 
			
		
	
	
		
		
			
				
					|  |  | @ -219,7 +219,7 @@ TEST(PyDataProvider2, index_seq) { | 
			
		
	
		
		
			
				
					
					|  |  |  | TEST(PyDataProvider2, index_sub_seq) { |  |  |  | TEST(PyDataProvider2, index_sub_seq) { | 
			
		
	
		
		
			
				
					
					|  |  |  |   paddle::DataBatch batch; |  |  |  |   paddle::DataBatch batch; | 
			
		
	
		
		
			
				
					
					|  |  |  |   ASSERT_EQ(readDataBatch(&batch, "test_index_sub_seq"), 200); |  |  |  |   ASSERT_EQ(readDataBatch(&batch, "test_index_sub_seq"), 200); | 
			
		
	
		
		
			
				
					
					|  |  |  |   auto& arg = batch.getStreams()[0]; |  |  |  |   auto &arg = batch.getStreams()[0]; | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |   size_t tmp = 0; |  |  |  |   size_t tmp = 0; | 
			
		
	
		
		
			
				
					
					|  |  |  |   for (size_t i = 0; i < 200; ++i) { |  |  |  |   for (size_t i = 0; i < 200; ++i) { | 
			
		
	
		
		
			
				
					
					|  |  |  |     for (size_t j = 0; j < i + 1; ++j) { |  |  |  |     for (size_t j = 0; j < i + 1; ++j) { | 
			
		
	
	
		
		
			
				
					|  |  | @ -268,7 +268,7 @@ TEST(PyDataProvider2, min_pool_size) { | 
			
		
	
		
		
			
				
					
					|  |  |  |     } |  |  |  |     } | 
			
		
	
		
		
			
				
					
					|  |  |  |   }); |  |  |  |   }); | 
			
		
	
		
		
			
				
					
					|  |  |  |   while (true) { |  |  |  |   while (true) { | 
			
		
	
		
		
			
				
					
					|  |  |  |     size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch); |  |  |  |     int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch); | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |     if (realBatchSize) { |  |  |  |     if (realBatchSize) { | 
			
		
	
		
		
			
				
					
					|  |  |  |       totalData -= realBatchSize; |  |  |  |       totalData -= realBatchSize; | 
			
		
	
		
		
			
				
					
					|  |  |  |     } else { |  |  |  |     } else { | 
			
		
	
	
		
		
			
				
					|  |  | @ -291,7 +291,7 @@ TEST(PyDataProvider2, can_over_batch_size) { | 
			
		
	
		
		
			
				
					
					|  |  |  |   provider->reset(); |  |  |  |   provider->reset(); | 
			
		
	
		
		
			
				
					
					|  |  |  |   constexpr size_t batchSize = 100; |  |  |  |   constexpr size_t batchSize = 100; | 
			
		
	
		
		
			
				
					
					|  |  |  |   while (true) { |  |  |  |   while (true) { | 
			
		
	
		
		
			
				
					
					|  |  |  |     size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch); |  |  |  |     int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch); | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |     if (realBatchSize) { |  |  |  |     if (realBatchSize) { | 
			
		
	
		
		
			
				
					
					|  |  |  |       CHECK_LE(realBatchSize, batchSize); |  |  |  |       CHECK_LE(realBatchSize, batchSize); | 
			
		
	
		
		
			
				
					
					|  |  |  |     } else { |  |  |  |     } else { | 
			
		
	
	
		
		
			
				
					|  |  | @ -317,12 +317,12 @@ TEST(PyDataProvider2, input_order) { | 
			
		
	
		
		
			
				
					
					|  |  |  |   provider->reset(); |  |  |  |   provider->reset(); | 
			
		
	
		
		
			
				
					
					|  |  |  |   constexpr size_t batchSize = 100; |  |  |  |   constexpr size_t batchSize = 100; | 
			
		
	
		
		
			
				
					
					|  |  |  |   while (true) { |  |  |  |   while (true) { | 
			
		
	
		
		
			
				
					
					|  |  |  |     size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch); |  |  |  |     int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch); | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |     if (!realBatchSize) { |  |  |  |     if (!realBatchSize) { | 
			
		
	
		
		
			
				
					
					|  |  |  |       break; |  |  |  |       break; | 
			
		
	
		
		
			
				
					
					|  |  |  |     } |  |  |  |     } | 
			
		
	
		
		
			
				
					
					|  |  |  |     ASSERT_EQ(batch.getStreams().size(), (size_t)2); |  |  |  |     ASSERT_EQ(batch.getStreams().size(), static_cast<size_t>(2)); | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |     for (size_t i = 0; i < realBatchSize; ++i) { |  |  |  |     for (int64_t i = 0; i < realBatchSize; ++i) { | 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					|  |  |  |       ASSERT_EQ(batch.getStream(0).ids->getData()[i], 0); |  |  |  |       ASSERT_EQ(batch.getStream(0).ids->getData()[i], 0); | 
			
		
	
		
		
			
				
					
					|  |  |  |       ASSERT_EQ(batch.getStream(1).ids->getData()[i], 1); |  |  |  |       ASSERT_EQ(batch.getStream(1).ids->getData()[i], 1); | 
			
		
	
		
		
			
				
					
					|  |  |  |     } |  |  |  |     } | 
			
		
	
	
		
		
			
				
					|  |  | @ -341,11 +341,11 @@ TEST(PyDataProvider2, test_check) { | 
			
		
	
		
		
			
				
					
					|  |  |  |       paddle::DataProvider::create(config, false)); |  |  |  |       paddle::DataProvider::create(config, false)); | 
			
		
	
		
		
			
				
					
					|  |  |  |   provider->reset(); |  |  |  |   provider->reset(); | 
			
		
	
		
		
			
				
					
					|  |  |  |   while (true) { |  |  |  |   while (true) { | 
			
		
	
		
		
			
				
					
					|  |  |  |     size_t realBatchSize = provider->getNextBatchInternal(100, &batch); |  |  |  |     int64_t realBatchSize = provider->getNextBatchInternal(100, &batch); | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |     if (!realBatchSize) { |  |  |  |     if (!realBatchSize) { | 
			
		
	
		
		
			
				
					
					|  |  |  |       break; |  |  |  |       break; | 
			
		
	
		
		
			
				
					
					|  |  |  |     } else { |  |  |  |     } else { | 
			
		
	
		
		
			
				
					
					|  |  |  |       auto& ivec = batch.getStream(0).ids; |  |  |  |       auto &ivec = batch.getStream(0).ids; | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |       for (size_t i = 0; i < ivec->getSize(); ++i) { |  |  |  |       for (size_t i = 0; i < ivec->getSize(); ++i) { | 
			
		
	
		
		
			
				
					
					|  |  |  |         CHECK_LT(ivec->getData()[i], 10); |  |  |  |         CHECK_LT(ivec->getData()[i], 10); | 
			
		
	
		
		
			
				
					
					|  |  |  |       } |  |  |  |       } | 
			
		
	
	
		
		
			
				
					|  |  | @ -370,7 +370,30 @@ TEST(PyDataProvider2, multiThread) { | 
			
		
	
		
		
			
				
					
					|  |  |  |   provider.reset(); |  |  |  |   provider.reset(); | 
			
		
	
		
		
			
				
					
					|  |  |  | } |  |  |  | } | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | int main(int argc, char** argv) { |  |  |  | TEST(PyDataProvider2, minPoolSizeWithCache) { | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |   paddle::DataConfig config; | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |   config.set_type("py2"); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |   config.set_files(FLAGS_train_list.c_str()); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |   config.set_load_data_module("test_PyDataProvider2"); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |   config.set_load_data_object("test_min_pool_size_with_cache"); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |   config.set_async_load_data(true); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |   std::unique_ptr<paddle::DataProvider> provider( | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       paddle::DataProvider::create(config, false)); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |   paddle::DataBatch batch; | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |   for (int i = 0; i < 10; ++i) { | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     provider->reset(); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     int64_t sum = 0; | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     while (int64_t actualNum = provider->getNextBatch(100, &batch)) { | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |       sum += actualNum; | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     } | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     ASSERT_EQ(1 << 20, sum); | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |   } | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | } | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | int main(int argc, char **argv) { | 
			
		
	
		
		
			
				
					
					|  |  |  |   testing::InitGoogleTest(&argc, argv); |  |  |  |   testing::InitGoogleTest(&argc, argv); | 
			
		
	
		
		
			
				
					
					|  |  |  |   paddle::initMain(argc, argv); |  |  |  |   paddle::initMain(argc, argv); | 
			
		
	
		
		
			
				
					
					|  |  |  |   paddle::initPython(argc, argv); |  |  |  |   paddle::initPython(argc, argv); | 
			
		
	
	
		
		
			
				
					|  |  | 
 |