|
|
|
|
@ -15,16 +15,16 @@ limitations under the License. */
|
|
|
|
|
#ifndef PADDLE_NO_PYTHON
|
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
|
#include <fstream>
|
|
|
|
|
#include "paddle/utils/Util.h"
|
|
|
|
|
#include "paddle/utils/PythonUtil.h"
|
|
|
|
|
#include "paddle/gserver/dataproviders/DataProvider.h"
|
|
|
|
|
#include "paddle/utils/PythonUtil.h"
|
|
|
|
|
#include "paddle/utils/Util.h"
|
|
|
|
|
|
|
|
|
|
P_DEFINE_string(train_list, "unittest.list", "file list for unittest");
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace unittest {
|
|
|
|
|
namespace pydp2 {
|
|
|
|
|
extern void setOnPoolFilledHook(const std::function<void(size_t)>& func);
|
|
|
|
|
extern void setOnPoolFilledHook(const std::function<void(size_t)> &func);
|
|
|
|
|
extern void clearOnPoolFilledHook();
|
|
|
|
|
|
|
|
|
|
} // namespace pydp2
|
|
|
|
|
@ -33,8 +33,8 @@ extern void clearOnPoolFilledHook();
|
|
|
|
|
|
|
|
|
|
const paddle::real epsilon = 1e-5;
|
|
|
|
|
|
|
|
|
|
static inline int64_t readDataBatch(paddle::DataBatch* batch,
|
|
|
|
|
const std::string& funcName,
|
|
|
|
|
static inline int64_t readDataBatch(paddle::DataBatch *batch,
|
|
|
|
|
const std::string &funcName,
|
|
|
|
|
int64_t batchSize = 65535) {
|
|
|
|
|
paddle::DataConfig config;
|
|
|
|
|
config.set_type("py2");
|
|
|
|
|
@ -143,7 +143,7 @@ TEST(PyDataProvider2, init_hook) {
|
|
|
|
|
paddle::DataBatch batch;
|
|
|
|
|
int64_t num = provider->getNextBatchInternal(100000, &batch);
|
|
|
|
|
ASSERT_EQ(num, 200);
|
|
|
|
|
auto& mat = batch.getStreams()[0].value;
|
|
|
|
|
auto &mat = batch.getStreams()[0].value;
|
|
|
|
|
ASSERT_EQ((size_t)mat->getWidth(), (size_t)20);
|
|
|
|
|
for (size_t i = 0; i < 200; ++i) {
|
|
|
|
|
for (size_t j = 0; j < 20; ++j) {
|
|
|
|
|
@ -170,7 +170,7 @@ TEST(PyDataProvider2, sparse_no_value_no_seq) {
|
|
|
|
|
CHECK(csm != nullptr);
|
|
|
|
|
for (int i = 0; i < 200; ++i) {
|
|
|
|
|
CHECK_EQ(csm->getColNum(i), (size_t)10);
|
|
|
|
|
int* cols = csm->getRowCols(i);
|
|
|
|
|
int *cols = csm->getRowCols(i);
|
|
|
|
|
for (int j = 0; j < 10; ++j) {
|
|
|
|
|
CHECK_EQ(cols[j], (i + 1) * (j + 1));
|
|
|
|
|
}
|
|
|
|
|
@ -185,8 +185,8 @@ TEST(PyDataProvider2, sparse_value_no_seq) {
|
|
|
|
|
CHECK(csm != nullptr);
|
|
|
|
|
for (int i = 0; i < 200; ++i) {
|
|
|
|
|
CHECK_EQ(csm->getColNum(i), (size_t)10);
|
|
|
|
|
int* cols = csm->getRowCols(i);
|
|
|
|
|
real* dat = csm->getRowValues(i);
|
|
|
|
|
int *cols = csm->getRowCols(i);
|
|
|
|
|
real *dat = csm->getRowValues(i);
|
|
|
|
|
for (int j = 0; j < 10; ++j) {
|
|
|
|
|
EXPECT_EQ(cols[j], (i + 1) * (j + 1));
|
|
|
|
|
EXPECT_EQ(dat[j], real(j) / real(i + 1));
|
|
|
|
|
@ -197,7 +197,7 @@ TEST(PyDataProvider2, sparse_value_no_seq) {
|
|
|
|
|
TEST(PyDataProvider2, index_seq) {
|
|
|
|
|
paddle::DataBatch batch;
|
|
|
|
|
CHECK_EQ(readDataBatch(&batch, "test_index_seq"), 200);
|
|
|
|
|
auto& arg = batch.getStreams()[0];
|
|
|
|
|
auto &arg = batch.getStreams()[0];
|
|
|
|
|
CHECK_EQ((int)arg.ids->getSize(), (200 + 1) * 200 / 2);
|
|
|
|
|
size_t tmp = 0;
|
|
|
|
|
for (size_t i = 0; i < 200; ++i) { // CHECK DATA CORRECT
|
|
|
|
|
@ -219,7 +219,7 @@ TEST(PyDataProvider2, index_seq) {
|
|
|
|
|
TEST(PyDataProvider2, index_sub_seq) {
|
|
|
|
|
paddle::DataBatch batch;
|
|
|
|
|
ASSERT_EQ(readDataBatch(&batch, "test_index_sub_seq"), 200);
|
|
|
|
|
auto& arg = batch.getStreams()[0];
|
|
|
|
|
auto &arg = batch.getStreams()[0];
|
|
|
|
|
size_t tmp = 0;
|
|
|
|
|
for (size_t i = 0; i < 200; ++i) {
|
|
|
|
|
for (size_t j = 0; j < i + 1; ++j) {
|
|
|
|
|
@ -268,7 +268,7 @@ TEST(PyDataProvider2, min_pool_size) {
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
while (true) {
|
|
|
|
|
size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
|
|
|
|
|
int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
|
|
|
|
|
if (realBatchSize) {
|
|
|
|
|
totalData -= realBatchSize;
|
|
|
|
|
} else {
|
|
|
|
|
@ -291,7 +291,7 @@ TEST(PyDataProvider2, can_over_batch_size) {
|
|
|
|
|
provider->reset();
|
|
|
|
|
constexpr size_t batchSize = 100;
|
|
|
|
|
while (true) {
|
|
|
|
|
size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
|
|
|
|
|
int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
|
|
|
|
|
if (realBatchSize) {
|
|
|
|
|
CHECK_LE(realBatchSize, batchSize);
|
|
|
|
|
} else {
|
|
|
|
|
@ -317,12 +317,12 @@ TEST(PyDataProvider2, input_order) {
|
|
|
|
|
provider->reset();
|
|
|
|
|
constexpr size_t batchSize = 100;
|
|
|
|
|
while (true) {
|
|
|
|
|
size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
|
|
|
|
|
int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
|
|
|
|
|
if (!realBatchSize) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
ASSERT_EQ(batch.getStreams().size(), (size_t)2);
|
|
|
|
|
for (size_t i = 0; i < realBatchSize; ++i) {
|
|
|
|
|
ASSERT_EQ(batch.getStreams().size(), static_cast<size_t>(2));
|
|
|
|
|
for (int64_t i = 0; i < realBatchSize; ++i) {
|
|
|
|
|
ASSERT_EQ(batch.getStream(0).ids->getData()[i], 0);
|
|
|
|
|
ASSERT_EQ(batch.getStream(1).ids->getData()[i], 1);
|
|
|
|
|
}
|
|
|
|
|
@ -341,11 +341,11 @@ TEST(PyDataProvider2, test_check) {
|
|
|
|
|
paddle::DataProvider::create(config, false));
|
|
|
|
|
provider->reset();
|
|
|
|
|
while (true) {
|
|
|
|
|
size_t realBatchSize = provider->getNextBatchInternal(100, &batch);
|
|
|
|
|
int64_t realBatchSize = provider->getNextBatchInternal(100, &batch);
|
|
|
|
|
if (!realBatchSize) {
|
|
|
|
|
break;
|
|
|
|
|
} else {
|
|
|
|
|
auto& ivec = batch.getStream(0).ids;
|
|
|
|
|
auto &ivec = batch.getStream(0).ids;
|
|
|
|
|
for (size_t i = 0; i < ivec->getSize(); ++i) {
|
|
|
|
|
CHECK_LT(ivec->getData()[i], 10);
|
|
|
|
|
}
|
|
|
|
|
@ -370,7 +370,30 @@ TEST(PyDataProvider2, multiThread) {
|
|
|
|
|
provider.reset();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int main(int argc, char** argv) {
|
|
|
|
|
TEST(PyDataProvider2, minPoolSizeWithCache) {
|
|
|
|
|
paddle::DataConfig config;
|
|
|
|
|
config.set_type("py2");
|
|
|
|
|
config.set_files(FLAGS_train_list.c_str());
|
|
|
|
|
config.set_load_data_module("test_PyDataProvider2");
|
|
|
|
|
config.set_load_data_object("test_min_pool_size_with_cache");
|
|
|
|
|
config.set_async_load_data(true);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<paddle::DataProvider> provider(
|
|
|
|
|
paddle::DataProvider::create(config, false));
|
|
|
|
|
|
|
|
|
|
paddle::DataBatch batch;
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < 10; ++i) {
|
|
|
|
|
provider->reset();
|
|
|
|
|
int64_t sum = 0;
|
|
|
|
|
while (int64_t actualNum = provider->getNextBatch(100, &batch)) {
|
|
|
|
|
sum += actualNum;
|
|
|
|
|
}
|
|
|
|
|
ASSERT_EQ(1 << 20, sum);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int main(int argc, char **argv) {
|
|
|
|
|
testing::InitGoogleTest(&argc, argv);
|
|
|
|
|
paddle::initMain(argc, argv);
|
|
|
|
|
paddle::initPython(argc, argv);
|
|
|
|
|
|