refine executor_thread_worker.h and executor_thread_worker.cc code style

revert-15207-remove_op_handle_lock_and_fix_var
dongdaxiang 7 years ago
parent c4cb414291
commit c59cdf3a24

File diff suppressed because it is too large Load Diff

@ -35,21 +35,22 @@ const static uint32_t MAX_FEASIGN_NUM = 1000 * 100 * 100;
void CreateTensor(Variable* var, proto::VarType::Type var_type); void CreateTensor(Variable* var, proto::VarType::Type var_type);
struct AsyncWorkerParamConfig { struct AsyncWorkerParamConfig {
int slot_dim; int slot_dim;
int fea_dim; int fea_dim;
int32_t tmp_push_dense_wait_times; int32_t tmp_push_dense_wait_times;
int32_t tmp_push_sparse_wait_times; int32_t tmp_push_sparse_wait_times;
std::vector<std::string> skip_op; std::vector<std::string> skip_op;
std::map<uint64_t, std::vector<std::string>> dense_variable_name; std::map<uint64_t, std::vector<std::string>> dense_variable_name;
std::map<uint64_t, std::vector<std::string>> dense_gradient_variable_name; std::map<uint64_t, std::vector<std::string>> dense_gradient_variable_name;
std::vector<int> dense_table_id; std::vector<int> dense_table_id;
std::vector<uint32_t> dense_table_size; // fea_dim for each dense table // fea_dim for each dense table
std::vector<int> sparse_table_id; std::vector<uint32_t> dense_table_size;
std::map<uint64_t, std::vector<std::string>> slot_input_vec; //6048slot 6050slot //name std::vector<int> sparse_table_id;
std::map<uint64_t, std::vector<std::string>> gradient_var; //6048slot_embed std::map<uint64_t, std::vector<std::string>> slot_input_vec;
std::map<std::string, uint64_t> slot_alias_to_table; //TODO done std::map<uint64_t, std::vector<std::string>> gradient_var;
std::map<std::string, uint64_t> slot_alias_to_table;
}; };
struct DensePullThreadParam { struct DensePullThreadParam {
@ -62,8 +63,8 @@ struct DensePullThreadParam {
}; };
class DensePullThread { class DensePullThread {
public: public:
DensePullThread(DensePullThreadParam& param) : explicit DensePullThread(const DensePullThreadParam& param) :
_running(false) { _running(false) {
_ps_client = param.ps_client; _ps_client = param.ps_client;
_threshold = param.threshold; _threshold = param.threshold;
@ -96,11 +97,11 @@ public:
void pull_dense2(uint64_t table_id); void pull_dense2(uint64_t table_id);
void wait_all(); void wait_all();
private: private:
void run(); void run();
bool check_update_param(uint64_t table_id); bool check_update_param(uint64_t table_id);
private: private:
std::shared_ptr<paddle::ps::PSClient> _ps_client; std::shared_ptr<paddle::ps::PSClient> _ps_client;
int _thread_num; int _thread_num;
int _threshold; int _threshold;
@ -153,9 +154,13 @@ class ExecutorThreadWorker {
virtual void TrainFiles(); virtual void TrainFiles();
// set fetch variable names from python interface assigned by users // set fetch variable names from python interface assigned by users
void SetFetchVarNames(const std::vector<std::string>& fetch_var_names); void SetFetchVarNames(const std::vector<std::string>& fetch_var_names);
virtual void SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr); virtual void SetPSlibPtr(
virtual void SetPullDenseThread(std::shared_ptr<DensePullThread> dpt) {}; std::shared_ptr<paddle::distributed::PSlib> pslib_ptr);
virtual void SetParamConfig(AsyncWorkerParamConfig* param_config) {}; virtual void SetPullDenseThread(
std::shared_ptr<DensePullThread> dpt) {}
virtual void SetParamConfig(
AsyncWorkerParamConfig * param_config) {}
private: private:
void CreateThreadScope(const framework::ProgramDesc& program); void CreateThreadScope(const framework::ProgramDesc& program);
void CreateThreadOperators(const framework::ProgramDesc& program); void CreateThreadOperators(const framework::ProgramDesc& program);
@ -178,32 +183,37 @@ class ExecutorThreadWorker {
Scope* root_scope_; Scope* root_scope_;
// a thread scope, father scope is global score which is shared // a thread scope, father scope is global score which is shared
Scope* thread_scope_; Scope* thread_scope_;
//private:
std::vector<std::string> fetch_var_names_; std::vector<std::string> fetch_var_names_;
std::vector<std::vector<float>> fetch_values_; std::vector<std::vector<float>> fetch_values_;
bool debug_; bool debug_;
}; };
class AsyncExecutorThreadWorker: public ExecutorThreadWorker { class AsyncExecutorThreadWorker: public ExecutorThreadWorker {
public: public:
AsyncExecutorThreadWorker(){}; AsyncExecutorThreadWorker() {}
virtual ~AsyncExecutorThreadWorker() {} virtual ~AsyncExecutorThreadWorker() {}
void SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr); void SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr);
void SetPullDenseThread(std::shared_ptr<DensePullThread> dpt); void SetPullDenseThread(std::shared_ptr<DensePullThread> dpt);
void SetParamConfig(AsyncWorkerParamConfig* param_config); void SetParamConfig(AsyncWorkerParamConfig* param_config);
void TrainFiles(); void TrainFiles();
void TrainOneNetwork(); void TrainOneNetwork();
void PrepareParams(); void PrepareParams();
void UpdateParams(); void UpdateParams();
void PullSparse(int table_id); void PullSparse(int table_id);
void FillSparse(int table_id); void FillSparse(int table_id);
void PushSparse(int table_id); void PushSparse(int table_id);
void PushDense(int table_id); void PushDense(int table_id);
void check_pull_push_memory(std::vector<uint64_t>& features, std::vector<float*>& push_g, int dim); void check_pull_push_memory(
void check_pull_push_memory(std::vector<uint64_t>& features, std::vector<std::vector<float>>& push_g, int dim); const std::vector<uint64_t>& features,
std::vector<float*>& push_g,
int dim);
void check_pull_push_memory(const std::vector<uint64_t>& features,
std::vector<std::vector<float>>& push_g,
int dim);
void collect_feasign_info(int table_id); void collect_feasign_info(int table_id);
private:
private:
struct FeasignInfo { struct FeasignInfo {
uint32_t slot; uint32_t slot;
uint32_t ins; uint32_t ins;

Loading…
Cancel
Save