|
|
|
@ -24,6 +24,27 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
namespace unittest {
|
|
|
|
|
|
|
|
|
|
static std::unique_ptr<std::function<void(size_t /*poolActualSize */)>>
|
|
|
|
|
OnPoolFilled;
|
|
|
|
|
|
|
|
|
|
namespace pydp2 {
|
|
|
|
|
|
|
|
|
|
void setOnPoolFilledHook(const std::function<void(size_t)>& callback) {
|
|
|
|
|
OnPoolFilled.reset(new std::function<void(size_t)>());
|
|
|
|
|
*OnPoolFilled = callback;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void clearOnPoolFilledHook() {
|
|
|
|
|
OnPoolFilled.reset();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace pydp2
|
|
|
|
|
} // namespace unittest
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Slot type
|
|
|
|
|
*/
|
|
|
|
@ -179,6 +200,7 @@ public:
|
|
|
|
|
* Ctor
|
|
|
|
|
*/
|
|
|
|
|
PyDataProvider2(const DataConfig& config,
|
|
|
|
|
const ModelConfig& modelConfig,
|
|
|
|
|
bool useGpu)
|
|
|
|
|
:DataProvider(config, useGpu), callingContextCreated_(2) {
|
|
|
|
|
auto& args = config.load_data_args();
|
|
|
|
@ -192,6 +214,12 @@ public:
|
|
|
|
|
|
|
|
|
|
py::DictHelper kwargsDict(kwargs);
|
|
|
|
|
kwargsDict.setBool("is_train", !config.for_test());
|
|
|
|
|
std::vector<std::string> inputs;
|
|
|
|
|
inputs.reserve(modelConfig.input_layer_names().size());
|
|
|
|
|
std::copy(modelConfig.input_layer_names().begin(),
|
|
|
|
|
modelConfig.input_layer_names().end(),
|
|
|
|
|
std::back_inserter(inputs));
|
|
|
|
|
kwargsDict.setStringList("input_order", inputs);
|
|
|
|
|
|
|
|
|
|
// kwargs is keyword arguemts to create object.
|
|
|
|
|
this->createPyDataObj(config.load_data_module(),
|
|
|
|
@ -199,7 +227,7 @@ public:
|
|
|
|
|
config.files(),
|
|
|
|
|
std::move(kwargs));
|
|
|
|
|
DBG << "Instance " << instance_.get() << " loaded.";
|
|
|
|
|
this->readPyFields();
|
|
|
|
|
this->readPyFields(config.for_test());
|
|
|
|
|
DBG << "Py Field Done";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -253,14 +281,28 @@ private:
|
|
|
|
|
CHECK_PY(instance_) << "Cannot Create instance";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void readPyFields() {
|
|
|
|
|
void readPyFields(bool testing) {
|
|
|
|
|
py::ObjectHelper self(this->instance_);
|
|
|
|
|
this->skipShuffle_ = !self.getBoolAttr("should_shuffle");
|
|
|
|
|
bool ok;
|
|
|
|
|
|
|
|
|
|
this->skipShuffle_ = !self.getBoolAttr("should_shuffle",
|
|
|
|
|
&ok /*isBoolType*/);
|
|
|
|
|
if (!ok) {
|
|
|
|
|
this->skipShuffle_ = testing; // shuffle when is training, skip shuffle
|
|
|
|
|
// when is testing.
|
|
|
|
|
}
|
|
|
|
|
DBG << "Provider Skip Shuffle " << this->skipShuffle_;
|
|
|
|
|
|
|
|
|
|
this->poolSize_ = self.getIntAttr<size_t>("pool_size", &ok);
|
|
|
|
|
if (!ok) {
|
|
|
|
|
this->poolSize_ = -1UL;
|
|
|
|
|
}
|
|
|
|
|
this->minPoolSize_ = self.getIntAttr<size_t>("min_pool_size", &ok);
|
|
|
|
|
if (!ok) {
|
|
|
|
|
this->minPoolSize_ = -1UL;
|
|
|
|
|
}
|
|
|
|
|
this->minPoolSize_ = std::min(this->poolSize_, this->minPoolSize_);
|
|
|
|
|
|
|
|
|
|
this->canOverBatchSize_ = self.getBoolAttr("can_over_batch_size");
|
|
|
|
|
|
|
|
|
|
calcBatchSize_.reset(self.getAttr("calc_batch_size"));
|
|
|
|
@ -307,7 +349,6 @@ private:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void loadThread() {
|
|
|
|
|
callingContexts_.reserve(fileLists_.size());
|
|
|
|
|
DBG << "Creating context";
|
|
|
|
|
for (auto& filename : fileLists_) {
|
|
|
|
|
PyGuard g;
|
|
|
|
@ -332,7 +373,14 @@ private:
|
|
|
|
|
bool atEnd;
|
|
|
|
|
data = py::iterNext(callingContexts_[cid], &atEnd);
|
|
|
|
|
if (atEnd || data == nullptr) {
|
|
|
|
|
callingContexts_.erase(callingContexts_.begin() + cid);
|
|
|
|
|
if (cid != 0) {
|
|
|
|
|
std::swap(callingContexts_[cid], callingContexts_[0]);
|
|
|
|
|
cid = 0;
|
|
|
|
|
}
|
|
|
|
|
{
|
|
|
|
|
PyGuard g;
|
|
|
|
|
callingContexts_.pop_front();
|
|
|
|
|
}
|
|
|
|
|
this->pullCV_.notify_all();
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
@ -354,11 +402,7 @@ private:
|
|
|
|
|
if (this->loadThread_){ // wait poolActualSize < poolSize;
|
|
|
|
|
std::unique_lock<std::mutex> l(mtx_);
|
|
|
|
|
pushCV_.wait(l, [this, additionalBatchSize] {
|
|
|
|
|
if (this->canOverBatchSize_) {
|
|
|
|
|
return this->poolActualSize_ < poolSize_;
|
|
|
|
|
} else {
|
|
|
|
|
return this->poolActualSize_ + additionalBatchSize < poolSize_;
|
|
|
|
|
}
|
|
|
|
|
return this->poolActualSize_ < poolSize_;
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -402,7 +446,7 @@ private:
|
|
|
|
|
private:
|
|
|
|
|
std::unique_ptr<std::thread> loadThread_;
|
|
|
|
|
std::atomic<bool> exit_;
|
|
|
|
|
std::vector<PyObjectPtr> callingContexts_;
|
|
|
|
|
std::deque<PyObjectPtr> callingContexts_;
|
|
|
|
|
std::deque<PyObjectPtr> dataPool_;
|
|
|
|
|
size_t poolActualSize_;
|
|
|
|
|
std::condition_variable pushCV_;
|
|
|
|
@ -413,6 +457,7 @@ private:
|
|
|
|
|
|
|
|
|
|
PyObjectPtr instance_;
|
|
|
|
|
size_t poolSize_;
|
|
|
|
|
size_t minPoolSize_;
|
|
|
|
|
bool canOverBatchSize_;
|
|
|
|
|
PyObjectPtr calcBatchSize_;
|
|
|
|
|
PyObjectPtr generator_;
|
|
|
|
@ -478,8 +523,13 @@ public:
|
|
|
|
|
// data pool ready.
|
|
|
|
|
std::unique_lock<std::mutex> l(mtx_);
|
|
|
|
|
pullCV_.wait(l, [this, &size] {
|
|
|
|
|
return this->poolActualSize_ >= size || callingContexts_.empty();
|
|
|
|
|
return this->poolActualSize_ >= std::max(size, this->minPoolSize_)
|
|
|
|
|
|| callingContexts_.empty();
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
if (unittest::OnPoolFilled) {
|
|
|
|
|
(*unittest::OnPoolFilled)(this->poolActualSize_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::deque<PyObjectPtr> data;
|
|
|
|
|
size_t bsize = 0;
|
|
|
|
@ -495,7 +545,8 @@ public:
|
|
|
|
|
std::deque<PyObjectPtr>& pool = *poolPtr;
|
|
|
|
|
|
|
|
|
|
while (bsize < size && !pool.empty()) {
|
|
|
|
|
{ // move data from pool to data
|
|
|
|
|
{
|
|
|
|
|
// move data from pool to data
|
|
|
|
|
std::lock_guard<std::mutex> guard(mtx_);
|
|
|
|
|
if (skipShuffle_) {
|
|
|
|
|
size_t i = 0;
|
|
|
|
@ -505,14 +556,13 @@ public:
|
|
|
|
|
} else { // when shuffle, use swap to drop only last pool element.
|
|
|
|
|
size_t i = ThreadLocalRand::rand() % pool.size();
|
|
|
|
|
CHECK(pool[i] != nullptr);
|
|
|
|
|
if (i != pool.size() - 1) {
|
|
|
|
|
std::swap(pool[i], pool.back());
|
|
|
|
|
if (i != 0) {
|
|
|
|
|
std::swap(pool[i], pool.front());
|
|
|
|
|
}
|
|
|
|
|
data.emplace_back(std::move(pool.back()));
|
|
|
|
|
pool.pop_back();
|
|
|
|
|
data.emplace_back(std::move(pool.front()));
|
|
|
|
|
pool.pop_front();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
{
|
|
|
|
|
|
|
|
|
|
if (calcBatchSize_) { // custom calc batch size.
|
|
|
|
|
PyGuard guard;
|
|
|
|
|
Py_INCREF(data.back().get());
|
|
|
|
@ -521,8 +571,17 @@ public:
|
|
|
|
|
calcBatchSize.getArgs().set(0, data.back());
|
|
|
|
|
PyObjectPtr customBatchSize(calcBatchSize());
|
|
|
|
|
bool ok;
|
|
|
|
|
bsize += py::castInt<size_t>(customBatchSize.get(), &ok);
|
|
|
|
|
size_t tmp = py::castInt<size_t>(customBatchSize.get(), &ok);
|
|
|
|
|
CHECK(ok) << "calc_batch_size must return int";
|
|
|
|
|
|
|
|
|
|
if (bsize + tmp > size && !canOverBatchSize_) {
|
|
|
|
|
// Put data back.
|
|
|
|
|
pool.push_front(std::move(data.back()));
|
|
|
|
|
data.pop_back();
|
|
|
|
|
break;
|
|
|
|
|
} else {
|
|
|
|
|
bsize += tmp;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
bsize += 1;
|
|
|
|
|
}
|
|
|
|
@ -598,7 +657,6 @@ public:
|
|
|
|
|
} else {
|
|
|
|
|
*batch = cpuBatch;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return bsize;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -606,7 +664,8 @@ public:
|
|
|
|
|
std::unordered_set<uintptr_t > PyDataProvider2::gModuleClsPtrs_;
|
|
|
|
|
PyObjectPtr PyDataProvider2::zeroTuple_(PyTuple_New(0));
|
|
|
|
|
|
|
|
|
|
REGISTER_DATA_PROVIDER(py2, PyDataProvider2);
|
|
|
|
|
REGISTER_DATA_PROVIDER_EX(py2, PyDataProvider2);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Scanner for dense slot.
|
|
|
|
|