You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
201 lines
6.5 KiB
201 lines
6.5 KiB
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
#pragma once
|
|
|
|
#include <fstream>
|
|
#include <map>
|
|
#include <memory>
|
|
#include <mutex> // NOLINT
|
|
#include <string>
|
|
#include <thread> // NOLINT
|
|
#include <vector>
|
|
|
|
#include "paddle/fluid/framework/data_feed.h"
|
|
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
|
|
#include "paddle/fluid/framework/lod_tensor.h"
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
#include "paddle/fluid/framework/program_desc.h"
|
|
#include "paddle/fluid/framework/reader.h"
|
|
#include "paddle/fluid/framework/trainer_desc.pb.h"
|
|
#include "paddle/fluid/framework/variable_helper.h"
|
|
#include "paddle/fluid/operators/reader/blocking_queue.h"
|
|
#include "paddle/fluid/platform/place.h"
|
|
#include "paddle/fluid/platform/port.h"
|
|
#include "paddle/fluid/platform/timer.h"
|
|
|
|
namespace paddle {
|
|
namespace framework {
|
|
|
|
class PullDenseWorker {
|
|
public:
|
|
virtual ~PullDenseWorker() {}
|
|
virtual void Initialize(const TrainerDesc& param);
|
|
int Start();
|
|
void Stop();
|
|
void SetRootScope(Scope* scope) { root_scope_ = scope; }
|
|
void IncreaseThreadVersion(int thread_id, uint64_t table_id);
|
|
void ResetThreadVersion(uint64_t table_id);
|
|
void Wait(std::vector<::std::future<int32_t>>* status_vec);
|
|
void PullDense(bool force_update = false);
|
|
static std::shared_ptr<PullDenseWorker> GetInstance() {
|
|
if (NULL == s_instance_) {
|
|
s_instance_.reset(new paddle::framework::PullDenseWorker());
|
|
}
|
|
return s_instance_;
|
|
}
|
|
|
|
private:
|
|
PullDenseWorker() : root_scope_(NULL) {}
|
|
void Run();
|
|
bool CheckUpdateParam(uint64_t table_id);
|
|
|
|
private:
|
|
static std::shared_ptr<PullDenseWorker> s_instance_;
|
|
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
|
|
PullDenseWorkerParameter param_;
|
|
DownpourWorkerParameter dwp_param_;
|
|
Scope* root_scope_;
|
|
bool running_;
|
|
|
|
static std::map<uint64_t, uint64_t> last_versions_;
|
|
static std::map<uint64_t, uint64_t> current_version_;
|
|
static std::mutex mutex_for_version_;
|
|
static std::map<uint64_t, std::vector<uint64_t>> training_versions_;
|
|
static std::map<uint64_t, std::vector<std::string>> dense_value_names_;
|
|
|
|
std::thread t_;
|
|
int thread_num_;
|
|
int sleep_time_ms_;
|
|
int threshold_;
|
|
|
|
std::vector<::std::future<int32_t>> pull_dense_status_;
|
|
uint32_t pull_dense_fail_times_ = 0;
|
|
std::vector<float> base_norm_param_;
|
|
std::vector<float> mean_;
|
|
std::vector<float> scale_;
|
|
float squared_sum_epsilon_ = 1e-4;
|
|
std::mutex mutex_for_mean_scale_;
|
|
float total_batch_num_ = 0;
|
|
};
|
|
|
|
// should incorporate different type of device
|
|
class DeviceWorker {
|
|
public:
|
|
DeviceWorker() { use_cvm_ = false; }
|
|
virtual ~DeviceWorker() {}
|
|
virtual void Initialize(const TrainerDesc& desc) = 0;
|
|
virtual void SetDeviceIndex(int tid) = 0;
|
|
virtual void TrainFiles() = 0;
|
|
virtual void PrintFetchVars() = 0;
|
|
virtual void TrainFilesWithProfiler() = 0;
|
|
virtual void CreateDeviceResource(const ProgramDesc& main_prog) = 0;
|
|
// will make this zero copy in the future
|
|
virtual void BindingDataFeedMemory() = 0;
|
|
virtual void SetRootScope(Scope* root_scope);
|
|
virtual void SetDataFeed(const std::shared_ptr<DataFeed>& data_feed);
|
|
virtual void SetPlace(const paddle::platform::Place& place) {
|
|
place_ = place;
|
|
}
|
|
|
|
protected:
|
|
Scope* root_scope_;
|
|
paddle::platform::Place place_;
|
|
std::shared_ptr<DataFeed> device_reader_;
|
|
int64_t batch_num_;
|
|
FetchConfig fetch_config_;
|
|
bool use_cvm_;
|
|
};
|
|
|
|
class CPUWorkerBase : public DeviceWorker {
|
|
public:
|
|
CPUWorkerBase() {}
|
|
virtual ~CPUWorkerBase() {}
|
|
virtual void SetDeviceIndex(int tid) { thread_id_ = tid; }
|
|
virtual void TrainFiles() = 0;
|
|
virtual void TrainFilesWithProfiler() {}
|
|
virtual void PrintFetchVars() {}
|
|
virtual void CreateDeviceResource(const ProgramDesc& main_prog) {}
|
|
|
|
protected:
|
|
int thread_id_;
|
|
};
|
|
|
|
class HogwildWorker : public CPUWorkerBase {
|
|
public:
|
|
HogwildWorker() {}
|
|
virtual ~HogwildWorker() {}
|
|
virtual void Initialize(const TrainerDesc& desc);
|
|
virtual void TrainFiles();
|
|
virtual void TrainFilesWithProfiler();
|
|
virtual void PrintFetchVars();
|
|
virtual void CreateDeviceResource(const ProgramDesc& main_prog);
|
|
virtual void BindingDataFeedMemory();
|
|
|
|
protected:
|
|
void CreateThreadOperators(const ProgramDesc& program);
|
|
void CreateThreadScope(const ProgramDesc& program);
|
|
std::vector<std::string> op_names_;
|
|
std::vector<OperatorBase*> ops_;
|
|
Scope* thread_scope_;
|
|
HogwildWorkerParameter param_;
|
|
std::vector<std::string> skip_ops_;
|
|
};
|
|
|
|
class DownpourWorker : public HogwildWorker {
|
|
public:
|
|
DownpourWorker() {}
|
|
virtual ~DownpourWorker() {}
|
|
virtual void Initialize(const TrainerDesc& desc);
|
|
virtual void TrainFiles();
|
|
virtual void TrainFilesWithProfiler();
|
|
|
|
protected:
|
|
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
|
|
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
|
|
void FillSparseValue(size_t table_id);
|
|
void PushGradients();
|
|
void CollectLabelInfo(size_t table_id);
|
|
|
|
private:
|
|
bool need_to_push_dense_;
|
|
bool need_to_push_sparse_;
|
|
DownpourWorkerParameter param_;
|
|
// just save the value in param_ for easy access
|
|
std::map<uint64_t, std::string> label_var_name_;
|
|
std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
|
|
std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
|
|
std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
|
|
std::map<uint64_t, std::vector<std::string>> dense_value_names_;
|
|
std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
|
|
|
|
// feasign
|
|
std::map<uint64_t, std::vector<uint64_t>> features_;
|
|
// feasign stats
|
|
std::map<uint64_t, std::vector<float>> feature_labels_;
|
|
// feasign embedding
|
|
std::map<uint64_t, std::vector<std::vector<float>>> feature_values_;
|
|
// feasign embedding gradient
|
|
std::map<uint64_t, std::vector<std::vector<float>>> feature_grads_;
|
|
// skipped ops
|
|
std::vector<std::string> skip_ops_;
|
|
|
|
std::shared_ptr<PullDenseWorker> _pull_dense_worker;
|
|
std::vector<::std::future<int32_t>> push_sparse_status_;
|
|
std::vector<::std::future<int32_t>> push_dense_status_;
|
|
};
|
|
|
|
} // namespace framework
|
|
} // namespace paddle
|