|
|
|
@ -96,7 +96,7 @@ class DeviceWorker {
|
|
|
|
|
virtual void Initialize(const TrainerDesc& desc) = 0;
|
|
|
|
|
virtual void SetDeviceIndex(int tid) = 0;
|
|
|
|
|
virtual void TrainFiles() = 0;
|
|
|
|
|
virtual void PrintFetchVars(int batch_cnt) = 0;
|
|
|
|
|
virtual void PrintFetchVars() = 0;
|
|
|
|
|
virtual void TrainFilesWithProfiler() = 0;
|
|
|
|
|
virtual void CreateDeviceResource(const ProgramDesc& main_prog) = 0;
|
|
|
|
|
// will make this zero copy in the future
|
|
|
|
@ -111,6 +111,8 @@ class DeviceWorker {
|
|
|
|
|
Scope* root_scope_;
|
|
|
|
|
paddle::platform::Place place_;
|
|
|
|
|
std::shared_ptr<DataFeed> device_reader_;
|
|
|
|
|
int64_t batch_num_;
|
|
|
|
|
FetchConfig fetch_config_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CPUWorkerBase : public DeviceWorker {
|
|
|
|
@ -120,7 +122,7 @@ class CPUWorkerBase : public DeviceWorker {
|
|
|
|
|
virtual void SetDeviceIndex(int tid) { thread_id_ = tid; }
|
|
|
|
|
virtual void TrainFiles() = 0;
|
|
|
|
|
virtual void TrainFilesWithProfiler() {}
|
|
|
|
|
virtual void PrintFetchVars(int batch_cnt) {}
|
|
|
|
|
virtual void PrintFetchVars() {}
|
|
|
|
|
virtual void CreateDeviceResource(const ProgramDesc& main_prog) {}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
@ -134,7 +136,7 @@ class HogwildWorker : public CPUWorkerBase {
|
|
|
|
|
virtual void Initialize(const TrainerDesc& desc);
|
|
|
|
|
virtual void TrainFiles();
|
|
|
|
|
virtual void TrainFilesWithProfiler();
|
|
|
|
|
virtual void PrintFetchVars(int batch_cnt);
|
|
|
|
|
virtual void PrintFetchVars();
|
|
|
|
|
virtual void CreateDeviceResource(const ProgramDesc& main_prog);
|
|
|
|
|
virtual void BindingDataFeedMemory();
|
|
|
|
|
|
|
|
|
@ -144,9 +146,6 @@ class HogwildWorker : public CPUWorkerBase {
|
|
|
|
|
std::vector<std::string> op_names_;
|
|
|
|
|
std::vector<OperatorBase*> ops_;
|
|
|
|
|
Scope* thread_scope_;
|
|
|
|
|
std::vector<std::string> fetch_var_names_;
|
|
|
|
|
std::vector<std::vector<float>> fetch_values_;
|
|
|
|
|
int batch_cnt_per_print_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class DownpourWorker : public HogwildWorker {
|
|
|
|
|