support multi dataset && add init model && fix bug

revert-16555-model_data_cryption_link_all_lib
xujiaqi01 6 years ago committed by dongdaxiang
parent 3c65cc1bbd
commit a5b1a0e12b

@ -155,7 +155,8 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
}
#ifdef PADDLE_WITH_PSLIB
if (mode == "mpi") {
_pull_dense_thread->stop();
// todo ?
//_pull_dense_thread->stop();
}
#endif
VLOG(3) << "start to run from files in async_executor";

File diff suppressed because it is too large Load Diff

@ -21,6 +21,7 @@ limitations under the License. */
#include <thread> // NOLINT
#include <vector>
#include <sstream>
#include <future>
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
@ -52,7 +53,10 @@ namespace framework {
// }
class DataFeed {
public:
DataFeed() {}
DataFeed() {
mutex_for_pick_file_ = nullptr;
file_idx_ = nullptr;
}
virtual ~DataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool CheckFile(const char* filename) {
@ -89,6 +93,12 @@ class DataFeed {
virtual void SetThreadNum(int thread_num) { }
// This function will do nothing at default
virtual void SetTrainerNum(int trainer_num) { }
virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex;
}
virtual void SetFileListIndex(size_t* file_index) {
file_idx_ = file_index;
}
virtual void LoadIntoMemory() {
PADDLE_THROW("This function(LoadIntoMemory) is not implemented.");
}
@ -100,7 +110,9 @@ class DataFeed {
}
// This function will do nothing at default
virtual void FillMemoryDataToChannel() { }
// This function will do nothing at default
virtual void FillChannelToMemoryData() { }
// This function will do nothing at default
virtual void PutInsToChannel(const std::string& ins_str) { }
protected:
@ -116,9 +128,9 @@ class DataFeed {
// safe).
virtual bool PickOneFile(std::string* filename);
static std::vector<std::string> filelist_;
static size_t file_idx_;
static std::mutex mutex_for_pick_file_;
std::vector<std::string> filelist_;
size_t* file_idx_;
std::mutex* mutex_for_pick_file_;
// the alias of used slots, and its order is determined by
// data_feed_desc(proto object)
@ -141,7 +153,7 @@ class DataFeed {
int batch_size_;
bool finish_init_;
static bool finish_set_filelist_;
bool finish_set_filelist_;
bool finish_start_;
std::string pipe_command_;
};
@ -215,8 +227,9 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual bool ParseOneInstance(T* instance) = 0;
virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
virtual void PutToFeedVec(const T& ins_vec) = 0;
virtual void SerializeIns(const T& ins, std::string* str) = 0;
virtual void DeserializeIns(T* ins, const std::string& str) = 0;
virtual void SerializeIns(const std::vector<T*>& ins, std::string* str) = 0;
virtual void DeserializeIns(std::vector<T>* ins, const std::string& str) = 0;
virtual std::pair<int64_t, int64_t> GetMemoryDataInterval();
int thread_id_;
int thread_num_;
@ -284,13 +297,13 @@ class MultiSlotType {
std::string DebugString() {
std::stringstream ss;
ss << "type: " << type_ << "\n";
ss << "offset:\n";
ss << "\ntype: " << type_ << "\n";
ss << "offset: ";
ss << "[";
for (const size_t& i : offset_) {
ss << offset_[i] << ",";
}
ss << "]\ndata:\n[";
ss << "]\ndata: [";
if (type_[0] == 'f') {
for (const float& i : float_feasign_) {
ss << i << ",";
@ -356,9 +369,9 @@ class MultiSlotInMemoryDataFeed
virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
virtual void SerializeIns(const std::vector<MultiSlotType>& ins,
virtual void SerializeIns(const std::vector<std::vector<MultiSlotType>*>& ins,
std::string* str);
virtual void DeserializeIns(std::vector<MultiSlotType>* ins,
virtual void DeserializeIns(std::vector<std::vector<MultiSlotType>>* ins,
const std::string& str);
};

@ -18,6 +18,8 @@
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/framework/io/fs.h"
namespace paddle {
namespace framework {
@ -25,12 +27,15 @@ namespace framework {
template <typename T>
DatasetImpl<T>::DatasetImpl() {
thread_num_ = 1;
trainer_num_ = 1;
file_idx_ = 0;
}
template <typename T>
void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
VLOG(3) << "filelist size: " << filelist.size();
filelist_ = filelist;
file_idx_ = 0;
/*
int file_cnt = filelist_.size();
if (thread_num_ > file_cnt) {
@ -45,19 +50,34 @@ void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
// not user friendly
template <typename T>
void DatasetImpl<T>::SetThreadNum(int thread_num) {
int file_cnt = filelist_.size();
VLOG(3) << "SetThreadNum thread_num=" << thread_num;
//int file_cnt = filelist_.size();
/*
if (file_cnt != 0 && thread_num > file_cnt) {
VLOG(3) << "DataSet thread num = " << thread_num
<< ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt;
thread_num = file_cnt;
}
}*/
thread_num_ = thread_num;
}
template <typename T>
void DatasetImpl<T>::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num;
// should inform reader of trainer_num directly
for (auto reader : readers_) {
reader->SetTrainerNum(trainer_num);
}
}
template <typename T>
void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) {
std::string cmd = std::string("hadoop fs");
cmd += " -D fs.default.name=" + fs_name;
cmd += " -D hadoop.job.ugi=" + fs_ugi;
paddle::framework::hdfs_set_command(cmd);
}
template <typename T>
@ -75,6 +95,8 @@ DatasetImpl<T>::GetReaders() {
template <typename T>
void DatasetImpl<T>::LoadIntoMemory() {
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) {
CreateReaders();
}
@ -86,12 +108,17 @@ void DatasetImpl<T>::LoadIntoMemory() {
for (std::thread& t : load_threads) {
t.join();
}
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() end";
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() end"
<< ", memory data size=" << memory_data_.size()
<< ", cost time=" << timeline.ElapsedSec() << " seconds";
}
template <typename T>
void DatasetImpl<T>::LocalShuffle() {
VLOG(3) << "DatasetImpl<T>::LocalShuffle() begin";
platform::Timer timeline;
timeline.Start();
if (readers_.size() == 0) {
CreateReaders();
}
@ -107,23 +134,27 @@ void DatasetImpl<T>::LocalShuffle() {
t.join();
}
std::vector<T>().swap(memory_data_);
VLOG(3) << "DatasetImpl<T>::LocalShuffle() end";
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::LocalShuffle() end, cost time="
<< timeline.ElapsedSec() << " seconds";
}
template <typename T>
void DatasetImpl<T>::GlobalShuffle() {
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() begin";
if (readers_.size() == 0) {
CreateReaders();
}
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
platform::Timer timeline;
timeline.Start();
auto fleet_ptr = FleetWrapper::GetInstance();
VLOG(3) << "RegisterClientToClientMsgHandler";
fleet_ptr->RegisterClientToClientMsgHandler(
0, [this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg);
});
if (readers_.size() == 0) {
CreateReaders();
}
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
VLOG(3) << "start global shuffle threads";
std::vector<std::thread> global_shuffle_threads;
for (int i = 0; i < thread_num_; ++i) {
@ -133,15 +164,32 @@ void DatasetImpl<T>::GlobalShuffle() {
for (std::thread& t : global_shuffle_threads) {
t.join();
}
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end";
std::vector<T>().swap(memory_data_);
timeline.Pause();
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() end, cost time="
<< timeline.ElapsedSec() << " seconds";
}
template <typename T>
void DatasetImpl<T>::CreateReaders() {
VLOG(3) << "Calling CreateReaders()";
CHECK(thread_num_ > 0) << "thread_num should > 0";
int file_cnt = filelist_.size();
int memory_data_size = memory_data_.size();
if (memory_data_size != 0 && thread_num_ > memory_data_size) {
VLOG(3) << "Dataset thread num = " << thread_num_
<< ", memory data size = " << memory_data_size
<< ". Changing Dataset thread num = " << memory_data_size;
thread_num_ = memory_data_size;
} else if (file_cnt != 0 && thread_num_ > file_cnt) {
VLOG(3) << "Dataset thread num = " << thread_num_
<< ", file num = " << file_cnt
<< ". Changing Dataset thread num = " << file_cnt;
thread_num_ = file_cnt;
}
VLOG(3) << "thread_num in Readers: " << thread_num_;
VLOG(3) << "readers size: " << readers_.size();
VLOG(3) << "Filelist size in readers: " << filelist_.size();
if (readers_.size() != 0) {
return;
}
@ -154,9 +202,10 @@ void DatasetImpl<T>::CreateReaders() {
readers_.back()->SetThreadId(i);
readers_.back()->SetThreadNum(thread_num_);
readers_.back()->SetTrainerNum(trainer_num_);
readers_.back()->SetFileListMutex(&mutex_for_pick_file_);
readers_.back()->SetFileListIndex(&file_idx_);
readers_.back()->SetFileList(filelist_);
}
VLOG(3) << "Filelist size in readers: " << filelist_.size();
readers_[0]->SetFileList(filelist_);
}
template <typename T>
@ -184,9 +233,12 @@ void DatasetImpl<T>::DestroyReaders() {
template <typename T>
int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) {
// todo random
// int64_t index = paddle::ps::local_random_engine()() % thread_num_;
int64_t index = 0;
VLOG(3) << "ReceiveFromClient msg_type=" << msg_type
<< ", client_id=" << client_id << ", msg length="
<< msg.length();
auto fleet_ptr = FleetWrapper::GetInstance();
int64_t index = fleet_ptr->LocalRandomEngine()() % thread_num_;
VLOG(3) << "ramdom index=" << index;
readers_[index]->PutInsToChannel(msg);
return 0;
}

@ -33,6 +33,8 @@ class Dataset {
virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
virtual void SetThreadNum(int thread_num) = 0;
virtual void SetTrainerNum(int trainer_num) = 0;
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) = 0;
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
virtual const std::vector<std::string>& GetFileList() = 0;
virtual int GetThreadNum() = 0;
@ -60,6 +62,8 @@ class DatasetImpl : public Dataset {
virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi);
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
@ -85,8 +89,10 @@ class DatasetImpl : public Dataset {
std::mutex mutex_for_update_memory_data_;
int thread_num_;
paddle::framework::DataFeedDesc data_feed_desc_;
std::vector<std::string> filelist_;
int trainer_num_;
std::vector<std::string> filelist_;
size_t file_idx_;
std::mutex mutex_for_pick_file_;
};
class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> {

@ -26,12 +26,14 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
thread_num_ = trainer_desc.thread_num();
SetDataset(dataset);
workers_.resize(thread_num_);
dataset->CreateReaders();
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
dataset->GetReaders();
thread_num_ = readers.size();
workers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());

@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include <utility>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
@ -203,6 +204,60 @@ void FleetWrapper::PullDenseVarsSync(
#endif
}
void FleetWrapper::PushDenseParamSync(
const ProgramDesc& program, const uint64_t table_id,
const std::vector<std::string>& var_names) {
#ifdef PADDLE_WITH_PSLIB
paddle::framework::Scope scope;
auto& block = program.Block(0);
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = scope.Var(var->Name());
InitializeVariable(ptr, var->GetType());
} else {
auto* ptr = scope.Var(var->Name());
InitializeVariable(ptr, var->GetType());
}
}
auto place = platform::CPUPlace();
std::vector<paddle::ps::Region> regions;
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
std::vector<int64_t> dim;
for (auto& var : block.AllVars()) {
if (var->Name() == t) {
dim = var->GetShape();
break;
}
}
int cnt = 1;
for (auto& i: dim) {
cnt *= i;
}
DDim d(std::vector<int64_t>{cnt}.data(), 1);
float* g = tensor->mutable_data<float>(d, place);
CHECK(g != nullptr) << "var[" << t << "] value not initialized";
float init_range = 0.2;
int rown = tensor->dims()[0];
init_range /= sqrt(rown);
std::normal_distribution<float> ndistr(0.0, 1.0);
for (auto i = 0u; i < tensor->numel(); ++i) {
g[i] = ndistr(LocalRandomEngine()) * init_range;
}
paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
auto push_status = pslib_ptr_->_worker_ptr->push_dense_param(
regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
CHECK(status == 0) << "push dense param failed, status["
<< status << "]";
}
#endif
}
void FleetWrapper::PushDenseVarsSync(
Scope* scope, const uint64_t table_id,
const std::vector<std::string>& var_names) {}
@ -269,6 +324,8 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
continue;
}
LOG(WARNING) << "going to memcpy";
CHECK(fea_idx < (*push_values).size());
CHECK(fea_idx < fea_labels.size());
memcpy((*push_values)[fea_idx].data() + offset, g,
sizeof(float) * emb_dim);
LOG(WARNING) << "show";
@ -294,13 +351,13 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif
}
int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
MsgHandlerFunc handler) {
int FleetWrapper::RegisterClientToClientMsgHandler(
int msg_type, MsgHandlerFunc handler) {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler";
VLOG(3) << "pslib_ptr_=" << pslib_ptr_;
VLOG(3) << "_worker_ptr=" << pslib_ptr_->_worker_ptr;
pslib_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type, handler);
return pslib_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type, handler);
#else
VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler"
<< " does nothing when no pslib";
@ -308,15 +365,15 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
return 0;
}
int FleetWrapper::SendClientToClientMsg(int msg_type, int to_client_id,
const std::string& msg) {
std::future<int32_t> FleetWrapper::SendClientToClientMsg(
int msg_type, int to_client_id, const std::string& msg) {
#ifdef PADDLE_WITH_PSLIB
pslib_ptr_->_worker_ptr->send_client2client_msg(msg_type, to_client_id, msg);
return pslib_ptr_->_worker_ptr->send_client2client_msg(msg_type, to_client_id, msg);
#else
VLOG(0) << "FleetWrapper::SendClientToClientMsg"
<< " does nothing when no pslib";
#endif
return 0;
return std::future<int32_t>();
}
std::default_random_engine& FleetWrapper::LocalRandomEngine() {
@ -336,10 +393,12 @@ std::default_random_engine& FleetWrapper::LocalRandomEngine() {
}
template <typename T>
void FleetWrapper::Serialize(const T& t, std::string* str) {
void FleetWrapper::Serialize(const std::vector<T*>& t, std::string* str) {
#ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar;
ar << t;
for (size_t i = 0; i < t.size(); ++i) {
ar << *(t[i]);
}
*str = std::string(ar.buffer(), ar.length());
#else
VLOG(0) << "FleetWrapper::Serialize does nothing when no pslib";
@ -347,20 +406,30 @@ void FleetWrapper::Serialize(const T& t, std::string* str) {
}
template <typename T>
void FleetWrapper::Deserialize(T* t, const std::string& str) {
void FleetWrapper::Deserialize(std::vector<T>* t, const std::string& str) {
#ifdef PADDLE_WITH_PSLIB
if (str.length() == 0) {
return;
}
paddle::ps::BinaryArchive ar;
ar.set_read_buffer(const_cast<char*>(str.c_str()), str.length(), nullptr);
*t = ar.get<T>();
if (ar.cursor() == ar.finish()) {
return;
}
while (ar.cursor() < ar.finish()) {
t->push_back(ar.get<T>());
}
CHECK(ar.cursor() == ar.finish());
VLOG(3) << "Deserialize size " << t->size();
#else
VLOG(0) << "FleetWrapper::Deserialize does nothing when no pslib";
#endif
}
template void FleetWrapper::Serialize<std::vector<MultiSlotType>>(
const std::vector<MultiSlotType>&, std::string*);
template void FleetWrapper::Deserialize(std::vector<MultiSlotType>*,
const std::string&);
const std::vector<std::vector<MultiSlotType>*>&, std::string*);
template void FleetWrapper::Deserialize<std::vector<MultiSlotType>>(
std::vector<std::vector<MultiSlotType>>*, const std::string&);
} // end namespace framework
} // end namespace paddle

@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
@ -71,6 +72,10 @@ class FleetWrapper {
const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* pull_dense_status);
void PushDenseParamSync(
const ProgramDesc& program, const uint64_t table_id,
const std::vector<std::string>& var_names);
// Push dense variables to server in async mode
// Param<in>: scope, table_id, var_names,
// Param<out>: push_sparse_status
@ -119,16 +124,15 @@ class FleetWrapper {
typedef std::function<int32_t (int, int, const std::string&)> MsgHandlerFunc;
int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
int SendClientToClientMsg(int msg_type,
int to_client_id,
const std::string& msg);
std::future<int32_t> SendClientToClientMsg(int msg_type,
int to_client_id,
const std::string& msg);
std::default_random_engine& LocalRandomEngine();
template <typename T>
void Serialize(const T& t, std::string* str);
void Serialize(const std::vector<T*>& t, std::string* str);
template <typename T>
void Deserialize(T* t, const std::string& str);
void Deserialize(std::vector<T>* t, const std::string& str);
static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::FleetWrapper());

@ -26,13 +26,15 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_ = trainer_desc.thread_num();
SetDataset(dataset);
// get filelist from trainer_desc here
workers_.resize(thread_num_);
VLOG(3) << "worker thread num: " << thread_num_;
dataset->CreateReaders();
VLOG(3) << "readers created";
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
dataset->GetReaders();
VLOG(3) << "readers num: " << readers.size();
// change thread num to readers num
thread_num_ = readers.size();
VLOG(3) << "worker thread num: " << thread_num_;
workers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());

@ -49,7 +49,7 @@ void BindAsyncExecutor(py::module* m) {
new framework::AsyncExecutor(scope, place));
}))
.def("run_from_files", &framework::AsyncExecutor::RunFromFile)
.def("run_from_dataset", &framework::AsyncExecutor::RunFromDataset)
//.def("run_from_dataset", &framework::AsyncExecutor::RunFromDataset)
.def("init_server", &framework::AsyncExecutor::InitServer)
.def("init_worker", &framework::AsyncExecutor::InitWorker)
.def("start_server", &framework::AsyncExecutor::StartServer)

@ -50,6 +50,7 @@ void BindDataset(py::module* m) {
.def("set_filelist", &framework::Dataset::SetFileList)
.def("set_thread_num", &framework::Dataset::SetThreadNum)
.def("set_trainer_num", &framework::Dataset::SetTrainerNum)
.def("set_hdfs_config", &framework::Dataset::SetHdfsConfig)
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle)

@ -47,6 +47,7 @@ void BindFleetWrapper(py::module* m) {
.def("init_server", &framework::FleetWrapper::InitServer)
.def("run_server", &framework::FleetWrapper::RunServer)
.def("init_worker", &framework::FleetWrapper::InitWorker)
.def("init_model", &framework::FleetWrapper::PushDenseParamSync)
.def("stop_server", &framework::FleetWrapper::StopServer)
.def("gather_servers", &framework::FleetWrapper::GatherServers);
} // end FleetWrapper

@ -86,6 +86,9 @@ class DatasetBase(object):
"Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
)
def set_hdfs_config(self, fs_name, fs_ugi):
self.dataset.set_hdfs_config(fs_name, fs_ugi)
def _prepare_to_run(self):
self.dataset.set_data_feed_desc(self.desc())
@ -115,11 +118,15 @@ class InMemoryDataset(DatasetBase):
def local_shuffle(self):
self.dataset.local_shuffle()
def global_shuffle(self):
from .distributed import ps_instance
instance = ps_instance.PaddlePSInstance(1, 2)
self.dataset.set_trainer_num(instance.get_worker_num())
def global_shuffle(self, fleet=None):
trainer_num = 1
if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker()
trainer_num = fleet.worker_num()
self.dataset.set_trainer_num(trainer_num)
self.dataset.global_shuffle()
if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker()
class QueueDataset(DatasetBase):
@ -130,5 +137,5 @@ class QueueDataset(DatasetBase):
def local_shuffle(self):
pass
def global_shuffle(self):
def global_shuffle(self, fleet=None):
pass

@ -170,7 +170,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
"""
if self._check_role_generation():
if self.is_worker():
return self.get_size()
return self.get_size() / 2;
return 0
def server_num(self):
@ -179,7 +179,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
"""
if self._check_role_generation():
if self.is_server():
return self.get_size()
return self.get_size() / 2;
return 0
def worker_index(self):

@ -43,7 +43,7 @@ class Fleet(object):
save_pserver_model(): save model parameters in pserver, called from a server node
Example:
.. code-block:: python
import paddle.fluid.incubate.fleet.parameter_server as fleet
from my_model import bow_net
@ -58,7 +58,7 @@ class Fleet(object):
fleet.init_worker() # init worker should be called before training
# do other things like training
elif fleet.is_server():
fleet.init_pserver()
fleet.init_pserver()
fleet.stop()
"""
@ -75,7 +75,7 @@ class Fleet(object):
"""
init(): which should be called only once in user's python scripts. init() will initialize
FleetWrapper in CPP, it will also initialize a RoleMaker which is used for identifying
current node's role, e.g. worker, server, etc.
current node's role, e.g. worker, server, etc.
"""
if not self.is_initialized_:
self.role_maker_ = MPISymetricRoleMaker()
@ -122,7 +122,7 @@ class Fleet(object):
print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1)
def init_worker(self):
def init_worker(self, program):
"""
init_worker(): will be called by user. When a user knows current process is_server(), he/she
should call init_worker() to initialize global information about worker and connect
@ -143,6 +143,19 @@ class Fleet(object):
self.role_maker_.get_rank())
self.role_maker_.barrier_all()
self.role_maker_.barrier_worker()
if self.role_maker_.is_first_worker():
tables = self._dist_desc.trainer_param.dense_table._values
for i in range(0, len(tables)):
table = tables[i];
var_name_list = []
for i in range(0, len(table.dense_variable_name)):
var_name_list.append(table.dense_variable_name[i])
#print "table id ", table.table_id
#print "var_name_list ", var_name_list
self._fleet_ptr.init_model(program.desc,
int(table.table_id),
var_name_list)
self.role_maker_.barrier_worker()
else:
print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1)

Loading…
Cancel
Save