@ -67,11 +67,12 @@ class FleetWrapper {
client2client_max_retry_ = 3 ;
client2client_max_retry_ = 3 ;
}
}
// set client to client communication config
void SetClient2ClientConfig ( int request_timeout_ms , int connect_timeout_ms ,
void SetClient2ClientConfig ( int request_timeout_ms , int connect_timeout_ms ,
int max_retry ) ;
int max_retry ) ;
// Pull sparse variables from server in S ync mode
// Pull sparse variables from server in s ync mode
// Param<in>: scope, table_id, var_names, fea_keys
// Param<in>: scope, table_id, var_names, fea_keys , fea_dim
// Param<out>: fea_values
// Param<out>: fea_values
void PullSparseVarsSync ( const Scope & scope , const uint64_t table_id ,
void PullSparseVarsSync ( const Scope & scope , const uint64_t table_id ,
const std : : vector < std : : string > & var_names ,
const std : : vector < std : : string > & var_names ,
@ -80,19 +81,24 @@ class FleetWrapper {
int fea_dim ,
int fea_dim ,
const std : : vector < std : : string > & var_emb_names ) ;
const std : : vector < std : : string > & var_emb_names ) ;
// pull dense variables from server in sync mod
void PullDenseVarsSync ( const Scope & scope , const uint64_t table_id ,
void PullDenseVarsSync ( const Scope & scope , const uint64_t table_id ,
const std : : vector < std : : string > & var_names ) ;
const std : : vector < std : : string > & var_names ) ;
// pull dense variables from server in async mod
// Param<in>: scope, table_id, var_names
// Param<out>: pull_dense_status
void PullDenseVarsAsync (
void PullDenseVarsAsync (
const Scope & scope , const uint64_t table_id ,
const Scope & scope , const uint64_t table_id ,
const std : : vector < std : : string > & var_names ,
const std : : vector < std : : string > & var_names ,
std : : vector < : : std : : future < int32_t > > * pull_dense_status ) ;
std : : vector < : : std : : future < int32_t > > * pull_dense_status ) ;
// push dense parameters(not gradients) to server in sync mode
void PushDenseParamSync ( const Scope & scope , const uint64_t table_id ,
void PushDenseParamSync ( const Scope & scope , const uint64_t table_id ,
const std : : vector < std : : string > & var_names ) ;
const std : : vector < std : : string > & var_names ) ;
// Push dense variables to server in async mode
// Push dense variables to server in async mode
// Param<in>: scope, table_id, var_names,
// Param<in>: scope, table_id, var_names, scale_datanorm, batch_size
// Param<out>: push_sparse_status
// Param<out>: push_sparse_status
void PushDenseVarsAsync (
void PushDenseVarsAsync (
const Scope & scope , const uint64_t table_id ,
const Scope & scope , const uint64_t table_id ,
@ -100,13 +106,14 @@ class FleetWrapper {
std : : vector < : : std : : future < int32_t > > * push_sparse_status ,
std : : vector < : : std : : future < int32_t > > * push_sparse_status ,
float scale_datanorm , int batch_size ) ;
float scale_datanorm , int batch_size ) ;
// push dense variables to server in sync mode
void PushDenseVarsSync ( Scope * scope , const uint64_t table_id ,
void PushDenseVarsSync ( Scope * scope , const uint64_t table_id ,
const std : : vector < std : : string > & var_names ) ;
const std : : vector < std : : string > & var_names ) ;
// Push sparse variables with labels to server in A sync mode
// Push sparse variables with labels to server in a sync mode
// This is specially designed for click/show stats in server
// This is specially designed for click/show stats in server
// Param<in>: scope, table_id, var_grad _names,
// Param<in>: scope, table_id, fea_keys, fea_labels, sparse_key _names,
// fea_keys, fea_labels, sparse_grad_names
// sparse_grad_names, batch_size, use_cvm, dump_slot
// Param<out>: push_values, push_sparse_status
// Param<out>: push_values, push_sparse_status
void PushSparseVarsWithLabelAsync (
void PushSparseVarsWithLabelAsync (
const Scope & scope , const uint64_t table_id ,
const Scope & scope , const uint64_t table_id ,
@ -132,12 +139,17 @@ class FleetWrapper {
std : : vector < : : std : : future < int32_t > > * push_sparse_status ) ;
std : : vector < : : std : : future < int32_t > > * push_sparse_status ) ;
*/
*/
// init server
void InitServer ( const std : : string & dist_desc , int index ) ;
void InitServer ( const std : : string & dist_desc , int index ) ;
// init trainer
void InitWorker ( const std : : string & dist_desc ,
void InitWorker ( const std : : string & dist_desc ,
const std : : vector < uint64_t > & host_sign_list , int node_num ,
const std : : vector < uint64_t > & host_sign_list , int node_num ,
int index ) ;
int index ) ;
// stop server
void StopServer ( ) ;
void StopServer ( ) ;
// run server
uint64_t RunServer ( ) ;
uint64_t RunServer ( ) ;
// gather server ip
void GatherServers ( const std : : vector < uint64_t > & host_sign_list , int node_num ) ;
void GatherServers ( const std : : vector < uint64_t > & host_sign_list , int node_num ) ;
// gather client ip
// gather client ip
void GatherClients ( const std : : vector < uint64_t > & host_sign_list ) ;
void GatherClients ( const std : : vector < uint64_t > & host_sign_list ) ;
@ -145,7 +157,6 @@ class FleetWrapper {
std : : vector < uint64_t > GetClientsInfo ( ) ;
std : : vector < uint64_t > GetClientsInfo ( ) ;
// create client to client connection
// create client to client connection
void CreateClient2ClientConnection ( ) ;
void CreateClient2ClientConnection ( ) ;
// flush all push requests
// flush all push requests
void ClientFlush ( ) ;
void ClientFlush ( ) ;
// load from paddle model
// load from paddle model
@ -164,37 +175,42 @@ class FleetWrapper {
// mode = 0, save all feature
// mode = 0, save all feature
// mode = 1, save delta feature, which means save diff
// mode = 1, save delta feature, which means save diff
void SaveModel ( const std : : string & path , const int mode ) ;
void SaveModel ( const std : : string & path , const int mode ) ;
// get save cache threshold
double GetCacheThreshold ( ) ;
double GetCacheThreshold ( ) ;
// shuffle cache model between servers
void CacheShuffle ( int table_id , const std : : string & path , const int mode ,
void CacheShuffle ( int table_id , const std : : string & path , const int mode ,
const double cache_threshold ) ;
const double cache_threshold ) ;
// save cache model
// cache model can speed up online predict
int32_t SaveCache ( int table_id , const std : : string & path , const int mode ) ;
int32_t SaveCache ( int table_id , const std : : string & path , const int mode ) ;
// copy feasign key/value from src_table_id to dest_table_id
int32_t CopyTable ( const uint64_t src_table_id , const uint64_t dest_table_id ) ;
// copy feasign key/value from src_table_id to dest_table_id
int32_t CopyTableByFeasign ( const uint64_t src_table_id ,
const uint64_t dest_table_id ,
const std : : vector < uint64_t > & feasign_list ) ;
// clear all models, release their memory
void ClearModel ( ) ;
void ClearModel ( ) ;
// shrink sparse table
void ShrinkSparseTable ( int table_id ) ;
void ShrinkSparseTable ( int table_id ) ;
// shrink dense table
void ShrinkDenseTable ( int table_id , Scope * scope ,
void ShrinkDenseTable ( int table_id , Scope * scope ,
std : : vector < std : : string > var_list , float decay ,
std : : vector < std : : string > var_list , float decay ,
int emb_dim ) ;
int emb_dim ) ;
// register client to client communication
typedef std : : function < int32_t ( int , int , const std : : string & ) > MsgHandlerFunc ;
typedef std : : function < int32_t ( int , int , const std : : string & ) > MsgHandlerFunc ;
// register client to client communication
int RegisterClientToClientMsgHandler ( int msg_type , MsgHandlerFunc handler ) ;
int RegisterClientToClientMsgHandler ( int msg_type , MsgHandlerFunc handler ) ;
// send client to client message
// send client to client message
std : : future < int32_t > SendClientToClientMsg ( int msg_type , int to_client_id ,
std : : future < int32_t > SendClientToClientMsg ( int msg_type , int to_client_id ,
const std : : string & msg ) ;
const std : : string & msg ) ;
// FleetWrapper singleton
template < typename T >
void Serialize ( const std : : vector < T * > & t , std : : string * str ) ;
template < typename T >
void Deserialize ( std : : vector < T > * t , const std : : string & str ) ;
static std : : shared_ptr < FleetWrapper > GetInstance ( ) {
static std : : shared_ptr < FleetWrapper > GetInstance ( ) {
if ( NULL = = s_instance_ ) {
if ( NULL = = s_instance_ ) {
s_instance_ . reset ( new paddle : : framework : : FleetWrapper ( ) ) ;
s_instance_ . reset ( new paddle : : framework : : FleetWrapper ( ) ) ;
}
}
return s_instance_ ;
return s_instance_ ;
}
}
// this performs better than rand_r, especially large data
// this performs better than rand_r, especially large data
std : : default_random_engine & LocalRandomEngine ( ) ;
std : : default_random_engine & LocalRandomEngine ( ) ;