@ -120,6 +120,7 @@ class DeviceWorker {
}
virtual ~ DeviceWorker ( ) { }
virtual void Initialize ( const TrainerDesc & desc ) = 0 ;
virtual void InitRandomDumpConfig ( const TrainerDesc & desc ) ;
virtual void SetDeviceIndex ( int tid ) = 0 ;
virtual void TrainFiles ( ) = 0 ;
virtual void PrintFetchVars ( ) = 0 ;
@ -129,8 +130,21 @@ class DeviceWorker {
virtual void BindingDataFeedMemory ( ) = 0 ;
virtual void SetRootScope ( Scope * root_scope ) ;
virtual void SetDataFeed ( DataFeed * data_feed ) ;
virtual void SetNeedDump ( bool need_dump_field ) { }
virtual void SetChannelWriter ( ChannelObject < std : : string > * queue ) { }
virtual void SetNeedDumpField ( bool need_dump_field ) {
need_dump_field_ = need_dump_field ;
}
virtual void SetNeedDumpParam ( bool need_dump_param ) {
need_dump_param_ = need_dump_param ;
}
virtual void SetDumpFieldVector ( const std : : vector < std : : string > & dump_fields ) {
dump_fields_ = & dump_fields ;
}
virtual void SetDumpParamVector ( const std : : vector < std : : string > & dump_param ) {
dump_param_ = & dump_param ;
}
virtual void SetChannelWriter ( ChannelObject < std : : string > * queue ) {
writer_ . Reset ( queue ) ;
}
virtual void SetPlace ( const paddle : : platform : : Place & place ) {
place_ = place ;
}
@ -140,6 +154,9 @@ class DeviceWorker {
virtual Scope * GetThreadScope ( ) { return thread_scope_ ; }
protected :
virtual void DumpParam ( const Scope & scope , const int batch_id ) ;
virtual void DumpField ( const Scope & scope , int dump_mode ,
int dump_interval = 10000 ) ;
Scope * root_scope_ = nullptr ;
Scope * thread_scope_ ;
paddle : : platform : : Place place_ ;
@ -148,6 +165,16 @@ class DeviceWorker {
FetchConfig fetch_config_ ;
bool use_cvm_ ;
bool no_cvm_ ;
// dump params or grads for debug
bool need_dump_param_ ;
bool need_dump_field_ ;
const std : : vector < std : : string > * dump_param_ ;
const std : : vector < std : : string > * dump_fields_ ;
int dump_mode_ = 0 ;
int dump_interval_ = 10000 ;
ChannelWriter < std : : string > writer_ ;
} ;
class CPUWorkerBase : public DeviceWorker {
@ -176,8 +203,6 @@ class HogwildWorker : public CPUWorkerBase {
virtual void Initialize ( const TrainerDesc & desc ) ;
virtual void TrainFiles ( ) ;
virtual void TrainFilesWithProfiler ( ) ;
virtual void SetNeedDump ( bool need_dump_field ) ;
virtual void SetChannelWriter ( ChannelObject < std : : string > * queue ) ;
virtual void PrintFetchVars ( ) ;
virtual void CreateDeviceResource ( const ProgramDesc & main_prog ) ;
virtual void BindingDataFeedMemory ( ) ;
@ -187,7 +212,6 @@ class HogwildWorker : public CPUWorkerBase {
protected :
void CreateThreadOperators ( const ProgramDesc & program ) ;
void CreateThreadScope ( const ProgramDesc & program ) ;
virtual void DumpParam ( const int batch_id ) ;
std : : vector < std : : string > op_names_ ;
std : : vector < OperatorBase * > ops_ ;
@ -196,12 +220,6 @@ class HogwildWorker : public CPUWorkerBase {
HogwildWorkerParameter param_ ;
std : : vector < std : : string > skip_ops_ ;
std : : map < std : : string , int > stat_var_name_map_ ;
// dump params or grads for debug
bool need_dump_param_ ;
bool need_dump_field_ ;
std : : vector < std : : string > dump_param_ ;
std : : vector < std : : string > dump_fields_ ;
ChannelWriter < std : : string > writer_ ;
} ;
class DownpourWorker : public HogwildWorker {
@ -211,8 +229,6 @@ class DownpourWorker : public HogwildWorker {
virtual void Initialize ( const TrainerDesc & desc ) ;
virtual void TrainFiles ( ) ;
virtual void TrainFilesWithProfiler ( ) ;
virtual void SetNeedDump ( bool need_dump_field ) ;
virtual void SetChannelWriter ( ChannelObject < std : : string > * queue ) ;
protected :
std : : shared_ptr < paddle : : framework : : FleetWrapper > fleet_ptr_ ;
@ -224,7 +240,6 @@ class DownpourWorker : public HogwildWorker {
void CopySparseTable ( ) ;
void CopyDenseTable ( ) ;
void CopyDenseVars ( ) ;
virtual void DumpParam ( const int batch_id ) ;
DownpourWorkerParameter param_ ;
// copy table