Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add-async-ssa-graph-executor-communicator
test=developrevert-16555-model_data_cryption_link_all_lib
commit
adf272bcec
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,150 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <mutex> // NOLINT
|
||||
#include <string>
|
||||
#include <thread> // NOLINT
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/data_feed.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
// Dataset is a abstract class, which defines user interfaces
|
||||
// Example Usage:
|
||||
// Dataset* dataset = DatasetFactory::CreateDataset("InMemoryDataset")
|
||||
// dataset->SetFileList(std::vector<std::string>{"a.txt", "b.txt"})
|
||||
// dataset->SetThreadNum(1)
|
||||
// dataset->CreateReaders();
|
||||
// dataset->SetDataFeedDesc(your_data_feed_desc);
|
||||
// dataset->LoadIntoMemory();
|
||||
// dataset->SetTrainerNum(2);
|
||||
// dataset->GlobalShuffle();
|
||||
class Dataset {
|
||||
public:
|
||||
Dataset() {}
|
||||
virtual ~Dataset() {}
|
||||
// set file list
|
||||
virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
|
||||
// set readers' num
|
||||
virtual void SetThreadNum(int thread_num) = 0;
|
||||
// set workers' num
|
||||
virtual void SetTrainerNum(int trainer_num) = 0;
|
||||
// set fs name and ugi
|
||||
virtual void SetHdfsConfig(const std::string& fs_name,
|
||||
const std::string& fs_ugi) = 0;
|
||||
// set data fedd desc, which contains:
|
||||
// data feed name, batch size, slots
|
||||
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
|
||||
// get file list
|
||||
virtual const std::vector<std::string>& GetFileList() = 0;
|
||||
// get thread num
|
||||
virtual int GetThreadNum() = 0;
|
||||
// get worker num
|
||||
virtual int GetTrainerNum() = 0;
|
||||
// get hdfs config
|
||||
virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
|
||||
// get data fedd desc
|
||||
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
|
||||
// get readers, the reader num depend both on thread num
|
||||
// and filelist size
|
||||
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
|
||||
GetReaders() = 0;
|
||||
// register message handler between workers
|
||||
virtual void RegisterClientToClientMsgHandler() = 0;
|
||||
// load all data into memory
|
||||
virtual void LoadIntoMemory() = 0;
|
||||
// release all memory data
|
||||
virtual void ReleaseMemory() = 0;
|
||||
// local shuffle data
|
||||
virtual void LocalShuffle() = 0;
|
||||
// global shuffle data
|
||||
virtual void GlobalShuffle() = 0;
|
||||
// create readers
|
||||
virtual void CreateReaders() = 0;
|
||||
// destroy readers
|
||||
virtual void DestroyReaders() = 0;
|
||||
|
||||
protected:
|
||||
virtual int ReceiveFromClient(int msg_type, int client_id,
|
||||
const std::string& msg) = 0;
|
||||
};
|
||||
|
||||
// DatasetImpl is the implementation of Dataset,
|
||||
// it holds memory data if user calls load_into_memory
|
||||
template <typename T>
|
||||
class DatasetImpl : public Dataset {
|
||||
public:
|
||||
DatasetImpl();
|
||||
virtual ~DatasetImpl() {}
|
||||
|
||||
virtual void SetFileList(const std::vector<std::string>& filelist);
|
||||
virtual void SetThreadNum(int thread_num);
|
||||
virtual void SetTrainerNum(int trainer_num);
|
||||
virtual void SetHdfsConfig(const std::string& fs_name,
|
||||
const std::string& fs_ugi);
|
||||
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
|
||||
|
||||
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
|
||||
virtual int GetThreadNum() { return thread_num_; }
|
||||
virtual int GetTrainerNum() { return trainer_num_; }
|
||||
virtual std::pair<std::string, std::string> GetHdfsConfig() {
|
||||
return std::make_pair(fs_name_, fs_ugi_);
|
||||
}
|
||||
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
|
||||
return data_feed_desc_;
|
||||
}
|
||||
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
|
||||
GetReaders();
|
||||
|
||||
virtual void RegisterClientToClientMsgHandler();
|
||||
virtual void LoadIntoMemory();
|
||||
virtual void ReleaseMemory();
|
||||
virtual void LocalShuffle();
|
||||
virtual void GlobalShuffle();
|
||||
virtual void CreateReaders();
|
||||
virtual void DestroyReaders();
|
||||
|
||||
protected:
|
||||
virtual int ReceiveFromClient(int msg_type, int client_id,
|
||||
const std::string& msg);
|
||||
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
|
||||
std::vector<T> memory_data_;
|
||||
std::mutex mutex_for_update_memory_data_;
|
||||
int thread_num_;
|
||||
paddle::framework::DataFeedDesc data_feed_desc_;
|
||||
int trainer_num_;
|
||||
std::vector<std::string> filelist_;
|
||||
size_t file_idx_;
|
||||
std::mutex mutex_for_pick_file_;
|
||||
std::string fs_name_;
|
||||
std::string fs_ugi_;
|
||||
unsigned int rand_seed;
|
||||
};
|
||||
|
||||
// use std::vector<MultiSlotType> as data type
|
||||
class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> {
|
||||
public:
|
||||
MultiSlotDataset() {}
|
||||
virtual ~MultiSlotDataset() {}
|
||||
};
|
||||
|
||||
} // end namespace framework
|
||||
} // end namespace paddle
|
@ -0,0 +1,66 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/dataset_factory.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "paddle/fluid/framework/data_set.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
typedef std::shared_ptr<Dataset> (*CreateDatasetFunction)();
|
||||
typedef std::unordered_map<std::string, CreateDatasetFunction> datasetMap;
|
||||
datasetMap g_dataset_map;
|
||||
|
||||
#define REGISTER_DATASET_CLASS(dataset_class) \
|
||||
namespace { \
|
||||
std::shared_ptr<Dataset> Creator_##dataset_class() { \
|
||||
return std::shared_ptr<Dataset>(new dataset_class); \
|
||||
} \
|
||||
class __Registerer_##dataset_class { \
|
||||
public: \
|
||||
__Registerer_##dataset_class() { \
|
||||
g_dataset_map[#dataset_class] = &Creator_##dataset_class; \
|
||||
} \
|
||||
}; \
|
||||
__Registerer_##dataset_class g_registerer_##dataset_class; \
|
||||
} // namespace
|
||||
|
||||
std::string DatasetFactory::DatasetTypeList() {
|
||||
std::string dataset_types;
|
||||
for (auto iter = g_dataset_map.begin(); iter != g_dataset_map.end(); ++iter) {
|
||||
if (iter != g_dataset_map.begin()) {
|
||||
dataset_types += ", ";
|
||||
}
|
||||
dataset_types += iter->first;
|
||||
}
|
||||
return dataset_types;
|
||||
}
|
||||
|
||||
std::shared_ptr<Dataset> DatasetFactory::CreateDataset(
|
||||
std::string dataset_class) {
|
||||
if (g_dataset_map.count(dataset_class) < 1) {
|
||||
LOG(WARNING) << "Your Dataset " << dataset_class
|
||||
<< "is not supported currently";
|
||||
LOG(WARNING) << "Supported Dataset: " << DatasetTypeList();
|
||||
exit(-1);
|
||||
}
|
||||
return g_dataset_map[dataset_class]();
|
||||
}
|
||||
|
||||
REGISTER_DATASET_CLASS(MultiSlotDataset);
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,29 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/data_set.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
class DatasetFactory {
|
||||
public:
|
||||
static std::string DatasetTypeList();
|
||||
static std::shared_ptr<Dataset> CreateDataset(std::string dataset_class);
|
||||
};
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,27 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/device_worker.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
void DeviceWorker::SetRootScope(Scope* root_scope) { root_scope_ = root_scope; }
|
||||
|
||||
void DeviceWorker::SetDataFeed(const std::shared_ptr<DataFeed>& data_feed) {
|
||||
device_reader_ = data_feed;
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,198 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex> // NOLINT
|
||||
#include <string>
|
||||
#include <thread> // NOLINT
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/data_feed.h"
|
||||
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/framework/reader.h"
|
||||
#include "paddle/fluid/framework/trainer_desc.pb.h"
|
||||
#include "paddle/fluid/framework/variable_helper.h"
|
||||
#include "paddle/fluid/operators/reader/blocking_queue.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
#include "paddle/fluid/platform/port.h"
|
||||
#include "paddle/fluid/platform/timer.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class PullDenseWorker {
|
||||
public:
|
||||
virtual ~PullDenseWorker() {}
|
||||
virtual void Initialize(const TrainerDesc& param);
|
||||
int Start();
|
||||
void Stop();
|
||||
void SetRootScope(Scope* scope) { root_scope_ = scope; }
|
||||
void IncreaseThreadVersion(int thread_id, uint64_t table_id);
|
||||
void ResetThreadVersion(uint64_t table_id);
|
||||
void Wait(std::vector<::std::future<int32_t>>* status_vec);
|
||||
static std::shared_ptr<PullDenseWorker> GetInstance() {
|
||||
if (NULL == s_instance_) {
|
||||
s_instance_.reset(new paddle::framework::PullDenseWorker());
|
||||
}
|
||||
return s_instance_;
|
||||
}
|
||||
|
||||
private:
|
||||
PullDenseWorker() : root_scope_(NULL) {}
|
||||
void Run();
|
||||
bool CheckUpdateParam(uint64_t table_id);
|
||||
|
||||
private:
|
||||
static std::shared_ptr<PullDenseWorker> s_instance_;
|
||||
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
|
||||
PullDenseWorkerParameter param_;
|
||||
DownpourWorkerParameter dwp_param_;
|
||||
Scope* root_scope_;
|
||||
bool running_;
|
||||
|
||||
static std::map<uint64_t, uint64_t> last_versions_;
|
||||
static std::map<uint64_t, uint64_t> current_version_;
|
||||
static std::mutex mutex_for_version_;
|
||||
static std::map<uint64_t, std::vector<uint64_t>> training_versions_;
|
||||
static std::map<uint64_t, std::vector<std::string>> dense_value_names_;
|
||||
|
||||
std::thread t_;
|
||||
int thread_num_;
|
||||
int sleep_time_ms_;
|
||||
int threshold_;
|
||||
|
||||
std::vector<::std::future<int32_t>> pull_dense_status_;
|
||||
uint32_t pull_dense_fail_times_ = 0;
|
||||
std::vector<float> base_norm_param_;
|
||||
std::vector<float> mean_;
|
||||
std::vector<float> scale_;
|
||||
float squared_sum_epsilon_ = 1e-4;
|
||||
std::mutex mutex_for_mean_scale_;
|
||||
float total_batch_num_ = 0;
|
||||
};
|
||||
|
||||
// should incorporate different type of device
|
||||
class DeviceWorker {
|
||||
public:
|
||||
DeviceWorker() {}
|
||||
virtual ~DeviceWorker() {}
|
||||
virtual void Initialize(const TrainerDesc& desc) = 0;
|
||||
virtual void SetDeviceIndex(int tid) = 0;
|
||||
virtual void TrainFiles() = 0;
|
||||
virtual void PrintFetchVars() = 0;
|
||||
virtual void TrainFilesWithProfiler() = 0;
|
||||
virtual void CreateDeviceResource(const ProgramDesc& main_prog) = 0;
|
||||
// will make this zero copy in the future
|
||||
virtual void BindingDataFeedMemory() = 0;
|
||||
virtual void SetRootScope(Scope* root_scope);
|
||||
virtual void SetDataFeed(const std::shared_ptr<DataFeed>& data_feed);
|
||||
virtual void SetPlace(const paddle::platform::Place& place) {
|
||||
place_ = place;
|
||||
}
|
||||
|
||||
protected:
|
||||
Scope* root_scope_;
|
||||
paddle::platform::Place place_;
|
||||
std::shared_ptr<DataFeed> device_reader_;
|
||||
int64_t batch_num_;
|
||||
FetchConfig fetch_config_;
|
||||
};
|
||||
|
||||
class CPUWorkerBase : public DeviceWorker {
|
||||
public:
|
||||
CPUWorkerBase() {}
|
||||
virtual ~CPUWorkerBase() {}
|
||||
virtual void SetDeviceIndex(int tid) { thread_id_ = tid; }
|
||||
virtual void TrainFiles() = 0;
|
||||
virtual void TrainFilesWithProfiler() {}
|
||||
virtual void PrintFetchVars() {}
|
||||
virtual void CreateDeviceResource(const ProgramDesc& main_prog) {}
|
||||
|
||||
protected:
|
||||
int thread_id_;
|
||||
};
|
||||
|
||||
class HogwildWorker : public CPUWorkerBase {
|
||||
public:
|
||||
HogwildWorker() {}
|
||||
virtual ~HogwildWorker() {}
|
||||
virtual void Initialize(const TrainerDesc& desc);
|
||||
virtual void TrainFiles();
|
||||
virtual void TrainFilesWithProfiler();
|
||||
virtual void PrintFetchVars();
|
||||
virtual void CreateDeviceResource(const ProgramDesc& main_prog);
|
||||
virtual void BindingDataFeedMemory();
|
||||
|
||||
protected:
|
||||
void CreateThreadOperators(const ProgramDesc& program);
|
||||
void CreateThreadScope(const ProgramDesc& program);
|
||||
std::vector<std::string> op_names_;
|
||||
std::vector<OperatorBase*> ops_;
|
||||
Scope* thread_scope_;
|
||||
HogwildWorkerParameter param_;
|
||||
std::vector<std::string> skip_ops_;
|
||||
};
|
||||
|
||||
class DownpourWorker : public HogwildWorker {
|
||||
public:
|
||||
DownpourWorker() {}
|
||||
virtual ~DownpourWorker() {}
|
||||
virtual void Initialize(const TrainerDesc& desc);
|
||||
virtual void TrainFiles();
|
||||
virtual void TrainFilesWithProfiler();
|
||||
|
||||
protected:
|
||||
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
|
||||
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
|
||||
void FillSparseValue(size_t table_id);
|
||||
void PushGradients();
|
||||
void CollectLabelInfo(size_t table_id);
|
||||
|
||||
private:
|
||||
bool need_to_push_dense_;
|
||||
bool need_to_push_sparse_;
|
||||
DownpourWorkerParameter param_;
|
||||
// just save the value in param_ for easy access
|
||||
std::map<uint64_t, std::string> label_var_name_;
|
||||
std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
|
||||
std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
|
||||
std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
|
||||
std::map<uint64_t, std::vector<std::string>> dense_value_names_;
|
||||
std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
|
||||
|
||||
// feasign
|
||||
std::map<uint64_t, std::vector<uint64_t>> features_;
|
||||
// feasign stats
|
||||
std::map<uint64_t, std::vector<float>> feature_labels_;
|
||||
// feasign embedding
|
||||
std::map<uint64_t, std::vector<std::vector<float>>> feature_values_;
|
||||
// feasign embedding gradient
|
||||
std::map<uint64_t, std::vector<std::vector<float>>> feature_grads_;
|
||||
// skipped ops
|
||||
std::vector<std::string> skip_ops_;
|
||||
|
||||
std::shared_ptr<PullDenseWorker> _pull_dense_worker;
|
||||
std::vector<::std::future<int32_t>> push_sparse_status_;
|
||||
std::vector<::std::future<int32_t>> push_dense_status_;
|
||||
};
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,65 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/device_worker_factory.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
typedef std::shared_ptr<DeviceWorker> (*Createdevice_workerFunction)();
|
||||
typedef std::unordered_map<std::string, Createdevice_workerFunction>
|
||||
device_workerMap;
|
||||
device_workerMap g_device_worker_map;
|
||||
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
|
||||
namespace { \
|
||||
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
|
||||
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
|
||||
} \
|
||||
class __Registerer_##device_worker_class { \
|
||||
public: \
|
||||
__Registerer_##device_worker_class() { \
|
||||
g_device_worker_map[#device_worker_class] = \
|
||||
&Creator_##device_worker_class; \
|
||||
} \
|
||||
}; \
|
||||
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
|
||||
} // namespace
|
||||
|
||||
std::string DeviceWorkerFactory::DeviceWorkerTypeList() {
|
||||
std::string device_worker_types;
|
||||
for (auto iter = g_device_worker_map.begin();
|
||||
iter != g_device_worker_map.end(); ++iter) {
|
||||
if (iter != g_device_worker_map.begin()) {
|
||||
device_worker_types += ", ";
|
||||
}
|
||||
device_worker_types += iter->first;
|
||||
}
|
||||
return device_worker_types;
|
||||
}
|
||||
|
||||
std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
|
||||
std::string device_worker_class) {
|
||||
if (g_device_worker_map.count(device_worker_class) < 1) {
|
||||
exit(-1);
|
||||
}
|
||||
return g_device_worker_map[device_worker_class]();
|
||||
}
|
||||
|
||||
REGISTER_DEVICE_WORKER_CLASS(HogwildWorker);
|
||||
REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,31 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/device_worker.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class DeviceWorkerFactory {
|
||||
public:
|
||||
static std::string DeviceWorkerTypeList();
|
||||
static std::shared_ptr<DeviceWorker> CreateDeviceWorker(
|
||||
std::string device_worker_class);
|
||||
};
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,24 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/framework/trainer.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
TEST() {
|
||||
// create hogwild device worker
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,80 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/data_feed_factory.h"
|
||||
#include "paddle/fluid/framework/data_set.h"
|
||||
#include "paddle/fluid/framework/device_worker_factory.h"
|
||||
#include "paddle/fluid/framework/trainer.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
|
||||
Dataset* dataset) {
|
||||
thread_num_ = trainer_desc.thread_num();
|
||||
SetDataset(dataset);
|
||||
|
||||
dataset->CreateReaders();
|
||||
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
|
||||
dataset->GetReaders();
|
||||
|
||||
thread_num_ = readers.size();
|
||||
workers_.resize(thread_num_);
|
||||
|
||||
for (int i = 0; i < thread_num_; ++i) {
|
||||
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
|
||||
trainer_desc.device_worker_name());
|
||||
workers_[i]->SetDeviceIndex(i);
|
||||
workers_[i]->SetDataFeed(readers[i]);
|
||||
workers_[i]->Initialize(trainer_desc);
|
||||
}
|
||||
|
||||
VLOG(3) << "going to initialize pull dense worker";
|
||||
pull_dense_worker_ = PullDenseWorker::GetInstance();
|
||||
pull_dense_worker_->Initialize(trainer_desc);
|
||||
VLOG(3) << "initialize pull dense worker";
|
||||
SetDebug(trainer_desc.debug());
|
||||
}
|
||||
|
||||
void DistMultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
|
||||
pull_dense_worker_->SetRootScope(root_scope_);
|
||||
pull_dense_worker_->Start();
|
||||
VLOG(3) << "init other env done.";
|
||||
}
|
||||
|
||||
void DistMultiTrainer::Run() {
|
||||
for (int thidx = 0; thidx < thread_num_; ++thidx) {
|
||||
if (!debug_) {
|
||||
threads_.push_back(
|
||||
std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get()));
|
||||
} else {
|
||||
threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler,
|
||||
workers_[thidx].get()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DistMultiTrainer::Finalize() {
|
||||
for (auto& th : threads_) {
|
||||
th.join();
|
||||
}
|
||||
pull_dense_worker_->Stop();
|
||||
dataset_ptr_->DestroyReaders();
|
||||
root_scope_->DropKids();
|
||||
}
|
||||
|
||||
} // end namespace framework
|
||||
} // end namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,5 @@
|
||||
if(WITH_PSLIB)
|
||||
cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope pslib_brpc pslib)
|
||||
else()
|
||||
cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope)
|
||||
endif(WITH_PSLIB)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue