add service (#29560)
	
		
	
				
					
				
			* add service, remove ut on mac * fix heter_profiler & add heter stop method * fix code stylerevert-31562-mean
							parent
							
								
									c0163837a5
								
							
						
					
					
						commit
						0034273b7e
					
				
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,246 @@
 | 
				
			|||||||
 | 
					/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License. */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <atomic>
 | 
				
			||||||
 | 
					#include <ctime>
 | 
				
			||||||
 | 
					#include <map>
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <random>
 | 
				
			||||||
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <unordered_map>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <ThreadPool.h>
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/communicator_common.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/service.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/archive.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/io/fs.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/io/shell.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/program_desc.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/scope.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/tensor.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/variable_helper.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace paddle {
 | 
				
			||||||
 | 
					namespace distributed {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using framework::LoDTensor;
 | 
				
			||||||
 | 
					using framework::Scope;
 | 
				
			||||||
 | 
					using framework::SelectedRows;
 | 
				
			||||||
 | 
					using framework::Variable;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using RpcCtxMap = std::unordered_map<std::string, CommContext>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FleetWrapper {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  virtual ~FleetWrapper() {}
 | 
				
			||||||
 | 
					  FleetWrapper() {
 | 
				
			||||||
 | 
					    scale_sparse_gradient_with_batch_size_ = true;
 | 
				
			||||||
 | 
					    // trainer sleep some time for pserver core dump
 | 
				
			||||||
 | 
					    sleep_seconds_before_fail_exit_ = 300;
 | 
				
			||||||
 | 
					    // pserver request server timeout ms
 | 
				
			||||||
 | 
					    client2client_request_timeout_ms_ = 500000;
 | 
				
			||||||
 | 
					    // pserver connect server timeout_ms
 | 
				
			||||||
 | 
					    client2client_connect_timeout_ms_ = 10000;
 | 
				
			||||||
 | 
					    // pserver request max retry
 | 
				
			||||||
 | 
					    client2client_max_retry_ = 3;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // set client to client communication config
 | 
				
			||||||
 | 
					  void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms,
 | 
				
			||||||
 | 
					                              int max_retry);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Pull sparse variables from server in sync mode
 | 
				
			||||||
 | 
					  // Param<in>: scope, table_id, var_names, fea_keys, fea_dim, var_emb_names
 | 
				
			||||||
 | 
					  // Param<out>: fea_values
 | 
				
			||||||
 | 
					  void PullSparseVarsSync(const Scope& scope, const uint64_t table_id,
 | 
				
			||||||
 | 
					                          const std::vector<std::string>& var_names,
 | 
				
			||||||
 | 
					                          std::vector<uint64_t>* fea_keys,
 | 
				
			||||||
 | 
					                          std::vector<std::vector<float>>* fea_values,
 | 
				
			||||||
 | 
					                          int fea_dim,
 | 
				
			||||||
 | 
					                          const std::vector<std::string>& var_emb_names);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Pull sparse variables from server in async mode
 | 
				
			||||||
 | 
					  // Param<in>: scope, table_id, var_names, fea_keys, fea_dim
 | 
				
			||||||
 | 
					  // Param<out>: fea_values std::future
 | 
				
			||||||
 | 
					  std::future<int32_t> PullSparseVarsAsync(
 | 
				
			||||||
 | 
					      const Scope& scope, const uint64_t table_id,
 | 
				
			||||||
 | 
					      const std::vector<std::string>& var_names,
 | 
				
			||||||
 | 
					      std::vector<uint64_t>* fea_keys,
 | 
				
			||||||
 | 
					      std::vector<std::vector<float>>* fea_values, int fea_dim);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Pull sparse variables from server in sync mode
 | 
				
			||||||
 | 
					  // pull immediately to tensors
 | 
				
			||||||
 | 
					  void PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
 | 
				
			||||||
 | 
					                              uint64_t padding_id, platform::Place place,
 | 
				
			||||||
 | 
					                              std::vector<const LoDTensor*>* inputs,  // NOLINT
 | 
				
			||||||
 | 
					                              std::vector<LoDTensor*>* outputs);      // NOLINT
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // pull dense variables from server in sync mod
 | 
				
			||||||
 | 
					  // Param<in>: scope, table_id, var_names
 | 
				
			||||||
 | 
					  // Param<out>: void
 | 
				
			||||||
 | 
					  void PullDenseVarsSync(const Scope& scope, const uint64_t table_id,
 | 
				
			||||||
 | 
					                         const std::vector<std::string>& var_names);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // pull dense variables from server in async mod
 | 
				
			||||||
 | 
					  // Param<in>: scope, table_id, var_names
 | 
				
			||||||
 | 
					  // Param<out>: pull_dense_status
 | 
				
			||||||
 | 
					  void PullDenseVarsAsync(const Scope& scope, const uint64_t table_id,
 | 
				
			||||||
 | 
					                          const std::vector<std::string>& var_names,
 | 
				
			||||||
 | 
					                          std::vector<std::future<int32_t>>* pull_dense_status,
 | 
				
			||||||
 | 
					                          bool in_cpu);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // push dense parameters(not gradients) to server in sync mode
 | 
				
			||||||
 | 
					  void PushDenseParamSync(const Scope& scope, const uint64_t table_id,
 | 
				
			||||||
 | 
					                          const std::vector<std::string>& var_names);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void PushDenseVarsAsync(const Scope& scope, const uint64_t table_id,
 | 
				
			||||||
 | 
					                          const std::vector<std::string>& var_names,
 | 
				
			||||||
 | 
					                          std::vector<std::future<int32_t>>* push_sparse_status,
 | 
				
			||||||
 | 
					                          float scale_datanorm, int batch_size);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // push dense variables to server in sync mode
 | 
				
			||||||
 | 
					  void PushDenseVarsSync(Scope* scope, const uint64_t table_id,
 | 
				
			||||||
 | 
					                         const std::vector<std::string>& var_names);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void PushSparseVarsAsync(
 | 
				
			||||||
 | 
					      const Scope& scope, const uint64_t table_id, const std::string& grad,
 | 
				
			||||||
 | 
					      std::vector<std::future<int32_t>>* push_sparse_status);
 | 
				
			||||||
 | 
					  // This is specially designed for click/show stats in server
 | 
				
			||||||
 | 
					  // Param<in>: scope, table_id, fea_keys, fea_labels, sparse_key_names,
 | 
				
			||||||
 | 
					  //            sparse_grad_names, batch_size, use_cvm, dump_slot
 | 
				
			||||||
 | 
					  // Param<out>: push_values, push_sparse_status
 | 
				
			||||||
 | 
					  void PushSparseVarsWithLabelAsync(
 | 
				
			||||||
 | 
					      const Scope& scope, const uint64_t table_id,
 | 
				
			||||||
 | 
					      const std::vector<uint64_t>& fea_keys,
 | 
				
			||||||
 | 
					      const std::vector<float>& fea_labels,
 | 
				
			||||||
 | 
					      const std::vector<std::string>& sparse_key_names,
 | 
				
			||||||
 | 
					      const std::vector<std::string>& sparse_grad_names, const int emb_dim,
 | 
				
			||||||
 | 
					      std::vector<std::vector<float>>* push_values,
 | 
				
			||||||
 | 
					      std::vector<std::future<int32_t>>* push_sparse_status,
 | 
				
			||||||
 | 
					      const int batch_size, const bool use_cvm, const bool dump_slot,
 | 
				
			||||||
 | 
					      std::vector<uint64_t>* sparse_push_keys, const bool no_cvm);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Push sparse variables to server in async mode
 | 
				
			||||||
 | 
					  void PushSparseFromTensorWithLabelAsync(
 | 
				
			||||||
 | 
					      const Scope& scope, const uint64_t table_id, int fea_dim,
 | 
				
			||||||
 | 
					      uint64_t padding_id, bool scale_sparse, const std::string& accesor,
 | 
				
			||||||
 | 
					      const std::string& click_name, platform::Place place,
 | 
				
			||||||
 | 
					      const std::vector<std::string>& input_names,
 | 
				
			||||||
 | 
					      std::vector<const LoDTensor*>* inputs,    // NOLINT
 | 
				
			||||||
 | 
					      std::vector<const LoDTensor*>* outputs);  // NOLINT
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Push sparse variables to server in Async mode
 | 
				
			||||||
 | 
					  // Param<In>: scope, table_id, fea_keys, sparse_grad_names
 | 
				
			||||||
 | 
					  // Param<Out>: push_values, push_sparse_status
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // init server
 | 
				
			||||||
 | 
					  void LoadSparseOnServer(const std::string& path, const std::string& meta,
 | 
				
			||||||
 | 
					                          uint32_t table_id);
 | 
				
			||||||
 | 
					  // init server
 | 
				
			||||||
 | 
					  // void InitServer(const std::string& dist_desc,
 | 
				
			||||||
 | 
					  //                 const std::vector<uint64_t>& host_sign_list, int index);
 | 
				
			||||||
 | 
					  void InitServer(const std::string& dist_desc,
 | 
				
			||||||
 | 
					                  const std::vector<std::string>& host_sign_list, int index);
 | 
				
			||||||
 | 
					  // init trainer
 | 
				
			||||||
 | 
					  void InitWorker(const std::string& dist_desc,
 | 
				
			||||||
 | 
					                  const std::vector<std::string>& host_sign_list, Scope* scope,
 | 
				
			||||||
 | 
					                  const RpcCtxMap& send_ctx,
 | 
				
			||||||
 | 
					                  const std::unordered_map<uint64_t, std::vector<std::string>>&
 | 
				
			||||||
 | 
					                      dense_varnames,
 | 
				
			||||||
 | 
					                  const std::map<std::string, std::string>& envs, int node_num,
 | 
				
			||||||
 | 
					                  int index);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // stop server
 | 
				
			||||||
 | 
					  void StopServer();
 | 
				
			||||||
 | 
					  // finalize worker to make worker can be stop
 | 
				
			||||||
 | 
					  void FinalizeWorker();
 | 
				
			||||||
 | 
					  // run server with ip port
 | 
				
			||||||
 | 
					  uint64_t RunServer(const std::string& ip, uint32_t port);
 | 
				
			||||||
 | 
					  // get client info
 | 
				
			||||||
 | 
					  std::vector<uint64_t> GetClientsInfo();
 | 
				
			||||||
 | 
					  // create client to client connection
 | 
				
			||||||
 | 
					  void CreateClient2ClientConnection();
 | 
				
			||||||
 | 
					  // flush all push requests
 | 
				
			||||||
 | 
					  void ClientFlush();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // barrier with barrier table
 | 
				
			||||||
 | 
					  void BarrierWithTable(uint32_t barrier_type);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void PrintTableStat(const uint64_t table_id);
 | 
				
			||||||
 | 
					  // mode = 0, load all feature
 | 
				
			||||||
 | 
					  // mode = 1, load delta feature, which means load diff
 | 
				
			||||||
 | 
					  void LoadModel(const std::string& path, const int mode);
 | 
				
			||||||
 | 
					  // mode = 0, load all feature
 | 
				
			||||||
 | 
					  // mode = 1, load delta feature, which means load diff
 | 
				
			||||||
 | 
					  void LoadModelOneTable(const uint64_t table_id, const std::string& path,
 | 
				
			||||||
 | 
					                         const int mode);
 | 
				
			||||||
 | 
					  // mode = 0, save all feature
 | 
				
			||||||
 | 
					  // mode = 1, save delta feature, which means save diff
 | 
				
			||||||
 | 
					  void SaveModel(const std::string& path, const int mode);
 | 
				
			||||||
 | 
					  // mode = 0, save all feature
 | 
				
			||||||
 | 
					  // mode = 1, save delta feature, which means save diff
 | 
				
			||||||
 | 
					  void SaveModelOneTable(const uint64_t table_id, const std::string& path,
 | 
				
			||||||
 | 
					                         const int mode);
 | 
				
			||||||
 | 
					  // clear all models, release their memory
 | 
				
			||||||
 | 
					  void ClearModel();
 | 
				
			||||||
 | 
					  // clear one table
 | 
				
			||||||
 | 
					  void ClearOneTable(const uint64_t table_id);
 | 
				
			||||||
 | 
					  // shrink sparse table
 | 
				
			||||||
 | 
					  void ShrinkSparseTable(int table_id);
 | 
				
			||||||
 | 
					  // shrink dense table
 | 
				
			||||||
 | 
					  void ShrinkDenseTable(int table_id, Scope* scope,
 | 
				
			||||||
 | 
					                        std::vector<std::string> var_list, float decay,
 | 
				
			||||||
 | 
					                        int emb_dim);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc;
 | 
				
			||||||
 | 
					  // register client to client communication
 | 
				
			||||||
 | 
					  int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
 | 
				
			||||||
 | 
					  // send client to client message
 | 
				
			||||||
 | 
					  std::future<int32_t> SendClientToClientMsg(int msg_type, int to_client_id,
 | 
				
			||||||
 | 
					                                             const std::string& msg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // FleetWrapper singleton
 | 
				
			||||||
 | 
					  static std::shared_ptr<FleetWrapper> GetInstance() {
 | 
				
			||||||
 | 
					    if (NULL == s_instance_) {
 | 
				
			||||||
 | 
					      s_instance_.reset(new paddle::distributed::FleetWrapper());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return s_instance_;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // this performs better than rand_r, especially large data
 | 
				
			||||||
 | 
					  std::default_random_engine& LocalRandomEngine();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static std::shared_ptr<paddle::distributed::PSCore> pserver_ptr_;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  static std::shared_ptr<FleetWrapper> s_instance_;
 | 
				
			||||||
 | 
					  size_t GetAbsoluteSum(size_t start, size_t end, size_t level,
 | 
				
			||||||
 | 
					                        const framework::LoD& lod);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 protected:
 | 
				
			||||||
 | 
					  static bool is_initialized_;
 | 
				
			||||||
 | 
					  std::map<uint64_t, std::vector<paddle::distributed::Region>> _regions;
 | 
				
			||||||
 | 
					  bool scale_sparse_gradient_with_batch_size_;
 | 
				
			||||||
 | 
					  int32_t sleep_seconds_before_fail_exit_;
 | 
				
			||||||
 | 
					  int client2client_request_timeout_ms_;
 | 
				
			||||||
 | 
					  int client2client_connect_timeout_ms_;
 | 
				
			||||||
 | 
					  int client2client_max_retry_;
 | 
				
			||||||
 | 
					  DISABLE_COPY_AND_ASSIGN(FleetWrapper);
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // end namespace distributed
 | 
				
			||||||
 | 
					}  // end namespace paddle
 | 
				
			||||||
@ -0,0 +1,40 @@
 | 
				
			|||||||
 | 
					set(BRPC_SRCS ps_client.cc server.cc)
 | 
				
			||||||
 | 
					set_source_files_properties(${BRPC_SRCS})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					brpc_library(sendrecv_rpc SRCS
 | 
				
			||||||
 | 
					        ${BRPC_SRCS}
 | 
				
			||||||
 | 
					        PROTO sendrecv.proto
 | 
				
			||||||
 | 
					        DEPS ${BRPC_DEPS} )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					set_property(GLOBAL PROPERTY RPC_DEPS sendrecv_rpc ${BRPC_DEPS} string_helper)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					set_source_files_properties(communicator.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
 | 
				
			||||||
 | 
					set_source_files_properties(service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
 | 
				
			||||||
 | 
					set_source_files_properties(brpc_ps_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
 | 
				
			||||||
 | 
					set_source_files_properties(brpc_ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					set_source_files_properties(brpc_utils.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
 | 
				
			||||||
 | 
					set_source_files_properties(heter_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
 | 
				
			||||||
 | 
					set_source_files_properties(heter_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					set_source_files_properties(client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
 | 
				
			||||||
 | 
					set_source_files_properties(ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
 | 
				
			||||||
 | 
					set_source_files_properties(server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc_library(downpour_server SRCS brpc_ps_server.cc DEPS boost eigen3 table ${RPC_DEPS})
 | 
				
			||||||
 | 
					cc_library(downpour_client SRCS brpc_ps_client.cc DEPS boost eigen3 table ${RPC_DEPS})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc_library(client SRCS ps_client.cc DEPS downpour_client boost ${RPC_DEPS})
 | 
				
			||||||
 | 
					cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc_library(communicator SRCS communicator.cc DEPS scope client boost table math_function selected_rows_functor ${RPC_DEPS})
 | 
				
			||||||
 | 
					cc_library(ps_service SRCS service.cc DEPS communicator client server boost ${RPC_DEPS})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc_library(brpc_utils SRCS brpc_utils.cc DEPS ${COMMON_DEPS} ${RPC_DEPS})
 | 
				
			||||||
 | 
					cc_library(heter_server SRCS heter_server.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
 | 
				
			||||||
 | 
					cc_library(heter_client SRCS heter_client.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
 | 
				
			||||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,212 @@
 | 
				
			|||||||
 | 
					// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "brpc/channel.h"
 | 
				
			||||||
 | 
					#include "brpc/controller.h"
 | 
				
			||||||
 | 
					#include "brpc/server.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/ps_client.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace paddle {
 | 
				
			||||||
 | 
					namespace distributed {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DownpourPsClientService : public PsService {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  DownpourPsClientService() {}
 | 
				
			||||||
 | 
					  virtual ~DownpourPsClientService() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual int32_t configure(PSClient *client, size_t rank_id) {
 | 
				
			||||||
 | 
					    _client = client;
 | 
				
			||||||
 | 
					    _rank = rank_id;
 | 
				
			||||||
 | 
					    return 0;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  virtual void service(::google::protobuf::RpcController *controller,
 | 
				
			||||||
 | 
					                       const ::paddle::PsRequestMessage *request,
 | 
				
			||||||
 | 
					                       ::paddle::PsResponseMessage *response,
 | 
				
			||||||
 | 
					                       ::google::protobuf::Closure *done) override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 protected:
 | 
				
			||||||
 | 
					  size_t _rank;
 | 
				
			||||||
 | 
					  PSClient *_client;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DownpourBrpcClosure : public PSClientClosure {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  DownpourBrpcClosure(size_t num, PSClientCallBack callback)
 | 
				
			||||||
 | 
					      : PSClientClosure(callback) {
 | 
				
			||||||
 | 
					    _waiting_num = num;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _cntls.resize(num);
 | 
				
			||||||
 | 
					    _requests.resize(num);
 | 
				
			||||||
 | 
					    _responses.resize(num);
 | 
				
			||||||
 | 
					    for (size_t i = 0; i < num; ++i) {
 | 
				
			||||||
 | 
					      _cntls[i].reset(new brpc::Controller());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  virtual ~DownpourBrpcClosure() {}
 | 
				
			||||||
 | 
					  virtual void Run() override {
 | 
				
			||||||
 | 
					    if (_waiting_num.fetch_sub(1) == 1) {
 | 
				
			||||||
 | 
					      _callback(this);
 | 
				
			||||||
 | 
					      delete this;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  PsRequestMessage *request(size_t i) { return &_requests[i]; }
 | 
				
			||||||
 | 
					  PsResponseMessage *response(size_t i) { return &_responses[i]; }
 | 
				
			||||||
 | 
					  brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
 | 
				
			||||||
 | 
					  int check_response(size_t request_idx, int cmd_id);
 | 
				
			||||||
 | 
					  int check_save_response(size_t request_idx, int cmd_id);
 | 
				
			||||||
 | 
					  std::string get_response(size_t request_idx, int cmd_id);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  std::atomic<int32_t> _waiting_num;
 | 
				
			||||||
 | 
					  std::vector<PsRequestMessage> _requests;
 | 
				
			||||||
 | 
					  std::vector<PsResponseMessage> _responses;
 | 
				
			||||||
 | 
					  std::vector<std::shared_ptr<brpc::Controller>> _cntls;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <class T>
 | 
				
			||||||
 | 
					struct array_deleter {
 | 
				
			||||||
 | 
					  void operator()(T *&x) const { delete[] x; }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BrpcPsClient : public PSClient {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  BrpcPsClient() {}
 | 
				
			||||||
 | 
					  virtual ~BrpcPsClient() {
 | 
				
			||||||
 | 
					    // _running = false;
 | 
				
			||||||
 | 
					    // try {
 | 
				
			||||||
 | 
					    // _async_push_dense_thread.join();
 | 
				
			||||||
 | 
					    // _async_push_sparse_thread.join();
 | 
				
			||||||
 | 
					    //} catch (...) {
 | 
				
			||||||
 | 
					    //}
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  virtual int32_t create_client2client_connection(
 | 
				
			||||||
 | 
					      int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> shrink(uint32_t table_id) override;
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> load(const std::string &epoch,
 | 
				
			||||||
 | 
					                                    const std::string &mode) override;
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
 | 
				
			||||||
 | 
					                                    const std::string &mode) override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> save(const std::string &epoch,
 | 
				
			||||||
 | 
					                                    const std::string &mode) override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
 | 
				
			||||||
 | 
					                                    const std::string &mode) override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> clear() override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> clear(uint32_t table_id) override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> stop_server() override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> start_profiler() override;
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> stop_profiler() override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual void finalize_worker() override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
 | 
				
			||||||
 | 
					                                          size_t table_id);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> push_dense_param(const Region *regions,
 | 
				
			||||||
 | 
					                                                size_t region_num,
 | 
				
			||||||
 | 
					                                                size_t table_id);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> pull_sparse(float **select_values,
 | 
				
			||||||
 | 
					                                           size_t table_id,
 | 
				
			||||||
 | 
					                                           const uint64_t *keys, size_t num);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> print_table_stat(uint32_t table_id);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> pull_geo_param(size_t table_id,
 | 
				
			||||||
 | 
					                                              std::vector<float> *values,
 | 
				
			||||||
 | 
					                                              std::vector<uint64_t> *keys,
 | 
				
			||||||
 | 
					                                              int pserver_idx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> flush();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> send_client2client_msg(
 | 
				
			||||||
 | 
					      int msg_type, int to_client_id, const std::string &msg) override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  virtual int32_t initialize() override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
 | 
				
			||||||
 | 
					                                      uint32_t shard_num) {
 | 
				
			||||||
 | 
					    return dense_dim_total / shard_num + 1;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::future<int32_t> send_cmd(uint32_t table_id, int cmd_id,
 | 
				
			||||||
 | 
					                                const std::vector<std::string> ¶m);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::future<int32_t> send_save_cmd(uint32_t table_id, int cmd_id,
 | 
				
			||||||
 | 
					                                     const std::vector<std::string> ¶m);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  inline brpc::Channel *get_sparse_channel(size_t server_id) {
 | 
				
			||||||
 | 
					    return _server_channels[server_id][0].get();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  inline brpc::Channel *get_dense_channel(size_t server_id) {
 | 
				
			||||||
 | 
					    return _server_channels[server_id][1].get();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  inline brpc::Channel *get_cmd_channel(size_t server_id) {
 | 
				
			||||||
 | 
					    return _server_channels[server_id][2].get();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  bool _running = false;
 | 
				
			||||||
 | 
					  bool _flushing = false;
 | 
				
			||||||
 | 
					  std::atomic<uint32_t> _async_call_num;  //异步请求计数
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::vector<std::shared_ptr<brpc::Channel>>
 | 
				
			||||||
 | 
					      _client_channels;  // client2client
 | 
				
			||||||
 | 
					  std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
 | 
				
			||||||
 | 
					      _server_channels;  // client2server
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> push_dense_raw_gradient(
 | 
				
			||||||
 | 
					      int table_id, float *total_send_data, size_t total_send_data_size,
 | 
				
			||||||
 | 
					      void *done) override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> push_sparse_raw_gradient(
 | 
				
			||||||
 | 
					      size_t table_id, const uint64_t *keys, const float **update_values,
 | 
				
			||||||
 | 
					      size_t num, void *done) override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> push_sparse_raw_gradient_partial(
 | 
				
			||||||
 | 
					      size_t table_id, const uint64_t *keys, const float **update_values,
 | 
				
			||||||
 | 
					      uint32_t num, void *done, int pserver_idx) override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual std::future<int32_t> push_sparse_param(size_t table_id,
 | 
				
			||||||
 | 
					                                                 const uint64_t *keys,
 | 
				
			||||||
 | 
					                                                 const float **update_values,
 | 
				
			||||||
 | 
					                                                 size_t num,
 | 
				
			||||||
 | 
					                                                 void *done) override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual size_t get_server_nums() { return _server_channels.size(); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  int32_t start_client_service();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  float _mae = 0;
 | 
				
			||||||
 | 
					  float _mse = 0;
 | 
				
			||||||
 | 
					  uint16_t _push_times = 0;
 | 
				
			||||||
 | 
					  brpc::Server _server;
 | 
				
			||||||
 | 
					  DownpourPsClientService _service;
 | 
				
			||||||
 | 
					  std::atomic_uint grad_num_{0};
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					}  // namespace distributed
 | 
				
			||||||
 | 
					}  // namespace paddle
 | 
				
			||||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,153 @@
 | 
				
			|||||||
 | 
					// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "brpc/channel.h"
 | 
				
			||||||
 | 
					#include "brpc/controller.h"
 | 
				
			||||||
 | 
					#include "brpc/server.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/server.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace paddle {
 | 
				
			||||||
 | 
					namespace distributed {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BrpcPsServer : public PSServer {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  BrpcPsServer() {}
 | 
				
			||||||
 | 
					  virtual ~BrpcPsServer() {}
 | 
				
			||||||
 | 
					  virtual uint64_t start(const std::string &ip, uint32_t port);
 | 
				
			||||||
 | 
					  virtual int32_t stop() {
 | 
				
			||||||
 | 
					    std::unique_lock<std::mutex> lock(mutex_);
 | 
				
			||||||
 | 
					    stoped_ = true;
 | 
				
			||||||
 | 
					    cv_.notify_all();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _server.Stop(1000);
 | 
				
			||||||
 | 
					    _server.Join();
 | 
				
			||||||
 | 
					    return 0;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  virtual int32_t port();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  virtual int32_t initialize();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  mutable std::mutex mutex_;
 | 
				
			||||||
 | 
					  std::condition_variable cv_;
 | 
				
			||||||
 | 
					  bool stoped_ = false;
 | 
				
			||||||
 | 
					  brpc::Server _server;
 | 
				
			||||||
 | 
					  std::shared_ptr<PsBaseService> _service;
 | 
				
			||||||
 | 
					  std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PsService;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					typedef int32_t (PsService::*serviceHandlerFunc)(
 | 
				
			||||||
 | 
					    Table *table, const PsRequestMessage &request, PsResponseMessage &response,
 | 
				
			||||||
 | 
					    brpc::Controller *cntl);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PsService : public PsBaseService {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  virtual int32_t initialize() override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual void service(::google::protobuf::RpcController *controller,
 | 
				
			||||||
 | 
					                       const ::paddle::PsRequestMessage *request,
 | 
				
			||||||
 | 
					                       ::paddle::PsResponseMessage *response,
 | 
				
			||||||
 | 
					                       ::google::protobuf::Closure *done) override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  int32_t initialize_shard_info();
 | 
				
			||||||
 | 
					  int32_t pull_dense(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                     PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					  int32_t push_dense(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                     PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					  int32_t push_dense_param(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                           PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					  int32_t push_sparse_param(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                            PsResponseMessage &response,
 | 
				
			||||||
 | 
					                            brpc::Controller *cntl);
 | 
				
			||||||
 | 
					  int32_t pull_sparse(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                      PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					  int32_t pull_geo_param(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                         PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					  int32_t barrier(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                  PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					  int32_t push_sparse(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                      PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					  int32_t load_one_table(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                         PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					  int32_t load_all_table(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                         PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					  int32_t save_one_table(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                         PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					  int32_t save_all_table(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                         PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					  int32_t shrink_table(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                       PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					  int32_t clear_one_table(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                          PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					  int32_t clear_all_table(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                          PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					  int32_t stop_server(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                      PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					  int32_t start_profiler(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                         PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					  int32_t stop_profiler(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                        PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  int32_t print_table_stat(Table *table, const PsRequestMessage &request,
 | 
				
			||||||
 | 
					                           PsResponseMessage &response, brpc::Controller *cntl);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  bool _is_initialize_shard_info;
 | 
				
			||||||
 | 
					  std::mutex _initialize_shard_mutex;
 | 
				
			||||||
 | 
					  std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
 | 
				
			||||||
 | 
					  std::unordered_map<int32_t, serviceHandlerFunc> _msg_handler_map;
 | 
				
			||||||
 | 
					  std::vector<float> _ori_values;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DownpourPServerBrpcClosure : public PServerClosure {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  DownpourPServerBrpcClosure(size_t num, PServerCallBack callback)
 | 
				
			||||||
 | 
					      : PServerClosure(callback) {
 | 
				
			||||||
 | 
					    _waiting_num = num;
 | 
				
			||||||
 | 
					    _cntls.resize(num);
 | 
				
			||||||
 | 
					    _requests.resize(num);
 | 
				
			||||||
 | 
					    _responses.resize(num);
 | 
				
			||||||
 | 
					    for (size_t i = 0; i < num; ++i) {
 | 
				
			||||||
 | 
					      _cntls[i].reset(new brpc::Controller());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  virtual ~DownpourPServerBrpcClosure() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual void Run() override {
 | 
				
			||||||
 | 
					    if (_waiting_num.fetch_sub(1) == 1) {
 | 
				
			||||||
 | 
					      _callback(this);
 | 
				
			||||||
 | 
					      delete this;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  PsRequestMessage *request(size_t i) { return &_requests[i]; }
 | 
				
			||||||
 | 
					  PsResponseMessage *response(size_t i) { return &_responses[i]; }
 | 
				
			||||||
 | 
					  brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
 | 
				
			||||||
 | 
					  int check_response(size_t request_idx, int cmd_id) { return 1; }
 | 
				
			||||||
 | 
					  int check_save_response(size_t request_idx, int cmd_id) { return 1; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  std::atomic<int32_t> _waiting_num;
 | 
				
			||||||
 | 
					  std::vector<PsRequestMessage> _requests;
 | 
				
			||||||
 | 
					  std::vector<PsResponseMessage> _responses;
 | 
				
			||||||
 | 
					  std::vector<std::shared_ptr<brpc::Controller>> _cntls;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					}  // namespace distributed
 | 
				
			||||||
 | 
					}  // namespace paddle
 | 
				
			||||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,86 @@
 | 
				
			|||||||
 | 
					/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License. */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <iostream>
 | 
				
			||||||
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "brpc/channel.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/sendrecv.pb.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/data_type.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/lod_tensor.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/scope.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/selected_rows.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/tensor_util.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/var_type.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/platform/port.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace grpc {
 | 
				
			||||||
 | 
					class ByteBuffer;
 | 
				
			||||||
 | 
					}  // namespace grpc
 | 
				
			||||||
 | 
					namespace paddle {
 | 
				
			||||||
 | 
					namespace framework {
 | 
				
			||||||
 | 
					class Scope;
 | 
				
			||||||
 | 
					class Variable;
 | 
				
			||||||
 | 
					}  // namespace framework
 | 
				
			||||||
 | 
					namespace platform {
 | 
				
			||||||
 | 
					class DeviceContext;
 | 
				
			||||||
 | 
					}  // namespace platform
 | 
				
			||||||
 | 
					}  // namespace paddle
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace paddle {
 | 
				
			||||||
 | 
					namespace distributed {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using MultiVarMsg = ::paddle::MultiVariableMessage;
 | 
				
			||||||
 | 
					using VarMsg = ::paddle::VariableMessage;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void SerializeToMultiVarMsgAndIOBuf(
 | 
				
			||||||
 | 
					    const std::string& message_name,
 | 
				
			||||||
 | 
					    const std::vector<std::string>& send_var_name_val,
 | 
				
			||||||
 | 
					    const std::vector<std::string>& recv_var_name_val,
 | 
				
			||||||
 | 
					    const platform::DeviceContext& ctx, const framework::Scope* scope,
 | 
				
			||||||
 | 
					    MultiVarMsg* var_msg, butil::IOBuf* iobuf);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void SerializeLodTensor(framework::Variable* var,
 | 
				
			||||||
 | 
					                        const platform::DeviceContext& ctx, VarMsg* var_msg,
 | 
				
			||||||
 | 
					                        butil::IOBuf* iobuf);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void SerializeSelectedRows(framework::Variable* var,
 | 
				
			||||||
 | 
					                           const platform::DeviceContext& ctx, VarMsg* request,
 | 
				
			||||||
 | 
					                           butil::IOBuf* iobuf);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Deserialize for Server
 | 
				
			||||||
 | 
					void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
 | 
				
			||||||
 | 
					                                        const butil::IOBuf* iobuf,
 | 
				
			||||||
 | 
					                                        const platform::DeviceContext& ctx,
 | 
				
			||||||
 | 
					                                        framework::Scope* scope);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Deserialize for Client
 | 
				
			||||||
 | 
					void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
 | 
				
			||||||
 | 
					                                        const butil::IOBuf* iobuf,
 | 
				
			||||||
 | 
					                                        const platform::DeviceContext& ctx,
 | 
				
			||||||
 | 
					                                        const framework::Scope* scope);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg,
 | 
				
			||||||
 | 
					                          butil::IOBufBytesIterator& iobuf,
 | 
				
			||||||
 | 
					                          const platform::DeviceContext& ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg,
 | 
				
			||||||
 | 
					                             butil::IOBufBytesIterator& iobuf,
 | 
				
			||||||
 | 
					                             const platform::DeviceContext& ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace distributed
 | 
				
			||||||
 | 
					}  // namespace paddle
 | 
				
			||||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,19 @@
 | 
				
			|||||||
 | 
					// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/env.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace paddle {
 | 
				
			||||||
 | 
					namespace distributed {}  // namespace distributed
 | 
				
			||||||
 | 
					}  // namespace paddle
 | 
				
			||||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,168 @@
 | 
				
			|||||||
 | 
					// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/heter_client.h"
 | 
				
			||||||
 | 
					#include <algorithm>
 | 
				
			||||||
 | 
					#include <utility>
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/channel.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/data_feed.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/device_worker.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/io/fs.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/scope.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/platform/profiler.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/platform/timer.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DECLARE_int32(rpc_deadline);
 | 
				
			||||||
 | 
					namespace paddle {
 | 
				
			||||||
 | 
					namespace distributed {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DEFINE_int32(pserver_timeout_ms, 10800000, "pserver request server timeout_ms");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::shared_ptr<HeterClient> HeterClient::s_instance_ = NULL;
 | 
				
			||||||
 | 
					bool HeterClient::is_initialized_ = false;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void HeterClient::MainThread() {
 | 
				
			||||||
 | 
					  while (running_) {
 | 
				
			||||||
 | 
					    RpcProfilerControl();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void HeterClient::Stop() {
 | 
				
			||||||
 | 
					  running_ = false;
 | 
				
			||||||
 | 
					  if (!is_initialized_) {
 | 
				
			||||||
 | 
					    VLOG(0) << "HeterClient is not inited, do nothing";
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    if (main_thread_) {
 | 
				
			||||||
 | 
					      auto status = StopHeterWorker();
 | 
				
			||||||
 | 
					      status.wait();
 | 
				
			||||||
 | 
					      main_thread_->join();
 | 
				
			||||||
 | 
					      main_thread_.reset(nullptr);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    VLOG(1) << "HeterClient Stop Done";
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void HeterClient::RpcProfilerControl() {
 | 
				
			||||||
 | 
					  if (trainer_id_ == 0) {
 | 
				
			||||||
 | 
					    if (!do_server_profiler_ && platform::IsProfileEnabled()) {
 | 
				
			||||||
 | 
					      // send profiler start flag
 | 
				
			||||||
 | 
					      do_server_profiler_ = true;
 | 
				
			||||||
 | 
					      auto start_status = StartProfiler();
 | 
				
			||||||
 | 
					      start_status.wait();
 | 
				
			||||||
 | 
					    } else if (do_server_profiler_ && !platform::IsProfileEnabled()) {
 | 
				
			||||||
 | 
					      // send profiler end flag
 | 
				
			||||||
 | 
					      auto stop_status = StopProfiler();
 | 
				
			||||||
 | 
					      stop_status.wait();
 | 
				
			||||||
 | 
					      do_server_profiler_ = false;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void HeterClient::CreateClient2XpuConnection() {
 | 
				
			||||||
 | 
					  brpc::ChannelOptions options;
 | 
				
			||||||
 | 
					  options.protocol = "baidu_std";
 | 
				
			||||||
 | 
					  options.connection_type = "single";
 | 
				
			||||||
 | 
					  options.timeout_ms = pserver_timeout_ms;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  xpu_channels_.resize(xpu_list_.size());
 | 
				
			||||||
 | 
					  for (size_t i = 0; i < xpu_list_.size(); ++i) {
 | 
				
			||||||
 | 
					    xpu_channels_[i].reset(new brpc::Channel());
 | 
				
			||||||
 | 
					    if (xpu_channels_[i]->Init(xpu_list_[i].c_str(), "", &options) != 0) {
 | 
				
			||||||
 | 
					      VLOG(0) << "HeterServer channel init fail";
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void HeterClient::SendAndRecvAsync(
 | 
				
			||||||
 | 
					    const std::vector<std::string>& ep, const platform::DeviceContext& ctx,
 | 
				
			||||||
 | 
					    const framework::Scope& scope, const std::string& message_name,
 | 
				
			||||||
 | 
					    const std::vector<std::string>& send_var_name,
 | 
				
			||||||
 | 
					    const std::vector<std::string>& recv_var_name) {
 | 
				
			||||||
 | 
					  platform::RecordEvent record_event("HeterClient->SendAndRecvAsync");
 | 
				
			||||||
 | 
					  const platform::DeviceContext* p_ctx = &ctx;
 | 
				
			||||||
 | 
					  const framework::Scope* p_scope = &scope;
 | 
				
			||||||
 | 
					  const std::string message_name_val = message_name;
 | 
				
			||||||
 | 
					  const std::vector<std::string> send_var_name_val = send_var_name;
 | 
				
			||||||
 | 
					  const std::vector<std::string> recv_var_name_val = recv_var_name;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  VLOG(3) << "GRPCClient::SendAndRecv Begin, message_name: "
 | 
				
			||||||
 | 
					          << message_name_val;
 | 
				
			||||||
 | 
					  // Todo: get correct channel
 | 
				
			||||||
 | 
					  int num = trainer_id_ % xpu_channels_.size();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  brpc::Controller cntl;
 | 
				
			||||||
 | 
					  cntl.set_timeout_ms(pserver_timeout_ms);
 | 
				
			||||||
 | 
					  distributed::MultiVarMsg request, response;
 | 
				
			||||||
 | 
					  auto& request_io_buffer = cntl.request_attachment();
 | 
				
			||||||
 | 
					  ::paddle::PsService_Stub stub(xpu_channels_[num].get());
 | 
				
			||||||
 | 
					  distributed::SerializeToMultiVarMsgAndIOBuf(
 | 
				
			||||||
 | 
					      message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
 | 
				
			||||||
 | 
					      &request, &request_io_buffer);
 | 
				
			||||||
 | 
					  stub.SendAndRecvVariable(&cntl, &request, &response, NULL);
 | 
				
			||||||
 | 
					  PADDLE_ENFORCE_NE(
 | 
				
			||||||
 | 
					      cntl.Failed(), true,
 | 
				
			||||||
 | 
					      platform::errors::Unimplemented(
 | 
				
			||||||
 | 
					          "HeterClient::SendAndRecv meets brpc error, error message is %s",
 | 
				
			||||||
 | 
					          cntl.ErrorText()));
 | 
				
			||||||
 | 
					  VLOG(4) << "call heter_worker success";
 | 
				
			||||||
 | 
					  auto& response_io_buffer = cntl.response_attachment();
 | 
				
			||||||
 | 
					  distributed::DeserializeFromMultiVarMsgAndIOBuf(response, &response_io_buffer,
 | 
				
			||||||
 | 
					                                                  ctx, p_scope);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::future<int32_t> HeterClient::SendCmd(
 | 
				
			||||||
 | 
					    uint32_t table_id, int cmd_id, const std::vector<std::string>& params) {
 | 
				
			||||||
 | 
					  size_t request_call_num = xpu_channels_.size();
 | 
				
			||||||
 | 
					  paddle::distributed::DownpourBrpcClosure* closure =
 | 
				
			||||||
 | 
					      new paddle::distributed::DownpourBrpcClosure(
 | 
				
			||||||
 | 
					          request_call_num, [request_call_num, cmd_id](void* done) {
 | 
				
			||||||
 | 
					            int ret = 0;
 | 
				
			||||||
 | 
					            auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
 | 
				
			||||||
 | 
					            for (size_t i = 0; i < request_call_num; ++i) {
 | 
				
			||||||
 | 
					              if (closure->check_response(i, cmd_id) != 0) {
 | 
				
			||||||
 | 
					                ret = -1;
 | 
				
			||||||
 | 
					                break;
 | 
				
			||||||
 | 
					              }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            closure->set_promise_value(ret);
 | 
				
			||||||
 | 
					          });
 | 
				
			||||||
 | 
					  auto promise = std::make_shared<std::promise<int32_t>>();
 | 
				
			||||||
 | 
					  closure->add_promise(promise);
 | 
				
			||||||
 | 
					  std::future<int> fut = promise->get_future();
 | 
				
			||||||
 | 
					  for (size_t i = 0; i < request_call_num; ++i) {
 | 
				
			||||||
 | 
					    closure->request(i)->set_cmd_id(cmd_id);
 | 
				
			||||||
 | 
					    closure->request(i)->set_table_id(table_id);
 | 
				
			||||||
 | 
					    closure->request(i)->set_client_id(trainer_id_);
 | 
				
			||||||
 | 
					    for (const auto& param : params) {
 | 
				
			||||||
 | 
					      closure->request(i)->add_params(param);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    ::paddle::PsService_Stub rpc_stub(xpu_channels_[i].get());
 | 
				
			||||||
 | 
					    closure->cntl(i)->set_timeout_ms(
 | 
				
			||||||
 | 
					        pserver_timeout_ms);  // cmd msg don't limit timeout for save/load
 | 
				
			||||||
 | 
					    rpc_stub.service(closure->cntl(i), closure->request(i),
 | 
				
			||||||
 | 
					                     closure->response(i), closure);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return fut;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::future<int32_t> HeterClient::StartProfiler() {
 | 
				
			||||||
 | 
					  return SendCmd(-1, PS_START_PROFILER, {});
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::future<int32_t> HeterClient::StopProfiler() {
 | 
				
			||||||
 | 
					  return SendCmd(-1, PS_STOP_PROFILER, {});
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // end namespace distributed
 | 
				
			||||||
 | 
					}  // end namespace paddle
 | 
				
			||||||
@ -0,0 +1,127 @@
 | 
				
			|||||||
 | 
					/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License. */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma once
 | 
				
			||||||
 | 
					#include <atomic>
 | 
				
			||||||
 | 
					#include <ctime>
 | 
				
			||||||
 | 
					#include <map>
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <random>
 | 
				
			||||||
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <unordered_map>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					#include "brpc/channel.h"
 | 
				
			||||||
 | 
					#include "brpc/controller.h"
 | 
				
			||||||
 | 
					#include "brpc/server.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/brpc_ps_client.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/brpc_utils.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/sendrecv.pb.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/scope.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/tensor.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/variable_helper.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace paddle {
 | 
				
			||||||
 | 
					namespace distributed {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using MultiVarMsg = ::paddle::MultiVariableMessage;
 | 
				
			||||||
 | 
					using VarMsg = ::paddle::VariableMessage;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					typedef std::function<void(void*)> HeterRpcCallbackFunc;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class OnHeterRpcDone : public google::protobuf::Closure {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  OnHeterRpcDone(HeterRpcCallbackFunc func) : handler_(func) {}
 | 
				
			||||||
 | 
					  virtual ~OnHeterRpcDone() {}
 | 
				
			||||||
 | 
					  void Run() {
 | 
				
			||||||
 | 
					    std::unique_ptr<OnHeterRpcDone> self_guard(this);
 | 
				
			||||||
 | 
					    handler_(this);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  HeterRpcCallbackFunc handler_;
 | 
				
			||||||
 | 
					  MultiVariableMessage response;
 | 
				
			||||||
 | 
					  brpc::Controller cntl;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HeterClient {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  virtual ~HeterClient() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  HeterClient() {
 | 
				
			||||||
 | 
					    running_ = true;
 | 
				
			||||||
 | 
					    main_thread_.reset(
 | 
				
			||||||
 | 
					        new std::thread(std::bind(&HeterClient::MainThread, this)));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void CreateClient2XpuConnection();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void SendAndRecvAsync(const std::vector<std::string>& ep,
 | 
				
			||||||
 | 
					                        const platform::DeviceContext& ctx,
 | 
				
			||||||
 | 
					                        const framework::Scope& scope,
 | 
				
			||||||
 | 
					                        const std::string& message_name,
 | 
				
			||||||
 | 
					                        const std::vector<std::string>& send_var_name,
 | 
				
			||||||
 | 
					                        const std::vector<std::string>& recv_var_name);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // HeterClient singleton
 | 
				
			||||||
 | 
					  static std::shared_ptr<HeterClient> GetInstance(
 | 
				
			||||||
 | 
					      const std::vector<std::string>& endpoint, const int& trainer_id) {
 | 
				
			||||||
 | 
					    if (NULL == s_instance_) {
 | 
				
			||||||
 | 
					      is_initialized_ = true;
 | 
				
			||||||
 | 
					      s_instance_.reset(new paddle::distributed::HeterClient());
 | 
				
			||||||
 | 
					      std::vector<std::string> xpu_list = {endpoint};
 | 
				
			||||||
 | 
					      s_instance_->SetXpuList(endpoint);
 | 
				
			||||||
 | 
					      s_instance_->SetTrainerID(trainer_id);
 | 
				
			||||||
 | 
					      s_instance_->CreateClient2XpuConnection();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return s_instance_;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void Stop();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void MainThread();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void RpcProfilerControl();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::future<int32_t> SendCmd(uint32_t table_id, int cmd_id,
 | 
				
			||||||
 | 
					                               const std::vector<std::string>& params);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::future<int32_t> StartProfiler();
 | 
				
			||||||
 | 
					  std::future<int32_t> StopProfiler();
 | 
				
			||||||
 | 
					  std::future<int32_t> StopHeterWorker();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::vector<std::string>& GetXpuList() { return xpu_list_; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void SetXpuList(const std::vector<std::string>& xpu_list) {
 | 
				
			||||||
 | 
					    xpu_list_ = xpu_list;
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  static std::shared_ptr<HeterClient> s_instance_;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 protected:
 | 
				
			||||||
 | 
					  static bool is_initialized_;
 | 
				
			||||||
 | 
					  std::unique_ptr<std::thread> main_thread_{nullptr};
 | 
				
			||||||
 | 
					  std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_;
 | 
				
			||||||
 | 
					  DISABLE_COPY_AND_ASSIGN(HeterClient);
 | 
				
			||||||
 | 
					  std::vector<std::string> xpu_list_;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  bool running_ = false;
 | 
				
			||||||
 | 
					  int trainer_id_;
 | 
				
			||||||
 | 
					  bool do_server_profiler_ = false;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // end namespace distributed
 | 
				
			||||||
 | 
					}  // end namespace paddle
 | 
				
			||||||
@ -0,0 +1,91 @@
 | 
				
			|||||||
 | 
					// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/heter_server.h"
 | 
				
			||||||
 | 
					#include <algorithm>
 | 
				
			||||||
 | 
					#include <utility>
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/fleet/heter_wrapper.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/op_registry.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/scope.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/platform/timer.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace paddle {
 | 
				
			||||||
 | 
					namespace distributed {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::shared_ptr<HeterServer> HeterServer::s_instance_ = NULL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void HeterServer::RegisterServiceHandler(std::string message_name,
 | 
				
			||||||
 | 
					                                         HeterServiceHandler func) {
 | 
				
			||||||
 | 
					  service_.RegisterServiceHandler(message_name, func);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void HeterServer::StartHeterService() {
 | 
				
			||||||
 | 
					  server_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE);
 | 
				
			||||||
 | 
					  brpc::ServerOptions options;
 | 
				
			||||||
 | 
					  if (server_.Start(endpoint_.c_str(), &options) != 0) {
 | 
				
			||||||
 | 
					    VLOG(0) << "heter server start fail";
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    VLOG(0) << "heter server start success! listen on " << endpoint_;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					    std::lock_guard<std::mutex> lock(this->mutex_ready_);
 | 
				
			||||||
 | 
					    ready_ = 1;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  condition_ready_.notify_all();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  server_.Join();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void HeterServer::SetEndPoint(std::string& endpoint) {
 | 
				
			||||||
 | 
					  endpoint_ = endpoint;
 | 
				
			||||||
 | 
					  service_.SetEndpoint(endpoint);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void HeterServer::SetFanin(int& fan_in) { service_.SetFanin(fan_in); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void HeterServer::WaitServerReady() {
 | 
				
			||||||
 | 
					  std::unique_lock<std::mutex> lock(this->mutex_ready_);
 | 
				
			||||||
 | 
					  condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int32_t HeterService::stop_profiler(const PsRequestMessage& request,
 | 
				
			||||||
 | 
					                                    PsResponseMessage& response,
 | 
				
			||||||
 | 
					                                    brpc::Controller* cntl) {
 | 
				
			||||||
 | 
					  platform::DisableProfiler(
 | 
				
			||||||
 | 
					      platform::EventSortingKey::kDefault,
 | 
				
			||||||
 | 
					      string::Sprintf("heter_worker_%s_profile", endpoint_));
 | 
				
			||||||
 | 
					  return 0;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int32_t HeterService::start_profiler(const PsRequestMessage& request,
 | 
				
			||||||
 | 
					                                     PsResponseMessage& response,
 | 
				
			||||||
 | 
					                                     brpc::Controller* cntl) {
 | 
				
			||||||
 | 
					  platform::EnableProfiler(platform::ProfilerState::kAll);
 | 
				
			||||||
 | 
					  return 0;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int32_t HeterService::stop_heter_worker(const PsRequestMessage& request,
 | 
				
			||||||
 | 
					                                        PsResponseMessage& response,
 | 
				
			||||||
 | 
					                                        brpc::Controller* cntl) {
 | 
				
			||||||
 | 
					  auto client_id = request.client_id();
 | 
				
			||||||
 | 
					  stop_cpu_worker_set_.insert(client_id);
 | 
				
			||||||
 | 
					  if (stop_cpu_worker_set_.size() == fan_in_) {
 | 
				
			||||||
 | 
					    is_exit_ = true;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return 0;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // end namespace distributed
 | 
				
			||||||
 | 
					}  // end namespace paddle
 | 
				
			||||||
@ -0,0 +1,243 @@
 | 
				
			|||||||
 | 
					/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License. */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma once
 | 
				
			||||||
 | 
					#include <atomic>
 | 
				
			||||||
 | 
					#include <ctime>
 | 
				
			||||||
 | 
					#include <map>
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <random>
 | 
				
			||||||
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <unordered_map>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					#include "brpc/channel.h"
 | 
				
			||||||
 | 
					#include "brpc/controller.h"
 | 
				
			||||||
 | 
					#include "brpc/server.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/brpc_utils.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/sendrecv.pb.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/executor.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/program_desc.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/scope.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/tensor.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/variable_helper.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/platform/macros.h"  // for DISABLE_COPY_AND_ASSIGN
 | 
				
			||||||
 | 
					#include "paddle/fluid/platform/profiler.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace paddle {
 | 
				
			||||||
 | 
					namespace distributed {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using MultiVarMsg = ::paddle::MultiVariableMessage;
 | 
				
			||||||
 | 
					using VarMsg = ::paddle::VariableMessage;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HeterService;
 | 
				
			||||||
 | 
					typedef int32_t (HeterService::*serviceHandlerFunc)(
 | 
				
			||||||
 | 
					    const PsRequestMessage& request, PsResponseMessage& response,
 | 
				
			||||||
 | 
					    brpc::Controller* cntl);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					typedef std::function<void(void*)> HeterRpcCallbackFunc;
 | 
				
			||||||
 | 
					typedef std::function<int(const MultiVarMsg*, MultiVarMsg*, brpc::Controller*)>
 | 
				
			||||||
 | 
					    HeterServiceHandler;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HeterService : public ::paddle::PsService {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  HeterService() {
 | 
				
			||||||
 | 
					    _service_handler_map[PS_STOP_SERVER] = &HeterService::stop_heter_worker;
 | 
				
			||||||
 | 
					    _service_handler_map[PS_START_PROFILER] = &HeterService::start_profiler;
 | 
				
			||||||
 | 
					    _service_handler_map[PS_STOP_PROFILER] = &HeterService::stop_profiler;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual ~HeterService() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual void service(::google::protobuf::RpcController* controller,
 | 
				
			||||||
 | 
					                       const ::paddle::PsRequestMessage* request,
 | 
				
			||||||
 | 
					                       ::paddle::PsResponseMessage* response,
 | 
				
			||||||
 | 
					                       ::google::protobuf::Closure* done) {
 | 
				
			||||||
 | 
					    brpc::ClosureGuard done_guard(done);
 | 
				
			||||||
 | 
					    std::string log_label("ReceiveCmd-");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    response->set_err_code(0);
 | 
				
			||||||
 | 
					    response->set_err_msg("");
 | 
				
			||||||
 | 
					    brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
 | 
				
			||||||
 | 
					    auto itr = _service_handler_map.find(request->cmd_id());
 | 
				
			||||||
 | 
					    if (itr == _service_handler_map.end()) {
 | 
				
			||||||
 | 
					      std::string err_msg(
 | 
				
			||||||
 | 
					          "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
 | 
				
			||||||
 | 
					      err_msg.append(std::to_string(request->cmd_id()));
 | 
				
			||||||
 | 
					      return;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    serviceHandlerFunc handler_func = itr->second;
 | 
				
			||||||
 | 
					    int service_ret = (this->*handler_func)(*request, *response, cntl);
 | 
				
			||||||
 | 
					    if (service_ret != 0) {
 | 
				
			||||||
 | 
					      response->set_err_code(service_ret);
 | 
				
			||||||
 | 
					      response->set_err_msg("server internal error");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void SendAndRecvVariable(::google::protobuf::RpcController* controller,
 | 
				
			||||||
 | 
					                           const MultiVarMsg* request, MultiVarMsg* response,
 | 
				
			||||||
 | 
					                           ::google::protobuf::Closure* done) {
 | 
				
			||||||
 | 
					    brpc::ClosureGuard done_guard(done);
 | 
				
			||||||
 | 
					    std::string message_name = request->message_name();
 | 
				
			||||||
 | 
					    auto itr = handler_map_.find(message_name);
 | 
				
			||||||
 | 
					    brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
 | 
				
			||||||
 | 
					    PADDLE_ENFORCE_NE(
 | 
				
			||||||
 | 
					        itr, handler_map_.end(),
 | 
				
			||||||
 | 
					        platform::errors::InvalidArgument(
 | 
				
			||||||
 | 
					            "HeterService::SendAndRecvVariable Get illegal message_name: %s "
 | 
				
			||||||
 | 
					            "which is not in HeterService::handler_map_",
 | 
				
			||||||
 | 
					            message_name));
 | 
				
			||||||
 | 
					    itr->second(request, response, cntl);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void RegisterServiceHandler(std::string message_name,
 | 
				
			||||||
 | 
					                              HeterServiceHandler func) {
 | 
				
			||||||
 | 
					    handler_map_[message_name] = func;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; }
 | 
				
			||||||
 | 
					  void SetFanin(const int& fan_in) { fan_in_ = fan_in; }
 | 
				
			||||||
 | 
					  bool IsExit() { return is_exit_; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  int32_t stop_profiler(const PsRequestMessage& request,
 | 
				
			||||||
 | 
					                        PsResponseMessage& response, brpc::Controller* cntl);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  int32_t start_profiler(const PsRequestMessage& request,
 | 
				
			||||||
 | 
					                         PsResponseMessage& response, brpc::Controller* cntl);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  int32_t stop_heter_worker(const PsRequestMessage& request,
 | 
				
			||||||
 | 
					                            PsResponseMessage& response,
 | 
				
			||||||
 | 
					                            brpc::Controller* cntl);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  std::string endpoint_;
 | 
				
			||||||
 | 
					  std::unordered_map<std::string, HeterServiceHandler> handler_map_;
 | 
				
			||||||
 | 
					  std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
 | 
				
			||||||
 | 
					  std::unordered_set<int> stop_cpu_worker_set_;
 | 
				
			||||||
 | 
					  int fan_in_;
 | 
				
			||||||
 | 
					  bool is_exit_ = false;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HeterServer {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  virtual ~HeterServer() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void Stop() {
 | 
				
			||||||
 | 
					    server_.Stop(1000);
 | 
				
			||||||
 | 
					    server_.Join();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  bool IsExit() { return service_.IsExit(); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  HeterServer() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void RegisterServiceHandler(std::string message_name,
 | 
				
			||||||
 | 
					                              HeterServiceHandler func);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void StartHeterService();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void SetEndPoint(std::string& endpoint);
 | 
				
			||||||
 | 
					  void SetFanin(int& fan_in);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // HeterWrapper singleton
 | 
				
			||||||
 | 
					  static std::shared_ptr<HeterServer> GetInstance() {
 | 
				
			||||||
 | 
					    if (NULL == s_instance_) {
 | 
				
			||||||
 | 
					      s_instance_.reset(new HeterServer());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return s_instance_;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void WaitServerReady();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  static std::shared_ptr<HeterServer> s_instance_;
 | 
				
			||||||
 | 
					  std::string endpoint_;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 protected:
 | 
				
			||||||
 | 
					  brpc::Server server_;
 | 
				
			||||||
 | 
					  HeterService service_;
 | 
				
			||||||
 | 
					  DISABLE_COPY_AND_ASSIGN(HeterServer);
 | 
				
			||||||
 | 
					  std::mutex mutex_ready_;
 | 
				
			||||||
 | 
					  std::condition_variable condition_ready_;
 | 
				
			||||||
 | 
					  int ready_;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HeterRequestHandler {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  HeterRequestHandler()
 | 
				
			||||||
 | 
					      : dev_ctx_(nullptr),
 | 
				
			||||||
 | 
					        executor_(nullptr),
 | 
				
			||||||
 | 
					        scope_(nullptr),
 | 
				
			||||||
 | 
					        program_(nullptr) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual ~HeterRequestHandler() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void SetScope(framework::Scope* scope) { scope_ = scope; }
 | 
				
			||||||
 | 
					  void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
 | 
				
			||||||
 | 
					  void SetProgram(framework::ProgramDesc* program) { program_ = program; }
 | 
				
			||||||
 | 
					  void SetExecutor(framework::Executor* executor) { executor_ = executor; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  void SetGradToPreparedCtx(
 | 
				
			||||||
 | 
					      std::unordered_map<
 | 
				
			||||||
 | 
					          std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
 | 
				
			||||||
 | 
					    message_to_prepared_ctx_ = g;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual int Handle(const MultiVarMsg* request, MultiVarMsg* response,
 | 
				
			||||||
 | 
					                     brpc::Controller* cntl) = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 protected:
 | 
				
			||||||
 | 
					  const platform::DeviceContext* dev_ctx_;
 | 
				
			||||||
 | 
					  framework::Executor* executor_;
 | 
				
			||||||
 | 
					  framework::Scope* scope_;
 | 
				
			||||||
 | 
					  framework::ProgramDesc* program_;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::unordered_map<std::string,
 | 
				
			||||||
 | 
					                     std::shared_ptr<framework::ExecutorPrepareContext>>*
 | 
				
			||||||
 | 
					      message_to_prepared_ctx_;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RequestSendAndRecvHandler final : public HeterRequestHandler {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  RequestSendAndRecvHandler() {}
 | 
				
			||||||
 | 
					  virtual ~RequestSendAndRecvHandler() {}
 | 
				
			||||||
 | 
					  int Handle(const MultiVarMsg* request, MultiVarMsg* response,
 | 
				
			||||||
 | 
					             brpc::Controller* cntl) override {
 | 
				
			||||||
 | 
					    platform::RecordEvent record_event("RequestSendAndRecvHandler->Handle");
 | 
				
			||||||
 | 
					    auto& local_scope = scope_->NewScope();
 | 
				
			||||||
 | 
					    auto message_name = request->message_name();
 | 
				
			||||||
 | 
					    auto& request_io_buffer = cntl->request_attachment();
 | 
				
			||||||
 | 
					    distributed::DeserializeFromMultiVarMsgAndIOBuf(
 | 
				
			||||||
 | 
					        *request, &request_io_buffer, *dev_ctx_, &local_scope);
 | 
				
			||||||
 | 
					    executor_->RunPreparedContext(
 | 
				
			||||||
 | 
					        (*message_to_prepared_ctx_)[message_name].get(), &local_scope, false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto response_var_nums = request->recv_var_names_size();
 | 
				
			||||||
 | 
					    std::vector<std::string> response_var_names(response_var_nums),
 | 
				
			||||||
 | 
					        empty_var_names{};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (int var_idx = 0; var_idx < response_var_nums; ++var_idx) {
 | 
				
			||||||
 | 
					      response_var_names[var_idx] = request->recv_var_names(var_idx);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    auto& response_io_buffer = cntl->response_attachment();
 | 
				
			||||||
 | 
					    distributed::SerializeToMultiVarMsgAndIOBuf(
 | 
				
			||||||
 | 
					        message_name, response_var_names, empty_var_names, *dev_ctx_,
 | 
				
			||||||
 | 
					        &local_scope, response, &response_io_buffer);
 | 
				
			||||||
 | 
					    scope_->DeleteScope(&local_scope);
 | 
				
			||||||
 | 
					    return 0;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // end namespace distributed
 | 
				
			||||||
 | 
					}  // end namespace paddle
 | 
				
			||||||
@ -0,0 +1,89 @@
 | 
				
			|||||||
 | 
					// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/ps_client.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <map>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "brpc/server.h"
 | 
				
			||||||
 | 
					#include "glog/logging.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/brpc_ps_client.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/table/table.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace paddle {
 | 
				
			||||||
 | 
					namespace distributed {
 | 
				
			||||||
 | 
					REGISTER_CLASS(PSClient, BrpcPsClient);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int32_t PSClient::configure(
 | 
				
			||||||
 | 
					    const PSParameter &config,
 | 
				
			||||||
 | 
					    const std::map<uint64_t, std::vector<paddle::distributed::Region>> ®ions,
 | 
				
			||||||
 | 
					    PSEnvironment &env, size_t client_id) {
 | 
				
			||||||
 | 
					  _env = &env;
 | 
				
			||||||
 | 
					  _config = config;
 | 
				
			||||||
 | 
					  _dense_pull_regions = regions;
 | 
				
			||||||
 | 
					  _client_id = client_id;
 | 
				
			||||||
 | 
					  _config.mutable_worker_param()
 | 
				
			||||||
 | 
					      ->mutable_downpour_worker_param()
 | 
				
			||||||
 | 
					      ->mutable_downpour_table_param()
 | 
				
			||||||
 | 
					      ->CopyFrom(_config.server_param()
 | 
				
			||||||
 | 
					                     .downpour_server_param()
 | 
				
			||||||
 | 
					                     .downpour_table_param());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const auto &work_param = _config.worker_param().downpour_worker_param();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (size_t i = 0; i < work_param.downpour_table_param_size(); ++i) {
 | 
				
			||||||
 | 
					    auto *accessor = CREATE_CLASS(
 | 
				
			||||||
 | 
					        ValueAccessor,
 | 
				
			||||||
 | 
					        work_param.downpour_table_param(i).accessor().accessor_class());
 | 
				
			||||||
 | 
					    accessor->configure(work_param.downpour_table_param(i).accessor());
 | 
				
			||||||
 | 
					    accessor->initialize();
 | 
				
			||||||
 | 
					    _table_accessors[work_param.downpour_table_param(i).table_id()].reset(
 | 
				
			||||||
 | 
					        accessor);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return initialize();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					PSClient *PSClientFactory::create(const PSParameter &ps_config) {
 | 
				
			||||||
 | 
					  const auto &config = ps_config.server_param();
 | 
				
			||||||
 | 
					  if (!config.has_downpour_server_param()) {
 | 
				
			||||||
 | 
					    LOG(ERROR) << "miss downpour_server_param in ServerParameter";
 | 
				
			||||||
 | 
					    return NULL;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (!config.downpour_server_param().has_service_param()) {
 | 
				
			||||||
 | 
					    LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param";
 | 
				
			||||||
 | 
					    return NULL;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (!config.downpour_server_param().service_param().has_client_class()) {
 | 
				
			||||||
 | 
					    LOG(ERROR) << "miss client_class in "
 | 
				
			||||||
 | 
					                  "ServerParameter.downpour_server_param.service_param";
 | 
				
			||||||
 | 
					    return NULL;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const auto &service_param = config.downpour_server_param().service_param();
 | 
				
			||||||
 | 
					  PSClient *client = CREATE_CLASS(PSClient, service_param.client_class());
 | 
				
			||||||
 | 
					  if (client == NULL) {
 | 
				
			||||||
 | 
					    LOG(ERROR) << "client is not registered, server_name:"
 | 
				
			||||||
 | 
					               << service_param.client_class();
 | 
				
			||||||
 | 
					    return NULL;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  TableManager::instance().initialize();
 | 
				
			||||||
 | 
					  LOG(INFO) << "Create PSClient[" << service_param.client_class()
 | 
				
			||||||
 | 
					            << "] success";
 | 
				
			||||||
 | 
					  return client;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					}  // namespace distributed
 | 
				
			||||||
 | 
					}  // namespace paddle
 | 
				
			||||||
@ -0,0 +1,113 @@
 | 
				
			|||||||
 | 
					// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					syntax = "proto2";
 | 
				
			||||||
 | 
					package paddle;
 | 
				
			||||||
 | 
					option cc_generic_services = true;
 | 
				
			||||||
 | 
					option cc_enable_arenas = true;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					enum PsCmdID {
 | 
				
			||||||
 | 
					  PS_PULL_DENSE_TABLE = 0;
 | 
				
			||||||
 | 
					  PS_PUSH_DENSE_TABLE = 1;
 | 
				
			||||||
 | 
					  PS_PULL_SPARSE_TABLE = 2;
 | 
				
			||||||
 | 
					  PS_PUSH_SPARSE_TABLE = 3;
 | 
				
			||||||
 | 
					  PS_SHRINK_TABLE = 4;
 | 
				
			||||||
 | 
					  PS_SAVE_ONE_TABLE = 5;
 | 
				
			||||||
 | 
					  PS_SAVE_ALL_TABLE = 6;
 | 
				
			||||||
 | 
					  PS_LOAD_ONE_TABLE = 7;
 | 
				
			||||||
 | 
					  PS_LOAD_ALL_TABLE = 8;
 | 
				
			||||||
 | 
					  PS_CLEAR_ONE_TABLE = 9;
 | 
				
			||||||
 | 
					  PS_CLEAR_ALL_TABLE = 10;
 | 
				
			||||||
 | 
					  PS_PUSH_DENSE_PARAM = 11;
 | 
				
			||||||
 | 
					  PS_STOP_SERVER = 12;
 | 
				
			||||||
 | 
					  PS_SAVE_ONE_CACHE_TABLE = 13;
 | 
				
			||||||
 | 
					  PS_GET_CACHE_THRESHOLD = 14;
 | 
				
			||||||
 | 
					  PS_CACHE_SHUFFLE = 15;
 | 
				
			||||||
 | 
					  PS_COPY_TABLE = 16;
 | 
				
			||||||
 | 
					  PS_COPY_TABLE_BY_FEASIGN = 17;
 | 
				
			||||||
 | 
					  PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY = 18;
 | 
				
			||||||
 | 
					  PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY = 19;
 | 
				
			||||||
 | 
					  PS_PRINT_TABLE_STAT = 20;
 | 
				
			||||||
 | 
					  PS_SAVE_ONE_TABLE_PREFIX = 21;
 | 
				
			||||||
 | 
					  PS_SAVE_ONE_TABLE_WITH_WHITELIST = 22;
 | 
				
			||||||
 | 
					  PS_LOAD_ONE_TABLE_WITH_WHITELIST = 23;
 | 
				
			||||||
 | 
					  PS_PULL_GEO_PARAM = 24;
 | 
				
			||||||
 | 
					  PS_BARRIER = 25;
 | 
				
			||||||
 | 
					  PS_PUSH_SPARSE_PARAM = 26;
 | 
				
			||||||
 | 
					  PS_START_PROFILER = 27;
 | 
				
			||||||
 | 
					  PS_STOP_PROFILER = 28;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message PsRequestMessage {
 | 
				
			||||||
 | 
					  required uint32 cmd_id = 1;
 | 
				
			||||||
 | 
					  optional uint32 table_id = 2;
 | 
				
			||||||
 | 
					  repeated bytes params = 3;
 | 
				
			||||||
 | 
					  optional int32 client_id = 4;
 | 
				
			||||||
 | 
					  optional bytes data = 5;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message PsResponseMessage {
 | 
				
			||||||
 | 
					  required int32 err_code = 1 [ default = 0 ];
 | 
				
			||||||
 | 
					  required string err_msg = 2 [ default = "" ];
 | 
				
			||||||
 | 
					  optional bytes data = 3;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					enum VarType {
 | 
				
			||||||
 | 
					  LOD_TENSOR = 0;
 | 
				
			||||||
 | 
					  SELECTED_ROWS = 1;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message VariableMessage {
 | 
				
			||||||
 | 
					  enum Type {
 | 
				
			||||||
 | 
					    // Pod Types
 | 
				
			||||||
 | 
					    BOOL = 0;
 | 
				
			||||||
 | 
					    INT16 = 1;
 | 
				
			||||||
 | 
					    INT32 = 2;
 | 
				
			||||||
 | 
					    INT64 = 3;
 | 
				
			||||||
 | 
					    FP16 = 4;
 | 
				
			||||||
 | 
					    FP32 = 5;
 | 
				
			||||||
 | 
					    FP64 = 6;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  message LodData { repeated int64 lod_data = 1; }
 | 
				
			||||||
 | 
					  optional string varname = 1;
 | 
				
			||||||
 | 
					  // TODO(Yancey1989): reference framework::proto::VarDesc::VarType
 | 
				
			||||||
 | 
					  optional VarType type = 2;
 | 
				
			||||||
 | 
					  // bool persistable is not needed for sending.
 | 
				
			||||||
 | 
					  // tensor info:
 | 
				
			||||||
 | 
					  optional Type data_type = 3;
 | 
				
			||||||
 | 
					  repeated int64 dims = 4;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // lod details:
 | 
				
			||||||
 | 
					  optional int64 lod_level = 5;
 | 
				
			||||||
 | 
					  repeated LodData lod = 6;
 | 
				
			||||||
 | 
					  // selected_rows height, aka. original dim0
 | 
				
			||||||
 | 
					  optional int64 slr_height = 7;
 | 
				
			||||||
 | 
					  // tensor data
 | 
				
			||||||
 | 
					  optional bytes data = 8;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// for SendAndRecv RPC method
 | 
				
			||||||
 | 
					message MultiVariableMessage {
 | 
				
			||||||
 | 
					  // message flags
 | 
				
			||||||
 | 
					  required string message_name = 1;
 | 
				
			||||||
 | 
					  repeated string send_var_names = 2;
 | 
				
			||||||
 | 
					  repeated string recv_var_names = 3;
 | 
				
			||||||
 | 
					  repeated VariableMessage var_messages = 4;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					service PsService {
 | 
				
			||||||
 | 
					  rpc service(PsRequestMessage) returns (PsResponseMessage);
 | 
				
			||||||
 | 
					  rpc SendAndRecvVariable(MultiVariableMessage) returns (MultiVariableMessage);
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
@ -0,0 +1,87 @@
 | 
				
			|||||||
 | 
					// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/server.h"
 | 
				
			||||||
 | 
					#include "glog/logging.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/brpc_ps_server.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/table/table.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace paddle {
 | 
				
			||||||
 | 
					namespace distributed {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					REGISTER_CLASS(PSServer, BrpcPsServer);
 | 
				
			||||||
 | 
					REGISTER_CLASS(PsBaseService, PsService);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					PSServer *PSServerFactory::create(const PSParameter &ps_config) {
 | 
				
			||||||
 | 
					  const auto &config = ps_config.server_param();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (!config.has_downpour_server_param()) {
 | 
				
			||||||
 | 
					    LOG(ERROR) << "miss downpour_server_param in ServerParameter";
 | 
				
			||||||
 | 
					    return NULL;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (!config.downpour_server_param().has_service_param()) {
 | 
				
			||||||
 | 
					    LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param";
 | 
				
			||||||
 | 
					    return NULL;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (!config.downpour_server_param().service_param().has_server_class()) {
 | 
				
			||||||
 | 
					    LOG(ERROR) << "miss server_class in "
 | 
				
			||||||
 | 
					                  "ServerParameter.downpour_server_param.service_param";
 | 
				
			||||||
 | 
					    return NULL;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const auto &service_param = config.downpour_server_param().service_param();
 | 
				
			||||||
 | 
					  PSServer *server = CREATE_CLASS(PSServer, service_param.server_class());
 | 
				
			||||||
 | 
					  if (server == NULL) {
 | 
				
			||||||
 | 
					    LOG(ERROR) << "server is not registered, server_name:"
 | 
				
			||||||
 | 
					               << service_param.server_class();
 | 
				
			||||||
 | 
					    return NULL;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  TableManager::instance().initialize();
 | 
				
			||||||
 | 
					  return server;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int32_t PSServer::configure(const PSParameter &config, PSEnvironment &env,
 | 
				
			||||||
 | 
					                            size_t server_rank) {
 | 
				
			||||||
 | 
					  _config = config.server_param();
 | 
				
			||||||
 | 
					  _rank = server_rank;
 | 
				
			||||||
 | 
					  _environment = &env;
 | 
				
			||||||
 | 
					  _shuffled_ins =
 | 
				
			||||||
 | 
					      paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
 | 
				
			||||||
 | 
					  const auto &downpour_param = _config.downpour_server_param();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  uint32_t barrier_table = UINT32_MAX;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
 | 
				
			||||||
 | 
					    auto *table = CREATE_CLASS(
 | 
				
			||||||
 | 
					        Table, downpour_param.downpour_table_param(i).table_class());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (downpour_param.downpour_table_param(i).table_class() ==
 | 
				
			||||||
 | 
					        "BarrierTable") {
 | 
				
			||||||
 | 
					      barrier_table = downpour_param.downpour_table_param(i).table_id();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    table->initialize(downpour_param.downpour_table_param(i),
 | 
				
			||||||
 | 
					                      config.fs_client_param());
 | 
				
			||||||
 | 
					    _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (barrier_table != UINT32_MAX) {
 | 
				
			||||||
 | 
					    _table_map[barrier_table]->set_table_map(&_table_map);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return initialize();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					}  // namespace distributed
 | 
				
			||||||
 | 
					}  // namespace paddle
 | 
				
			||||||
@ -0,0 +1,150 @@
 | 
				
			|||||||
 | 
					// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <future>
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <unordered_map>
 | 
				
			||||||
 | 
					#include <utility>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					#include "butil/endpoint.h"
 | 
				
			||||||
 | 
					#include "google/protobuf/service.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/common/registerer.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/ps.pb.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/env.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/sendrecv.pb.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/framework/channel.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace paddle {
 | 
				
			||||||
 | 
					namespace distributed {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Table;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PSServer {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  PSServer() {}
 | 
				
			||||||
 | 
					  virtual ~PSServer() {}
 | 
				
			||||||
 | 
					  PSServer(PSServer &&) = delete;
 | 
				
			||||||
 | 
					  PSServer(const PSServer &) = delete;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual int32_t configure(const PSParameter &config, PSEnvironment &env,
 | 
				
			||||||
 | 
					                            size_t server_rank) final;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // return server_ip
 | 
				
			||||||
 | 
					  virtual std::string ip() { return butil::my_ip_cstr(); }
 | 
				
			||||||
 | 
					  // return server_port
 | 
				
			||||||
 | 
					  virtual int32_t port() = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual uint64_t start(const std::string &ip, uint32_t port) = 0;
 | 
				
			||||||
 | 
					  virtual int32_t stop() = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  inline size_t rank() const { return _rank; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  inline PSEnvironment *environment() { return _environment; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  inline const ServerParameter *config() const { return &_config; }
 | 
				
			||||||
 | 
					  inline Table *table(size_t table_id) {
 | 
				
			||||||
 | 
					    auto itr = _table_map.find(table_id);
 | 
				
			||||||
 | 
					    if (itr != _table_map.end()) {
 | 
				
			||||||
 | 
					      return itr->second.get();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return NULL;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  inline std::unordered_map<uint32_t, std::shared_ptr<Table>> *table() {
 | 
				
			||||||
 | 
					    return &_table_map;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
 | 
				
			||||||
 | 
					  virtual int registe_pserver2pserver_msg_handler(int msg_type,
 | 
				
			||||||
 | 
					                                                  MsgHandlerFunc handler) {
 | 
				
			||||||
 | 
					    _msg_handler_map[msg_type] = handler;
 | 
				
			||||||
 | 
					    return 0;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 protected:
 | 
				
			||||||
 | 
					  virtual int32_t initialize() = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 protected:
 | 
				
			||||||
 | 
					  size_t _rank;
 | 
				
			||||||
 | 
					  ServerParameter _config;
 | 
				
			||||||
 | 
					  PSEnvironment *_environment;
 | 
				
			||||||
 | 
					  std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
 | 
				
			||||||
 | 
					  std::unordered_map<int32_t, MsgHandlerFunc> _msg_handler_map;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					REGISTER_REGISTERER(PSServer);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					typedef std::function<void(void *)> PServerCallBack;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PServerClosure : public google::protobuf::Closure {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  PServerClosure(PServerCallBack callback) : _callback(callback) {}
 | 
				
			||||||
 | 
					  virtual ~PServerClosure() {}
 | 
				
			||||||
 | 
					  virtual void set_promise_value(int value) {
 | 
				
			||||||
 | 
					    for (auto &promise : _promises) {
 | 
				
			||||||
 | 
					      promise->set_value(value);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  void add_promise(std::shared_ptr<std::promise<int32_t>> &promise) {
 | 
				
			||||||
 | 
					    _promises.push_back(promise);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 protected:
 | 
				
			||||||
 | 
					  PServerCallBack _callback;
 | 
				
			||||||
 | 
					  std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PsBaseService : public PsService {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  PsBaseService() : _rank(0), _server(NULL), _config(NULL) {}
 | 
				
			||||||
 | 
					  virtual ~PsBaseService() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual int32_t configure(PSServer *server) {
 | 
				
			||||||
 | 
					    _server = server;
 | 
				
			||||||
 | 
					    _rank = _server->rank();
 | 
				
			||||||
 | 
					    _config = _server->config();
 | 
				
			||||||
 | 
					    return 0;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  virtual void service(::google::protobuf::RpcController *controller,
 | 
				
			||||||
 | 
					                       const ::paddle::PsRequestMessage *request,
 | 
				
			||||||
 | 
					                       ::paddle::PsResponseMessage *response,
 | 
				
			||||||
 | 
					                       ::google::protobuf::Closure *done) override = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual void set_response_code(PsResponseMessage &response, int err_code,
 | 
				
			||||||
 | 
					                                 const char *err_msg) {
 | 
				
			||||||
 | 
					    response.set_err_msg(err_msg);
 | 
				
			||||||
 | 
					    response.set_err_code(err_code);
 | 
				
			||||||
 | 
					    LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual int32_t initialize() = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 protected:
 | 
				
			||||||
 | 
					  size_t _rank;
 | 
				
			||||||
 | 
					  PSServer *_server;
 | 
				
			||||||
 | 
					  const ServerParameter *_config;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					REGISTER_REGISTERER(PsBaseService);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PSServerFactory {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  static PSServer *create(const PSParameter &config);
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					}  // namespace distributed
 | 
				
			||||||
 | 
					}  // namespace paddle
 | 
				
			||||||
@ -0,0 +1,129 @@
 | 
				
			|||||||
 | 
					/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					   you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					   You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					   distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					   See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					   limitations under the License. */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/service.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <fcntl.h>
 | 
				
			||||||
 | 
					#include <google/protobuf/io/zero_copy_stream_impl.h>
 | 
				
			||||||
 | 
					#include <google/protobuf/text_format.h>
 | 
				
			||||||
 | 
					#include <iostream>
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/communicator.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/string/string_helper.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace std;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace paddle {
 | 
				
			||||||
 | 
					namespace distributed {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					paddle::distributed::PSParameter load_from_prototxt(
 | 
				
			||||||
 | 
					    const std::string& filename) {
 | 
				
			||||||
 | 
					  paddle::distributed::PSParameter param;
 | 
				
			||||||
 | 
					  int file_descriptor = open(filename.c_str(), O_RDONLY);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (file_descriptor == -1) {
 | 
				
			||||||
 | 
					    VLOG(3) << "FATAL: fail to parse " << filename;
 | 
				
			||||||
 | 
					    exit(-1);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  google::protobuf::io::FileInputStream fileInput(file_descriptor);
 | 
				
			||||||
 | 
					  if (!google::protobuf::TextFormat::Parse(&fileInput, ¶m)) {
 | 
				
			||||||
 | 
					    VLOG(3) << "FATAL: fail to parse " << filename;
 | 
				
			||||||
 | 
					    exit(-1);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  close(file_descriptor);
 | 
				
			||||||
 | 
					  return param;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void PSCore::init_gflag(const std::string& gflags) {
 | 
				
			||||||
 | 
					  LOG(INFO) << "Init With Gflags:" << gflags;
 | 
				
			||||||
 | 
					  std::vector<std::string> flags = paddle::string::split_string(gflags);
 | 
				
			||||||
 | 
					  if (flags.size() < 1) {
 | 
				
			||||||
 | 
					    flags.push_back("-max_body_size=314217728");
 | 
				
			||||||
 | 
					    flags.push_back("-bthread_concurrency=40");
 | 
				
			||||||
 | 
					    flags.push_back("-socket_max_unwritten_bytes=2048000000");
 | 
				
			||||||
 | 
					    flags.push_back("-max_connection_pool_size=1950");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  auto it = flags.begin();
 | 
				
			||||||
 | 
					  flags.insert(it, "exe default");
 | 
				
			||||||
 | 
					  char* flags_ptr[flags.size()];
 | 
				
			||||||
 | 
					  for (size_t i = 0; i < flags.size(); ++i) {
 | 
				
			||||||
 | 
					    flags_ptr[i] = (char*)(flags[i].c_str());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  int params_cnt = flags.size();
 | 
				
			||||||
 | 
					  char** params_ptr = &(flags_ptr[0]);
 | 
				
			||||||
 | 
					  ::google::ParseCommandLineFlags(¶ms_cnt, ¶ms_ptr, true);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int PSCore::init_server(const std::string& dist_desc,
 | 
				
			||||||
 | 
					                        const std::vector<std::string>* host_sign_list,
 | 
				
			||||||
 | 
					                        int node_num, int index) {
 | 
				
			||||||
 | 
					  google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
 | 
				
			||||||
 | 
					  init_gflag(_ps_param.init_gflags());
 | 
				
			||||||
 | 
					  _ps_env = paddle::distributed::PaddlePSEnvironment();
 | 
				
			||||||
 | 
					  _ps_env.set_ps_servers(host_sign_list, node_num);
 | 
				
			||||||
 | 
					  int ret = 0;
 | 
				
			||||||
 | 
					  _server_ptr = std::shared_ptr<paddle::distributed::PSServer>(
 | 
				
			||||||
 | 
					      paddle::distributed::PSServerFactory::create(_ps_param));
 | 
				
			||||||
 | 
					  ret = _server_ptr->configure(_ps_param, _ps_env, index);
 | 
				
			||||||
 | 
					  CHECK(ret == 0) << "failed to configure server";
 | 
				
			||||||
 | 
					  return ret;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int PSCore::init_worker(
 | 
				
			||||||
 | 
					    const std::string& dist_desc,
 | 
				
			||||||
 | 
					    const std::map<uint64_t, std::vector<paddle::distributed::Region>>& regions,
 | 
				
			||||||
 | 
					    const std::vector<std::string>* host_sign_list, int node_num, int index) {
 | 
				
			||||||
 | 
					  google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
 | 
				
			||||||
 | 
					  init_gflag(_ps_param.init_gflags());
 | 
				
			||||||
 | 
					  _ps_env = paddle::distributed::PaddlePSEnvironment();
 | 
				
			||||||
 | 
					  _ps_env.set_ps_servers(host_sign_list, node_num);
 | 
				
			||||||
 | 
					  int ret = 0;
 | 
				
			||||||
 | 
					  VLOG(1) << "PSCore::init_worker";
 | 
				
			||||||
 | 
					  auto* communicator = Communicator::GetInstance();
 | 
				
			||||||
 | 
					  ret = communicator->GetPsClient()->configure(_ps_param, regions, _ps_env,
 | 
				
			||||||
 | 
					                                               index);
 | 
				
			||||||
 | 
					  communicator->Start();
 | 
				
			||||||
 | 
					  return ret;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::vector<uint64_t> PSCore::get_client_info() {
 | 
				
			||||||
 | 
					  return _ps_env.get_client_info();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int PSCore::create_client2client_connection(int pserver_timeout_ms,
 | 
				
			||||||
 | 
					                                            int pserver_connect_timeout_ms,
 | 
				
			||||||
 | 
					                                            int max_retry) {
 | 
				
			||||||
 | 
					  int ret = _worker_ptr->create_client2client_connection(
 | 
				
			||||||
 | 
					      pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
 | 
				
			||||||
 | 
					  return ret;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					uint64_t PSCore::run_server(const std::string& ip, uint32_t port) {
 | 
				
			||||||
 | 
					  return _server_ptr->start(ip, port);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int PSCore::finalize_worker() {
 | 
				
			||||||
 | 
					  _worker_ptr->finalize_worker();
 | 
				
			||||||
 | 
					  return 0;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int PSCore::stop_server() {
 | 
				
			||||||
 | 
					  auto stop_status = _worker_ptr->stop_server();
 | 
				
			||||||
 | 
					  stop_status.wait();
 | 
				
			||||||
 | 
					  return 0;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					paddle::distributed::PSParameter* PSCore::get_param() { return &_ps_param; }
 | 
				
			||||||
 | 
					}  // namespace distributed
 | 
				
			||||||
 | 
					}  // namespace paddle
 | 
				
			||||||
@ -0,0 +1,64 @@
 | 
				
			|||||||
 | 
					/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License. */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <map>
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <glog/logging.h>
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/ps.pb.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/ps_client.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/sendrecv.pb.h"
 | 
				
			||||||
 | 
					#include "paddle/fluid/distributed/service/server.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace paddle {
 | 
				
			||||||
 | 
					namespace distributed {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PSCore {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  explicit PSCore() {}
 | 
				
			||||||
 | 
					  virtual ~PSCore() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  virtual int init_server(const std::string& dist_desc,
 | 
				
			||||||
 | 
					                          const std::vector<std::string>* host_sign_list,
 | 
				
			||||||
 | 
					                          int node_num, int index);
 | 
				
			||||||
 | 
					  virtual int init_worker(
 | 
				
			||||||
 | 
					      const std::string& dist_desc,
 | 
				
			||||||
 | 
					      const std::map<uint64_t, std::vector<paddle::distributed::Region>>&
 | 
				
			||||||
 | 
					          regions,
 | 
				
			||||||
 | 
					      const std::vector<std::string>* host_sign_list, int node_num, int index);
 | 
				
			||||||
 | 
					  virtual uint64_t run_server(const std::string& ip, uint32_t port);
 | 
				
			||||||
 | 
					  virtual int stop_server();
 | 
				
			||||||
 | 
					  virtual int finalize_worker();
 | 
				
			||||||
 | 
					  virtual std::vector<uint64_t> get_client_info();
 | 
				
			||||||
 | 
					  virtual int create_client2client_connection(int pserver_timeout_ms,
 | 
				
			||||||
 | 
					                                              int pserver_connect_timeout_ms,
 | 
				
			||||||
 | 
					                                              int max_retry);
 | 
				
			||||||
 | 
					  std::shared_ptr<paddle::distributed::PSServer>
 | 
				
			||||||
 | 
					      _server_ptr;  // pointer to server
 | 
				
			||||||
 | 
					  std::shared_ptr<paddle::distributed::PSClient>
 | 
				
			||||||
 | 
					      _worker_ptr;  // pointer to worker
 | 
				
			||||||
 | 
					  virtual paddle::distributed::PSParameter* get_param();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  void init_gflag(const std::string& gflags);
 | 
				
			||||||
 | 
					  paddle::distributed::PSParameter _ps_param;
 | 
				
			||||||
 | 
					  paddle::distributed::PaddlePSEnvironment _ps_env;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace distributed
 | 
				
			||||||
 | 
					}  // namespace paddle
 | 
				
			||||||
Some files were not shown because too many files have changed in this diff Show More
					Loading…
					
					
				
		Reference in new issue