add DataSet and InMemoryDataFeed, support load data into memory and shuffle data

revert-16555-model_data_cryption_link_all_lib
xjqbest 6 years ago committed by dongdaxiang
parent 08c25995a2
commit 824b84d185

@ -199,6 +199,7 @@ if(WITH_PSLIB)
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass async_executor_proto
@ -208,6 +209,7 @@ else()
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass async_executor_proto

@ -154,5 +154,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
return;
}
// todo RunFromDataset
void AsyncExecutor::RunFromDataset(const ProgramDesc& main_program,
Dataset* data_set,
const std::string& trainer_desc_str,
const bool debug) {
}
} // einit_modelnd namespace framework
} // end namespace paddle

@ -30,6 +30,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {

@ -33,6 +33,14 @@ class BlockingQueue {
cv_.notify_one();
}
void Push(T &&item) {
{
std::lock_guard<std::mutex> g(mutex_);
q_.emplace_back(std::move(item));
}
cv_.notify_one();
}
template <typename U>
void Extend(const U &items) {
{
@ -44,6 +52,17 @@ class BlockingQueue {
cv_.notify_all();
}
template <typename U>
void Extend(U &&items) {
{
std::lock_guard<std::mutex> g(mutex_);
for (auto &item : items) {
q_.emplace_back(std::move(item));
}
}
cv_.notify_all();
}
std::deque<T> PopAll(size_t ms, bool *timeout) {
auto time =
std::chrono::system_clock::now() + std::chrono::milliseconds(ms);
@ -64,6 +83,18 @@ class BlockingQueue {
return rc;
}
void Pop(T &t) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !q_.empty(); });
t = std::move(q_.front());
q_.pop_front();
}
size_t Size() {
std::lock_guard<std::mutex> lock(mutex_);
return q_.size();
}
private:
std::mutex mutex_;
std::condition_variable cv_;

File diff suppressed because it is too large Load Diff

@ -27,6 +27,8 @@ limitations under the License. */
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
namespace paddle {
namespace framework {
@ -76,6 +78,19 @@ class DataFeed {
// This function is used for binding feed_vec memory
virtual void AddFeedVar(Variable* var, const std::string& name);
virtual void LoadIntoMemory() {
PADDLE_THROW("This function(LoadIntoMemory) is not implemented.");
}
virtual void LocalShuffle() {
PADDLE_THROW("This function(LocalShuffle) is not implemented.");
}
virtual void GlobalShuffle(int trainer_num) {
PADDLE_THROW("This function(GlobalShuffle) is not implemented.");
}
virtual void PutInsToChannel(const std::string& ins_str) {
PADDLE_THROW("This function(PutToChannel) is not implemented.");
}
protected:
// The following three functions are used to check if it is executed in this
// order:
@ -161,6 +176,35 @@ class PrivateQueueDataFeed : public DataFeed {
std::unique_ptr<paddle::operators::reader::BlockingQueue<T>> queue_;
};
template <typename T>
class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
public:
InMemoryDataFeed();
virtual ~InMemoryDataFeed() {}
virtual bool Start();
virtual int Next();
virtual void PutInsToChannel(const std::string& ins_str);
virtual void LoadIntoMemory();
virtual void LocalShuffle();
// todo global shuffle
//virtual void GlobalShuffle(int trainer_num);
protected:
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance, int index) = 0;
virtual bool ParseOneInstance(T* instance) = 0;
virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
virtual void PutToFeedVec(const T& ins_vec) = 0;
virtual void SerializeIns(const T& ins, std::string& str) = 0;
virtual void DeserializeIns(T& ins, const std::string& str) = 0;
std::vector<T> memory_data_;
// when read ins, we put ins from one channel to the other,
// and when finish reading, we set cur_channel = 1 - cur_channel,
// so if cur_channel=0, all data are in shuffled_ins_, else shuffled_ins_out_
int cur_channel_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_out_;
};
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
class MultiSlotType {
public:
@ -245,5 +289,23 @@ class MultiSlotDataFeed
virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
};
class MultiSlotInMemoryDataFeed
: public InMemoryDataFeed<std::vector<MultiSlotType>> {
public:
MultiSlotInMemoryDataFeed() {}
virtual ~MultiSlotInMemoryDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc);
protected:
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
const std::vector<MultiSlotType>& instance,
int index);
virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
virtual void SerializeIns(const std::vector<MultiSlotType>& ins, std::string& str);
virtual void DeserializeIns(std::vector<MultiSlotType>& ins, const std::string& str);
};
} // namespace framework
} // namespace paddle

@ -60,5 +60,6 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
}
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed);
} // namespace framework
} // namespace paddle

@ -0,0 +1,128 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/data_feed_factory.h"
namespace paddle {
namespace framework {
Dataset::Dataset() {
thread_num_ = 1;
}
void Dataset::SetFileList(const std::vector<std::string>& filelist) {
filelist_ = filelist;
int file_cnt = filelist_.size();
if (thread_num_ > file_cnt) {
VLOG(1) << "DataSet thread num = " << thread_num_ << ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt;
thread_num_ = file_cnt;
}
}
void Dataset::SetThreadNum(int thread_num) {
int file_cnt = filelist_.size();
if (file_cnt != 0 && thread_num > file_cnt) {
VLOG(1) << "DataSet thread num = " << thread_num << ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt;
thread_num = file_cnt;
}
thread_num_ = thread_num;
}
void Dataset::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num;
}
void Dataset::SetDataFeedDesc(const paddle::framework::DataFeedDesc& data_feed_desc) {
data_feed_desc_ = data_feed_desc;
}
std::vector<std::shared_ptr<paddle::framework::DataFeed>> Dataset::GetReaders() {
return readers_;
}
void Dataset::LoadIntoMemory() {
if (readers_.size() == 0) {
CreateReaders();
}
std::vector<std::thread> load_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
load_threads.push_back(std::thread(&paddle::framework::DataFeed::LoadIntoMemory,
readers_[i].get()));
}
for (std::thread& t : load_threads) {
t.join();
}
}
void Dataset::LocalShuffle() {
if (readers_.size() == 0) {
CreateReaders();
}
std::vector<std::thread> local_shuffle_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
local_shuffle_threads.push_back(std::thread(&paddle::framework::DataFeed::LocalShuffle,
readers_[i].get()));
}
for (std::thread& t : local_shuffle_threads) {
t.join();
}
}
// todo global shuffle
void Dataset::GlobalShuffle() {
/*
auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->registe_client2client_msg_handler(0,
[this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg);
});
if (readers_.size() == 0) {
CreateReaders();
}
std::vector<std::thread> global_shuffle_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
global_shuffle_threads.push_back(std::thread(&paddle::framework::DataFeed::GlobalShuffle,
readers_[i].get(), trainer_num_));
}
for (std::thread& t : global_shuffle_threads) {
t.join();
}*/
}
void Dataset::CreateReaders() {
CHECK(thread_num_ > 0) << "thread_num should > 0";
if (readers_.size() != 0) {
return;
}
for (int64_t i = 0; i < thread_num_; ++i) {
readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name()));
readers_.back()->Init(data_feed_desc_);
}
readers_[0]->SetFileList(filelist_);
}
int Dataset::ReceiveFromClient(int msg_type, int client_id, const std::string& msg) {
// can also use hash
// int64_t index = paddle::ps::local_random_engine()() % thread_num_;
// todo
int64_t index = 0;
readers_[index]->PutInsToChannel(msg);
return 0;
}
}
}

@ -0,0 +1,70 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <fstream>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
namespace paddle {
namespace framework {
class Dataset {
public:
Dataset();
virtual ~Dataset() {}
virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void SetDataFeedDesc(const paddle::framework::DataFeedDesc& data_feed_desc);
virtual const std::vector<std::string>& GetFileList() {
return filelist_;
}
virtual int GetThreadNum() {
return thread_num_;
}
virtual int GetTrainerNum() {
return trainer_num_;
}
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
return data_feed_desc_;
}
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>> GetReaders();
virtual void LoadIntoMemory();
virtual void LocalShuffle();
// todo global shuffle
virtual void GlobalShuffle();
virtual void CreateReaders();
protected:
virtual int ReceiveFromClient(int msg_type, int client_id, const std::string& msg);
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
int thread_num_;
std::string fs_name_;
std::string fs_ugi_;
paddle::framework::DataFeedDesc data_feed_desc_;
std::vector<std::string> filelist_;
int trainer_num_;
};
}
}

@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, Dataset* data_set) {
thread_num_ = trainer_desc.thread_num();
workers_.resize(thread_num_);
readers_.resize(thread_num_);

@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
@ -115,7 +116,7 @@ class Executor {
const std::string& trainer_desc_str,
const bool debug);
void RunFromDataset(const ProgramDesc& main_program, const Dataset* dataset,
void RunFromDataset(const ProgramDesc& main_program, Dataset* dataset,
const std::string& trainer_desc_str, const bool debug);
public:

@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
@ -40,7 +41,7 @@ class TrainerBase {
// model memory are hosted in root_scope
void SetScope(Scope* root_scope);
void SetDebug(const bool debug) { debug_ = debug; }
virtual void Initialize(const TrainerDesc& trainer_desc) = 0;
virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set) = 0;
virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) = 0;
virtual void InitOtherEnv(const ProgramDesc& main_program) = 0;
@ -59,7 +60,7 @@ class MultiTrainer : public TrainerBase {
public:
MultiTrainer() {}
virtual ~MultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc);
virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set);
virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place);
virtual void InitOtherEnv(const ProgramDesc& main_program) {}
@ -77,7 +78,7 @@ class DistMultiTrainer : public MultiTrainer {
public:
DistMultiTrainer() {}
virtual ~DistMultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc);
virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set);
virtual void InitOtherEnv(const ProgramDesc& main_program);
virtual void Finalize();

@ -49,6 +49,7 @@ void BindAsyncExecutor(py::module* m) {
new framework::AsyncExecutor(scope, place));
}))
.def("run_from_files", &framework::AsyncExecutor::RunFromFile)
.def("run_from_dataset", &framework::AsyncExecutor::RunFromDataset)
.def("init_server", &framework::AsyncExecutor::InitServer)
.def("init_worker", &framework::AsyncExecutor::InitWorker)
.def("start_server", &framework::AsyncExecutor::StartServer)

@ -0,0 +1,61 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
// To avoid conflicting definition in gcc-4.8.2 headers and pyconfig.h (2.7.3)
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#include <string>
#include <vector>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/pybind/async_executor_py.h"
#include "paddle/fluid/framework/data_set.h"
namespace py = pybind11;
namespace pd = paddle::framework;
namespace paddle {
namespace pybind {
void BindDataset(py::module* m) {
py::class_<framework::DataSet>(*m, "Dataset")
.def(py::init([]() {
return std::unique_ptr<framework::Dataset>(
new framework::Dataset());
}))
.def("set_filelist", &framework::Dataset::SetFileList)
.def("set_thread_num", &framework::Dataset::SetThreadNum)
.def("set_trainer_num", &framework::Dataset::SetTrainerNum)
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle)
.def("global_shuffle", &framework::Dataset::GLobalShuffle)
}
} // end namespace pybind
} // end namespace paddle

@ -0,0 +1,28 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindDataset(py::module* m);
} // namespace pybind
} // namespace paddle

@ -61,6 +61,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/recordio.h"
#include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/fluid/string/to_string.h"
#include "paddle/fluid/pybind/data_set_py.h"
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
@ -1359,6 +1360,7 @@ All parameter, weight, gradient are variables in Paddle.
BindGraph(&m);
BindNode(&m);
BindInferenceApi(&m);
BindDataset(&m);
}
} // namespace pybind
} // namespace paddle

Loading…
Cancel
Save