fix code style & fix register bug & add release_memory

test=develop
revert-16555-model_data_cryption_link_all_lib
xjqbest 6 years ago committed by dongdaxiang
parent a0b59773af
commit be74de2c61

@ -83,10 +83,10 @@ class BlockingQueue {
return rc; return rc;
} }
void Pop(T &t) { void Pop(T *t) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !q_.empty(); }); cv_.wait(lock, [=] { return !q_.empty(); });
t = std::move(q_.front()); *t = std::move(q_.front());
q_.pop_front(); q_.pop_front();
} }

@ -48,7 +48,7 @@ bool DataFeed::SetFileList(const std::vector<std::string>& files) {
return false; return false;
} }
*/ */
//PADDLE_ENFORCE(files.size(), "You have set an empty filelist."); // PADDLE_ENFORCE(files.size(), "You have set an empty filelist.");
filelist_.assign(files.begin(), files.end()); filelist_.assign(files.begin(), files.end());
finish_set_filelist_ = true; finish_set_filelist_ = true;
@ -190,7 +190,8 @@ int InMemoryDataFeed<T>::Next() {
if (in_channel->Size() == 0) { if (in_channel->Size() == 0) {
break; break;
} }
in_channel->Pop(instance); in_channel->Pop(&instance);
AddInstanceToInsVec(&ins_vec, instance, index++); AddInstanceToInsVec(&ins_vec, instance, index++);
out_channel->Push(std::move(instance)); out_channel->Push(std::move(instance));
} }
@ -268,17 +269,19 @@ void InMemoryDataFeed<T>::FillChannelToMemoryData() {
} }
CHECK(channel != nullptr); CHECK(channel != nullptr);
CHECK(pre_channel != nullptr); CHECK(pre_channel != nullptr);
CHECK(pre_channel->Size() == 0); CHECK_EQ(pre_channel->Size(), 0);
local_vec.resize(channel->Size()); local_vec.resize(channel->Size());
for (int64_t i = 0; i < local_vec.size(); ++i) { for (int64_t i = 0; i < local_vec.size(); ++i) {
channel->Pop(local_vec[i]); channel->Pop(&local_vec[i]);
} }
VLOG(3) << "local_vec size=" << local_vec.size() <<", thread_id=" << thread_id_; VLOG(3) << "local_vec size=" << local_vec.size()
<<", thread_id=" << thread_id_;
{ {
std::lock_guard<std::mutex> g(*mutex_for_update_memory_data_); std::lock_guard<std::mutex> g(*mutex_for_update_memory_data_);
VLOG(3) << "before insert, memory_data_ size=" << memory_data_->size() VLOG(3) << "before insert, memory_data_ size=" << memory_data_->size()
<< ", thread_id=" << thread_id_; << ", thread_id=" << thread_id_;
memory_data_->insert(memory_data_->end(), local_vec.begin(), local_vec.end()); memory_data_->insert(memory_data_->end(), local_vec.begin(),
local_vec.end());
VLOG(3) << "after insert memory_data_ size=" << memory_data_->size() VLOG(3) << "after insert memory_data_ size=" << memory_data_->size()
<< ", thread_id=" << thread_id_; << ", thread_id=" << thread_id_;
} }
@ -574,7 +577,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
const char* str = reader.get(); const char* str = reader.get();
std::string line = std::string(str); std::string line = std::string(str);
//VLOG(3) << line; // VLOG(3) << line;
char* endptr = const_cast<char*>(str); char* endptr = const_cast<char*>(str);
int pos = 0; int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) { for (size_t i = 0; i < use_slots_index_.size(); ++i) {
@ -750,7 +753,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(
const char* str = reader.get(); const char* str = reader.get();
std::string line = std::string(str); std::string line = std::string(str);
//VLOG(3) << line; // VLOG(3) << line;
char* endptr = const_cast<char*>(str); char* endptr = const_cast<char*>(str);
int pos = 0; int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) { for (size_t i = 0; i < use_slots_index_.size(); ++i) {

@ -21,7 +21,8 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include <sstream> #include <sstream>
#include <future> #include <future> // NOLINT
#include <utility>
#include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"

@ -82,6 +82,18 @@ DatasetImpl<T>::GetReaders() {
return readers_; return readers_;
} }
// if sent message between workers, should first call this function
template <typename T>
void DatasetImpl<T>::RegisterClientToClientMsgHandler() {
auto fleet_ptr = FleetWrapper::GetInstance();
VLOG(3) << "RegisterClientToClientMsgHandler";
fleet_ptr->RegisterClientToClientMsgHandler(
0, [this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg);
});
VLOG(3) << "RegisterClientToClientMsgHandler done";
}
// load data into memory, Dataset hold this memory, // load data into memory, Dataset hold this memory,
// which will later be fed into readers' channel // which will later be fed into readers' channel
template <typename T> template <typename T>
@ -106,6 +118,14 @@ void DatasetImpl<T>::LoadIntoMemory() {
<< ", cost time=" << timeline.ElapsedSec() << " seconds"; << ", cost time=" << timeline.ElapsedSec() << " seconds";
} }
// release memory data
template <typename T>
void DatasetImpl<T>::ReleaseMemory() {
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() begin";
std::vector<T>().swap(memory_data_);
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end";
}
// do local shuffle // do local shuffle
template <typename T> template <typename T>
void DatasetImpl<T>::LocalShuffle() { void DatasetImpl<T>::LocalShuffle() {
@ -137,12 +157,6 @@ void DatasetImpl<T>::GlobalShuffle() {
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() begin"; VLOG(3) << "DatasetImpl<T>::GlobalShuffle() begin";
platform::Timer timeline; platform::Timer timeline;
timeline.Start(); timeline.Start();
auto fleet_ptr = FleetWrapper::GetInstance();
VLOG(3) << "RegisterClientToClientMsgHandler";
fleet_ptr->RegisterClientToClientMsgHandler(
0, [this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg);
});
if (readers_.size() == 0) { if (readers_.size() == 0) {
CreateReaders(); CreateReaders();
} }

@ -40,22 +40,43 @@ class Dataset {
public: public:
Dataset() {} Dataset() {}
virtual ~Dataset() {} virtual ~Dataset() {}
// set file list
virtual void SetFileList(const std::vector<std::string>& filelist) = 0; virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
// set readers' num
virtual void SetThreadNum(int thread_num) = 0; virtual void SetThreadNum(int thread_num) = 0;
// set workers' num
virtual void SetTrainerNum(int trainer_num) = 0; virtual void SetTrainerNum(int trainer_num) = 0;
// set fs name and ugi
virtual void SetHdfsConfig(const std::string& fs_name, virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) = 0; const std::string& fs_ugi) = 0;
// set data fedd desc, which contains:
// data feed name, batch size, slots
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0; virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
// get file list
virtual const std::vector<std::string>& GetFileList() = 0; virtual const std::vector<std::string>& GetFileList() = 0;
// get thread num
virtual int GetThreadNum() = 0; virtual int GetThreadNum() = 0;
// get worker num
virtual int GetTrainerNum() = 0; virtual int GetTrainerNum() = 0;
// get data fedd desc
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0; virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
// get readers, the reader num depend both on thread num
// and filelist size
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>& virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders() = 0; GetReaders() = 0;
// register message handler between workers
virtual void RegisterClientToClientMsgHandler() = 0;
// load all data into memory
virtual void LoadIntoMemory() = 0; virtual void LoadIntoMemory() = 0;
// release all memory data
virtual void ReleaseMemory() = 0;
// local shuffle data
virtual void LocalShuffle() = 0; virtual void LocalShuffle() = 0;
// global shuffle data
virtual void GlobalShuffle() = 0; virtual void GlobalShuffle() = 0;
// create readers
virtual void CreateReaders() = 0; virtual void CreateReaders() = 0;
// destroy readers
virtual void DestroyReaders() = 0; virtual void DestroyReaders() = 0;
protected: protected:
@ -84,10 +105,12 @@ class DatasetImpl : public Dataset {
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() { virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
return data_feed_desc_; return data_feed_desc_;
} }
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>& virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders(); GetReaders();
virtual void RegisterClientToClientMsgHandler();
virtual void LoadIntoMemory(); virtual void LoadIntoMemory();
virtual void ReleaseMemory();
virtual void LocalShuffle(); virtual void LocalShuffle();
virtual void GlobalShuffle(); virtual void GlobalShuffle();
virtual void CreateReaders(); virtual void CreateReaders();

@ -23,6 +23,7 @@ limitations under the License. */
#endif #endif
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory>
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
@ -49,7 +50,6 @@ void BindAsyncExecutor(py::module* m) {
new framework::AsyncExecutor(scope, place)); new framework::AsyncExecutor(scope, place));
})) }))
.def("run_from_files", &framework::AsyncExecutor::RunFromFile) .def("run_from_files", &framework::AsyncExecutor::RunFromFile)
//.def("run_from_dataset", &framework::AsyncExecutor::RunFromDataset)
.def("init_server", &framework::AsyncExecutor::InitServer) .def("init_server", &framework::AsyncExecutor::InitServer)
.def("init_worker", &framework::AsyncExecutor::InitWorker) .def("init_worker", &framework::AsyncExecutor::InitWorker)
.def("start_server", &framework::AsyncExecutor::StartServer) .def("start_server", &framework::AsyncExecutor::StartServer)

@ -52,7 +52,10 @@ void BindDataset(py::module* m) {
.def("set_trainer_num", &framework::Dataset::SetTrainerNum) .def("set_trainer_num", &framework::Dataset::SetTrainerNum)
.def("set_hdfs_config", &framework::Dataset::SetHdfsConfig) .def("set_hdfs_config", &framework::Dataset::SetHdfsConfig)
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc) .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("register_client2client_msg_handler",
&framework::Dataset::RegisterClientToClientMsgHandler)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory) .def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("release_memory", &framework::Dataset::ReleaseMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle) .def("local_shuffle", &framework::Dataset::LocalShuffle)
.def("global_shuffle", &framework::Dataset::GlobalShuffle); .def("global_shuffle", &framework::Dataset::GlobalShuffle);
} }

@ -237,7 +237,10 @@ class InMemoryDataset(DatasetBase):
if fleet is not None: if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker() fleet.fleet_instance.role_maker_.barrier_worker()
trainer_num = fleet.worker_num() trainer_num = fleet.worker_num()
self.dataset.register_client2client_msg_handler()
self.dataset.set_trainer_num(trainer_num) self.dataset.set_trainer_num(trainer_num)
if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker()
self.dataset.global_shuffle() self.dataset.global_shuffle()
if fleet is not None: if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker() fleet.fleet_instance.role_maker_.barrier_worker()

Loading…
Cancel
Save