|
|
|
@ -25,45 +25,9 @@ limitations under the License. */
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include "picojson.h"
|
|
|
|
|
|
|
|
|
|
void checkEqual(const paddle::Argument& expect, const paddle::Argument& actual);
|
|
|
|
|
void checkValue(std::vector<paddle::Argument>& arguments, picojson::array& arr);
|
|
|
|
|
const std::string kDir = "./trainer/tests/pydata_provider_wrapper_dir/";
|
|
|
|
|
|
|
|
|
|
TEST(PyDataProviderWrapper, NoSequenceData) {
|
|
|
|
|
paddle::DataConfig conf;
|
|
|
|
|
conf.set_type("py");
|
|
|
|
|
conf.set_load_data_module(std::string("testPyDataWrapper"));
|
|
|
|
|
conf.set_load_data_object(std::string("processNonSequenceData"));
|
|
|
|
|
conf.set_async_load_data(false);
|
|
|
|
|
conf.clear_files();
|
|
|
|
|
conf.set_files(kDir + "test_pydata_provider_wrapper.list");
|
|
|
|
|
paddle::DataProviderPtr provider(paddle::DataProvider::create(conf, false));
|
|
|
|
|
provider->setSkipShuffle();
|
|
|
|
|
provider->reset();
|
|
|
|
|
paddle::DataBatch batchFromPy;
|
|
|
|
|
provider->getNextBatch(100, &batchFromPy);
|
|
|
|
|
|
|
|
|
|
paddle::DataConfig conf2;
|
|
|
|
|
conf2.set_type("proto");
|
|
|
|
|
conf2.set_async_load_data(false);
|
|
|
|
|
conf2.clear_files();
|
|
|
|
|
conf2.set_files(kDir + "test_pydata_provider_wrapper.protolist");
|
|
|
|
|
|
|
|
|
|
provider.reset(paddle::DataProvider::create(conf2, false));
|
|
|
|
|
provider->setSkipShuffle();
|
|
|
|
|
provider->reset();
|
|
|
|
|
paddle::DataBatch batchFromProto;
|
|
|
|
|
provider->getNextBatch(100, &batchFromProto);
|
|
|
|
|
|
|
|
|
|
std::vector<paddle::Argument>& pyArguments = batchFromPy.getStreams();
|
|
|
|
|
std::vector<paddle::Argument>& protoArguments = batchFromProto.getStreams();
|
|
|
|
|
EXPECT_EQ(pyArguments.size(), protoArguments.size());
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < pyArguments.size(); ++i) {
|
|
|
|
|
checkEqual(protoArguments[i], pyArguments[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(PyDataProviderWrapper, SequenceData) {
|
|
|
|
|
paddle::DataConfig conf;
|
|
|
|
|
conf.set_type("py");
|
|
|
|
@ -148,66 +112,6 @@ int main(int argc, char** argv) {
|
|
|
|
|
return RUN_ALL_TESTS();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void checkEqual(const paddle::Argument& expect,
|
|
|
|
|
const paddle::Argument& actual) {
|
|
|
|
|
if (expect.value) {
|
|
|
|
|
EXPECT_TRUE(actual.value != nullptr);
|
|
|
|
|
paddle::Matrix* e = expect.value.get();
|
|
|
|
|
paddle::Matrix* a = actual.value.get();
|
|
|
|
|
EXPECT_EQ(e->getWidth(), a->getWidth());
|
|
|
|
|
EXPECT_EQ(e->getHeight(), a->getHeight());
|
|
|
|
|
if (dynamic_cast<paddle::CpuSparseMatrix*>(e)) {
|
|
|
|
|
paddle::CpuSparseMatrix* se = dynamic_cast<paddle::CpuSparseMatrix*>(e);
|
|
|
|
|
paddle::CpuSparseMatrix* sa = dynamic_cast<paddle::CpuSparseMatrix*>(a);
|
|
|
|
|
EXPECT_EQ(se->getFormat(), sa->getFormat());
|
|
|
|
|
EXPECT_EQ(se->getElementCnt(), sa->getElementCnt());
|
|
|
|
|
size_t rowSize = se->getFormat() == paddle::SPARSE_CSC
|
|
|
|
|
? se->getElementCnt()
|
|
|
|
|
: se->getHeight() + 1;
|
|
|
|
|
size_t colSize = se->getFormat() == paddle::SPARSE_CSC
|
|
|
|
|
? se->getWidth() + 1
|
|
|
|
|
: se->getElementCnt();
|
|
|
|
|
for (size_t i = 0; i < rowSize; ++i) {
|
|
|
|
|
EXPECT_EQ(se->getRows()[i], sa->getRows()[i]);
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < colSize; ++i) {
|
|
|
|
|
EXPECT_EQ(se->getCols()[i], sa->getCols()[i]);
|
|
|
|
|
}
|
|
|
|
|
if (se->getValueType() == paddle::FLOAT_VALUE) {
|
|
|
|
|
EXPECT_EQ(paddle::FLOAT_VALUE, sa->getValueType());
|
|
|
|
|
for (size_t i = 0; i < se->getElementCnt(); ++i) {
|
|
|
|
|
EXPECT_EQ(se->getValue()[i], sa->getValue()[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (dynamic_cast<paddle::CpuMatrix*>(e)) {
|
|
|
|
|
EXPECT_EQ(e->getElementCnt(), a->getElementCnt());
|
|
|
|
|
for (size_t i = 0; i < e->getElementCnt(); ++i) {
|
|
|
|
|
EXPECT_EQ(e->getData()[i], a->getData()[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (expect.ids) {
|
|
|
|
|
EXPECT_TRUE(actual.ids != nullptr);
|
|
|
|
|
paddle::VectorT<int>* e = expect.ids.get();
|
|
|
|
|
paddle::VectorT<int>* a = actual.ids.get();
|
|
|
|
|
EXPECT_EQ(e->getSize(), a->getSize());
|
|
|
|
|
for (size_t i = 0; i < e->getSize(); ++i) {
|
|
|
|
|
EXPECT_EQ(e->getData()[i], a->getData()[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (expect.strs) {
|
|
|
|
|
EXPECT_TRUE(actual.strs != nullptr);
|
|
|
|
|
std::vector<std::string>* e = expect.strs.get();
|
|
|
|
|
std::vector<std::string>* a = actual.strs.get();
|
|
|
|
|
EXPECT_EQ(e->size(), a->size());
|
|
|
|
|
for (size_t i = 0; i < e->size(); ++i) {
|
|
|
|
|
EXPECT_EQ((*e)[i], (*a)[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void checkValue(std::vector<paddle::Argument>& arguments,
|
|
|
|
|
picojson::array& arr) {
|
|
|
|
|
// CHECK SLOT 0, Sparse Value.
|
|
|
|
|