Pipeline Concurrency (#17402)

Add Pipeline Concurrency Train Mode:
- Cpp: pipeline_trainer & section_worker
- Python: PipelineOptimizer
- Add a new data_feed type: PrivateInstantDataFeed
- Add a test demo of pipeline trainer and the test model is gnn
- Do not support win32 now
dependabot/pip/python/requests-2.20.0
hutuxian 6 years ago committed by GitHub
parent ed118ee306
commit 969e6378b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -542,6 +542,15 @@ paddle.fluid.optimizer.ExponentialMovingAverage.__init__ (ArgSpec(args=['self',
paddle.fluid.optimizer.ExponentialMovingAverage.apply (ArgSpec(args=['self', 'executor', 'need_restore'], varargs=None, keywords=None, defaults=(True,)), ('document', '30f494752ac8921dc5835a63637f453a'))
paddle.fluid.optimizer.ExponentialMovingAverage.restore (ArgSpec(args=['self', 'executor'], varargs=None, keywords=None, defaults=None), ('document', '8c8a1791608b02a1ede53d6dd3a4fcec'))
paddle.fluid.optimizer.ExponentialMovingAverage.update (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'ea10f08af6d7aac3b7974aa976e4085f'))
paddle.fluid.optimizer.PipelineOptimizer.__init__ (ArgSpec(args=['self', 'optimizer', 'cut_list', 'place_list', 'concurrency_list', 'queue_size', 'sync_steps', 'start_cpu_core_id'], varargs=None, keywords=None, defaults=(None, None, None, 30, 1, 0)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.PipelineOptimizer.create_vars (ArgSpec(args=['self', 'block', 'main_program'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.PipelineOptimizer.extract_section_ops (ArgSpec(args=['self', 'ops', 'cut_point_name'], varargs=None, keywords=None, defaults=None), ('document', '4a29be77da04b5c30dd7202f44c79b70'))
paddle.fluid.optimizer.PipelineOptimizer.extract_section_opt_ops (ArgSpec(args=['self', 'ops', 'cut_point_name'], varargs=None, keywords=None, defaults=None), ('document', '99e0f641222c1ce4dd0d7194c3b2c653'))
paddle.fluid.optimizer.PipelineOptimizer.find_input_output (ArgSpec(args=['self', 'ops', 'name', 'is_forward'], varargs=None, keywords=None, defaults=(True,)), ('document', '92d77fb262766b352746f09cca81db93'))
paddle.fluid.optimizer.PipelineOptimizer.find_persistable_vars (ArgSpec(args=['self', 'ops', 'whole_parameters'], varargs=None, keywords=None, defaults=None), ('document', '877b7cc290f0647455e5e4409e825923'))
paddle.fluid.optimizer.PipelineOptimizer.find_section_opt (ArgSpec(args=['self', 'ops', 'params'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.PipelineOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.optimizer.PipelineOptimizer.split_program (ArgSpec(args=['self', 'main_program', 'cut_list'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.backward.append_backward (ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '08a5dd9f6f376ff3d55e0b1d92115cbd'))
paddle.fluid.regularizer.L1DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.regularizer.L2DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))

@ -173,20 +173,20 @@ endif()
cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector)
if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc multi_trainer.cc dataset_factory.cc
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
cc_library(executor SRCS executor.cc multi_trainer.cc dataset_factory.cc
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer data_feed_proto)
@ -201,10 +201,10 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
fast_threaded_ssa_graph_executor variable_helper)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc pipeline_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc dataset_factory.cc
downpour_worker.cc pull_dense_worker.cc section_worker.cc
device_worker_factory.cc data_set.cc dataset_factory.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass data_feed_proto

@ -85,8 +85,9 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
}
DataFeedDesc data_feed_desc;
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc);
bool success = data_feed_desc.ParseFromString(data_feed_desc_str);
PADDLE_ENFORCE(success, "Fail to parse DataFeedDesc from string:\n%s",
data_feed_desc_str.c_str());
actual_thread_num_ = thread_num;
int file_cnt = filelist.size();

@ -20,6 +20,9 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed.h"
#ifdef _LINUX
#include <stdio_ext.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#endif
#include <utility>
#include "gflags/gflags.h"
@ -87,6 +90,13 @@ void DataFeed::CheckStart() {
PADDLE_ENFORCE(finish_start_, "Datafeed has not started running yet.");
}
void DataFeed::AssignFeedVar(const Scope& scope) {
CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) {
feed_vec_[i] = scope.FindVar(use_slots_[i])->GetMutable<LoDTensor>();
}
}
template <typename T>
void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
PADDLE_ENFORCE(queue_size > 0, "Illegal queue size: %d.", queue_size);
@ -1009,5 +1019,205 @@ void MultiSlotInMemoryDataFeed::DeserializeIns(
fleet_ptr->Deserialize(ins, str);
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
template <typename T>
void PrivateInstantDataFeed<T>::PutToFeedVec() {
for (size_t i = 0; i < use_slots_.size(); ++i) {
const auto& type = ins_vec_[i].GetType();
const auto& offset = ins_vec_[i].GetOffset();
int total_instance = static_cast<int>(offset.back());
if (type[0] == 'f') { // float
const auto& feasign = ins_vec_[i].GetFloatData();
float* tensor_ptr = feed_vec_[i]->mutable_data<float>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle
const auto& feasign = ins_vec_[i].GetUint64Data();
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
}
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
int64_t total_dims = 1;
for (const auto e : use_slots_shape_[i]) {
total_dims *= e;
}
PADDLE_ENFORCE(
total_dims == total_instance,
"The actual data size of slot[%s] doesn't match its declaration",
use_slots_[i].c_str());
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
}
template <typename T>
int PrivateInstantDataFeed<T>::Next() {
if (ParseOneMiniBatch()) {
PutToFeedVec();
return ins_vec_[0].GetBatchSize();
}
Postprocess();
std::string filename;
if (!PickOneFile(&filename)) {
return -1;
}
if (!Preprocess(filename)) {
return -1;
}
PADDLE_ENFORCE(true == ParseOneMiniBatch(), "Fail to parse mini-batch data");
PutToFeedVec();
return ins_vec_[0].GetBatchSize();
}
template <typename T>
void PrivateInstantDataFeed<T>::Init(const DataFeedDesc& data_feed_desc) {
finish_init_ = false;
finish_set_filelist_ = false;
finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(),
"Multi_slot_desc has not been set.");
paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size());
size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num);
use_slots_index_.resize(all_slot_num);
multi_inductive_shape_index_.resize(all_slot_num);
use_slots_.clear();
use_slots_is_dense_.clear();
for (size_t i = 0; i < all_slot_num; ++i) {
const auto& slot = multi_slot_desc.slots(i);
all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape;
if (slot.is_dense()) {
for (size_t j = 0; j < slot.shape_size(); ++j) {
if (slot.shape(j) == -1) {
multi_inductive_shape_index_[i].push_back(j);
}
}
}
for (size_t j = 0; j < slot.shape_size(); ++j) {
local_shape.push_back(slot.shape(j));
}
use_slots_shape_.push_back(local_shape);
}
}
feed_vec_.resize(use_slots_.size());
ins_vec_.resize(use_slots_.size());
finish_init_ = true;
}
template class PrivateInstantDataFeed<std::vector<MultiSlotType>>;
bool MultiSlotFileInstantDataFeed::Preprocess(const std::string& filename) {
fd_ = open(filename.c_str(), O_RDONLY);
PADDLE_ENFORCE(fd_ != -1, "Fail to open file: %s", filename.c_str());
struct stat sb;
fstat(fd_, &sb);
end_ = static_cast<size_t>(sb.st_size);
buffer_ =
reinterpret_cast<char*>(mmap(NULL, end_, PROT_READ, MAP_PRIVATE, fd_, 0));
PADDLE_ENFORCE(buffer_ != MAP_FAILED, strerror(errno));
offset_ = 0;
return true;
}
bool MultiSlotFileInstantDataFeed::Postprocess() {
if (buffer_ != nullptr) {
munmap(buffer_, end_);
buffer_ = nullptr;
}
if (fd_ != -1) {
close(fd_);
fd_ = -1;
end_ = 0;
offset_ = 0;
}
return true;
}
bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() {
if (offset_ == end_) {
return false;
}
batch_size_ = 0;
while (batch_size_ < default_batch_size_ && offset_ < end_) {
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
char type = all_slots_type_[i][0];
uint16_t num = *reinterpret_cast<uint16_t*>(buffer_ + offset_);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.");
offset_ += sizeof(uint16_t);
if (idx != -1) {
int inductive_size = multi_inductive_shape_index_[i].size();
if (UNLIKELY(batch_size_ == 0)) {
ins_vec_[idx].Init(all_slots_type_[i], default_batch_size_ * num);
ins_vec_[idx].InitOffset(default_batch_size_);
uint64_t* inductive_shape =
reinterpret_cast<uint64_t*>(buffer_ + offset_);
for (int inductive_id = 0; inductive_id < inductive_size;
++inductive_id) {
use_slots_shape_[i][multi_inductive_shape_index_[i][inductive_id]] =
static_cast<int>(*(inductive_shape + inductive_id));
}
}
num -= inductive_size;
offset_ += sizeof(uint64_t) * inductive_size;
if (type == 'f') {
ins_vec_[idx].AppendValues(
reinterpret_cast<float*>(buffer_ + offset_), num);
offset_ += num * sizeof(float);
} else if (type == 'u') {
ins_vec_[idx].AppendValues(
reinterpret_cast<uint64_t*>(buffer_ + offset_), num);
offset_ += num * sizeof(uint64_t);
}
} else {
if (type == 'f') {
offset_ += num * sizeof(float);
} else if (type == 'u') {
offset_ += num * sizeof(uint64_t);
}
}
}
++batch_size_;
// OPTIMIZE: It is better to insert check codes between instances for format
// checking
}
PADDLE_ENFORCE(batch_size_ == default_batch_size_ || offset_ == end_,
"offset_ != end_");
return true;
}
#endif
} // namespace framework
} // namespace paddle

@ -59,7 +59,7 @@ class DataFeed {
file_idx_ = nullptr;
}
virtual ~DataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
virtual bool CheckFile(const char* filename) {
PADDLE_THROW("This function(CheckFile) is not implemented.");
}
@ -84,6 +84,9 @@ class DataFeed {
// This function is used for binding feed_vec memory
virtual void AddFeedVar(Variable* var, const std::string& name);
// This function is used for binding feed_vec memory in a given scope
virtual void AssignFeedVar(const Scope& scope);
// This function will do nothing at default
virtual void SetMemoryData(void* memory_data) {}
// This function will do nothing at default
@ -148,6 +151,8 @@ class DataFeed {
std::vector<std::vector<int>> use_slots_shape_;
std::vector<int> inductive_shape_index_;
std::vector<int> total_dims_without_inductive_;
// For the inductive shape passed within data
std::vector<std::vector<int>> multi_inductive_shape_index_;
std::vector<int>
use_slots_index_; // -1: not used; >=0: the index of use_slots_
@ -173,7 +178,6 @@ class PrivateQueueDataFeed : public DataFeed {
public:
PrivateQueueDataFeed() {}
virtual ~PrivateQueueDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool Start();
virtual int Next();
@ -212,7 +216,7 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
public:
InMemoryDataFeed();
virtual ~InMemoryDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
virtual bool Start();
virtual int Next();
virtual void SetMemoryData(void* memory_data);
@ -263,16 +267,25 @@ class MultiSlotType {
public:
MultiSlotType() {}
~MultiSlotType() {}
void Init(const std::string& type) {
void Init(const std::string& type, size_t reserved_size = 0) {
CheckType(type);
if (type_[0] == 'f') {
float_feasign_.clear();
if (reserved_size) {
float_feasign_.reserve(reserved_size);
}
} else if (type_[0] == 'u') {
uint64_feasign_.clear();
if (reserved_size) {
uint64_feasign_.reserve(reserved_size);
}
}
type_ = type;
}
void InitOffset() {
void InitOffset(size_t max_batch_size = 0) {
if (max_batch_size > 0) {
offset_.reserve(max_batch_size + 1);
}
offset_.resize(1);
// LoDTensor' lod is counted from 0, the size of lod
// is one size larger than the size of data.
@ -288,6 +301,16 @@ class MultiSlotType {
CheckUint64();
uint64_feasign_.push_back(v);
}
void CopyValues(const float* input, size_t size) {
CheckFloat();
float_feasign_.resize(size);
memcpy(float_feasign_.data(), input, size * sizeof(float));
}
void CopyValues(const uint64_t* input, size_t size) {
CheckUint64();
uint64_feasign_.resize(size);
memcpy(uint64_feasign_.data(), input, size * sizeof(uint64_t));
}
void AddIns(const MultiSlotType& ins) {
if (ins.GetType()[0] == 'f') { // float
CheckFloat();
@ -301,11 +324,22 @@ class MultiSlotType {
uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end());
}
}
void AppendValues(const uint64_t* input, size_t size) {
CheckUint64();
offset_.push_back(offset_.back() + size);
uint64_feasign_.insert(uint64_feasign_.end(), input, input + size);
}
void AppendValues(const float* input, size_t size) {
CheckFloat();
offset_.push_back(offset_.back() + size);
float_feasign_.insert(float_feasign_.end(), input, input + size);
}
const std::vector<float>& GetFloatData() const { return float_feasign_; }
std::vector<float>& MutableFloatData() { return float_feasign_; }
const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; }
std::vector<uint64_t>& MutableUint64Data() { return uint64_feasign_; }
const std::string& GetType() const { return type_; }
size_t GetBatchSize() { return offset_.size() - 1; }
std::string& MutableType() { return type_; }
std::string DebugString() {
@ -355,7 +389,7 @@ class MultiSlotDataFeed
public:
MultiSlotDataFeed() {}
virtual ~MultiSlotDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc);
virtual void Init(const DataFeedDesc& data_feed_desc);
virtual bool CheckFile(const char* filename);
// virtual void ReadThread();
@ -374,7 +408,7 @@ class MultiSlotInMemoryDataFeed
public:
MultiSlotInMemoryDataFeed() {}
virtual ~MultiSlotInMemoryDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc);
virtual void Init(const DataFeedDesc& data_feed_desc);
protected:
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
@ -389,5 +423,54 @@ class MultiSlotInMemoryDataFeed
const std::string& str);
};
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
template <typename T>
class PrivateInstantDataFeed : public DataFeed {
public:
PrivateInstantDataFeed() {}
virtual ~PrivateInstantDataFeed() {}
void Init(const DataFeedDesc& data_feed_desc) override;
bool Start() override { return true; }
int Next() override;
protected:
// The batched data buffer
std::vector<MultiSlotType> ins_vec_;
// This function is used to preprocess with a given filename, e.g. open it or
// mmap
virtual bool Preprocess(const std::string& filename) = 0;
// This function is used to postprocess system resource such as closing file
// NOTICE: Ensure that it is safe to call before Preprocess
virtual bool Postprocess() = 0;
// The reading and parsing method.
virtual bool ParseOneMiniBatch() = 0;
// This function is used to put ins_vec to feed_vec
virtual void PutToFeedVec();
};
class MultiSlotFileInstantDataFeed
: public PrivateInstantDataFeed<std::vector<MultiSlotType>> {
public:
MultiSlotFileInstantDataFeed() {}
virtual ~MultiSlotFileInstantDataFeed() {}
protected:
int fd_{-1};
char* buffer_{nullptr};
size_t end_{0};
size_t offset_{0};
bool Preprocess(const std::string& filename) override;
bool Postprocess() override;
bool ParseOneMiniBatch() override;
};
#endif
} // namespace framework
} // namespace paddle

@ -64,5 +64,8 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
REGISTER_DATAFEED_CLASS(MultiSlotFileInstantDataFeed);
#endif
} // namespace framework
} // namespace paddle

@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <atomic>
#include <fstream>
#include <map>
#include <memory>
@ -35,9 +36,17 @@ limitations under the License. */
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/timer.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace framework {
#define SEC_LOG \
VLOG(3) << "[s" << section_id_ << "p" << pipeline_id_ << "t" << thread_id_ \
<< "]: "
class PullDenseWorker {
public:
virtual ~PullDenseWorker() {}
@ -196,5 +205,101 @@ class DownpourWorker : public HogwildWorker {
std::vector<::std::future<int32_t>> push_dense_status_;
};
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
using ScopeQueue = operators::reader::BlockingQueue<Scope*>;
class SyncFunctor {
public:
SyncFunctor(int rank_id, int rank_num, int sync_steps);
virtual ~SyncFunctor() {}
void SetSyncParam(const std::vector<std::string>& sync_param) {
sync_param_ = &sync_param;
}
void SetNcclCtxMap(platform::NCCLContextMap* nccl_ctx_map) {
nccl_ctx_map_ = nccl_ctx_map;
}
int operator()(Scope* scope);
static std::vector<Scope*> pipeline_scopes_;
static uint64_t sync_flag_;
protected:
const int rank_id_;
const int rank_num_;
const std::vector<std::string>* sync_param_ = nullptr;
platform::NCCLContextMap* nccl_ctx_map_ = nullptr;
uint64_t sync_signal_;
const int sync_steps_;
int counter_;
void Synchronize();
};
class SectionWorker : public DeviceWorker {
public:
SectionWorker() {}
~SectionWorker() override {}
void Initialize(const TrainerDesc& desc) override;
void BindingDataFeedMemory() override {}
void CreateDeviceResource(const ProgramDesc& main_prog) override{};
void TrainFiles() override;
void TrainFilesWithProfiler() override;
void PrintFetchVars() override {}
const platform::Place& place() const { return place_; }
void SetSectionIndex(int section_id) { section_id_ = section_id; }
void SetDeviceIndex(int tid) override { pipeline_id_ = tid; }
void SetThreadIndex(int thread_id) { thread_id_ = thread_id; }
void SetVarNames(const std::vector<std::string>& in_var_names,
const std::vector<std::string>& out_var_names) {
in_var_names_ = &in_var_names;
out_var_names_ = &out_var_names;
}
void SetScopeQueue(ScopeQueue* in_scope_queue, ScopeQueue* out_scope_queue) {
in_scope_queue_ = in_scope_queue;
out_scope_queue_ = out_scope_queue;
}
void SetCountMutex(std::mutex* mutex) { worker_count_mutex_ = mutex; }
void SetWorkerCount(int* worker_count) { worker_count_ = worker_count; }
void SetSectionNum(int section_num) { section_num_ = section_num; }
void SetPipelineNum(int pipeline_num) { pipeline_num_ = pipeline_num; }
void SetNextSectionPlace(const paddle::platform::Place& place) {
next_section_place_ = place;
}
SyncFunctor* sync_func_ = nullptr;
void SetSyncFunctor(SyncFunctor* sync_func) { sync_func_ = sync_func; }
static std::atomic<int> cpu_id_;
protected:
void AutoSetCPUAffinity(bool reuse);
int section_id_;
int pipeline_id_;
int section_num_;
int pipeline_num_;
int thread_id_;
// This worker will consume scope from in_scope_queue_
// and produce scope to out_scope_queue_
ScopeQueue* in_scope_queue_ = nullptr;
ScopeQueue* out_scope_queue_ = nullptr;
const std::vector<std::string>* in_var_names_ = nullptr;
const std::vector<std::string>* out_var_names_ = nullptr;
std::mutex* worker_count_mutex_ = nullptr;
int* worker_count_ = nullptr;
paddle::platform::Place next_section_place_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
platform::DeviceContext* dev_ctx_ = nullptr;
};
#endif
} // namespace framework
} // namespace paddle

@ -61,5 +61,8 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
REGISTER_DEVICE_WORKER_CLASS(HogwildWorker);
REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
REGISTER_DEVICE_WORKER_CLASS(SectionWorker);
#endif
} // namespace framework
} // namespace paddle

@ -122,8 +122,9 @@ void Executor::RunFromDataset(const ProgramDesc& main_program, Scope* scope,
const std::string& trainer_desc_str) {
VLOG(3) << "Start to RunFromDataset in executor";
TrainerDesc trainer_desc;
google::protobuf::TextFormat::ParseFromString(trainer_desc_str,
&trainer_desc);
bool success = trainer_desc.ParseFromString(trainer_desc_str);
PADDLE_ENFORCE(success, "Fail to parse TrainerDesc from string:\n%s",
trainer_desc_str.c_str());
VLOG(3) << "Going to create trainer, trainer class is "
<< trainer_desc.class_name();
std::shared_ptr<TrainerBase> trainer;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -91,5 +91,58 @@ class DistMultiTrainer : public MultiTrainer {
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
};
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
class PipelineTrainer : public TrainerBase {
public:
PipelineTrainer() {}
~PipelineTrainer() override {}
void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set) override;
void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) override;
void InitOtherEnv(const ProgramDesc& main_program) override {}
void Run() override;
void Finalize() override;
protected:
int section_num_;
int pipeline_num_;
int scope_queue_size_;
int sync_steps_;
SectionWorkerParameter pipeline_config_;
// The in/output var names for each section
std::vector<std::unique_ptr<std::vector<std::string>>> in_var_names_;
std::vector<std::unique_ptr<std::vector<std::string>>> out_var_names_;
// Counter for the running thread
std::vector<std::vector<int*>> worker_count_;
std::vector<std::vector<std::unique_ptr<std::mutex>>> worker_count_mutex_;
// worker: [section_id][pipeline_id][thread_id]
std::vector<std::vector<
std::vector<std::shared_ptr<paddle::framework::DeviceWorker>>>>
workers_;
std::vector<std::thread> section_threads_;
// We use scope to maintain context info, and scopes
// will be deliverd between different sections.
std::vector<std::vector<std::unique_ptr<ScopeQueue>>> scope_queues_;
std::vector<Scope*> pipeline_scopes_;
// The parameters that should be syncronized between different cards using
// nccl all-reduce
std::shared_ptr<std::vector<std::string>> param_need_sync_;
std::vector<std::unique_ptr<SyncFunctor>> sync_functors_;
std::shared_ptr<platform::NCCLContextMap> nccl_ctx_map_;
std::vector<std::shared_ptr<DataFeed>> readers_;
void InitFirstScopeQueue(ScopeQueue* scope_queue, int pipeline_id,
const ProgramDesc& main_program);
void CopyParameters(const Scope& root_scope, int pipeline_id);
void construct_sync_functor();
};
#endif
} // namespace framework
} // namespace paddle

@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
option optimize_for = LITE_RUNTIME;
import "data_feed.proto";
import "framework.proto";
package paddle.framework;
message TrainerDesc {
@ -36,6 +38,7 @@ message TrainerDesc {
optional HogwildWorkerParameter hogwild_param = 101;
optional DownpourWorkerParameter downpour_param = 103;
optional PullDenseWorkerParameter pull_dense_param = 102;
optional SectionWorkerParameter section_param = 104;
// datafeed desc
optional DataFeedDesc data_desc = 201;
}
@ -51,6 +54,30 @@ message DownpourWorkerParameter {
optional bool push_dense = 6 [ default = true ];
}
message SectionWorkerParameter {
repeated SectionConfig section_config = 1;
optional int32 queue_size = 2 [ default = 1 ];
optional int64 sync_steps = 3 [ default = 1 ];
optional int32 start_cpu_core_id = 4 [ default = 1 ];
repeated string param_need_sync = 5;
}
message SectionConfig {
enum Place {
CPUPlace = 0;
CUDAPlace = 1;
CUDAPinnedPlace = 2;
}
// FIXME: How to use proto::ProgramDesc
// required string program_desc_str = 1;
optional proto.ProgramDesc program_desc = 1;
optional Place place = 2;
optional int32 concurrency = 3 [ default = 1 ];
repeated string section_in_var_names = 4;
repeated string section_out_var_names = 5;
}
message FetchConfig {
enum Method { PRINT = 0; }
repeated string fetch_var_names = 1;

@ -63,5 +63,8 @@ std::shared_ptr<TrainerBase> TrainerFactory::CreateTrainer(
REGISTER_TRAINER_CLASS(MultiTrainer);
REGISTER_TRAINER_CLASS(DistMultiTrainer);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
REGISTER_TRAINER_CLASS(PipelineTrainer);
#endif
} // namespace framework
} // namespace paddle

@ -50,7 +50,7 @@ class Timer {
struct timeval _start;
struct timeval _now;
int _count;
int _elapsed;
int64_t _elapsed;
bool _paused;
// get us difference between start and now

@ -21,7 +21,7 @@ __all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset']
class DatasetFactory(object):
"""
DatasetFactory is a factory which create dataset by its name,
you can create "QueueDataset" or "InMemoryDataset",
you can create "QueueDataset" or "InMemoryDataset", or "FileInstantDataset",
the default is "QueueDataset".
Example:
@ -38,7 +38,7 @@ class DatasetFactory(object):
def create_dataset(self, datafeed_class="QueueDataset"):
"""
Create "QueueDataset" or "InMemoryDataset",
Create "QueueDataset" or "InMemoryDataset", or "FileInstantDataset",
the default is "QueueDataset".
Args:
@ -450,3 +450,36 @@ class QueueDataset(DatasetBase):
raise NotImplementedError(
"QueueDataset does not support global shuffle, "
"please use InMemoryDataset for global_shuffle")
class FileInstantDataset(DatasetBase):
"""
FileInstantDataset, it will process data streamly.
Example:
import paddle.fluid as fluid
dataset = fluid.DatasetFactory.create_dataset("FileInstantDataset")
"""
def __init__(self):
"""
Init
"""
super(FileInstantDataset, self).__init__()
self.proto_desc.name = "MultiSlotFileInstantDataFeed"
def local_shuffle(self):
"""
Local shuffle
FileInstantDataset does not support local shuffle
"""
raise NotImplementedError(
"FileInstantDataset does not support local shuffle, "
"please use InMemoryDataset for local_shuffle")
def global_shuffle(self, fleet=None):
"""
Global shuffle
"""
raise NotImplementedError(
"FileInstantDataset does not support global shuffle, "
"please use InMemoryDataset for global_shuffle")

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD']
__all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD', 'Section']
class DeviceWorker(object):
@ -181,6 +181,58 @@ class DownpourSGD(DeviceWorker):
downpour.push_sparse = False
class Section(DeviceWorker):
"""
SectionWorker
"""
def __init__(self):
"""
Init.
"""
super(Section, self).__init__()
def _gen_worker_desc(self, trainer_desc):
"""
Generator worker desc, which device worker is SectionWorker.
Args:
trainer_desc(TrainerDesc): a TrainerDesc object
"""
from google.protobuf import text_format
from . import core
trainer_desc.device_worker_name = "SectionWorker"
pipeline_opt = self._program._pipeline_opt
section_param = trainer_desc.section_param
section_param.queue_size = pipeline_opt["queue_size"]
section_param.sync_steps = pipeline_opt["sync_steps"]
section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"]
for e in pipeline_opt["param_need_sync"]:
section_param.param_need_sync.append(e)
for i, program in enumerate(pipeline_opt["section_program_list"]):
cfg = section_param.section_config.add()
cfg.program_desc.ParseFromString(program["program"]._get_desc()
.serialize_to_string())
# TODO: why does not work
#cfg.program_desc.CopyFrom(program.program._get_desc())
place = pipeline_opt["place_list"][i]
if isinstance(place, core.CPUPlace):
cfg.place = cfg.CPUPlace
elif isinstance(place, core.CUDAPlace):
cfg.place = cfg.CUDAPlace
elif isinstance(place, core.CUDAPinnedPlace):
cfg.place = cfg.CUDAPinnedPlace
else:
raise NotImplementedError(
"SectionWorker only supports CPUPlace, CUDAPlace and CUDAPinnedPlace now."
)
cfg.concurrency = pipeline_opt["concurrency_list"][i]
for var in program["input_set"]:
cfg.section_in_var_names.append(var)
for var in program["output_set"]:
cfg.section_out_var_names.append(var)
class DeviceWorkerFactory(object):
def _create_device_worker(self, worker_type):
classname = worker_type.capitalize()

@ -781,12 +781,23 @@ class Executor(object):
assert len(fetch_list) == len(fetch_info)
compiled = isinstance(program, compiler.CompiledProgram)
if not compiled:
trainer = TrainerFactory()._create_trainer(program._fleet_opt)
# TODO: Need a better way to distinguish and specify different execution mode
if program._pipeline_opt:
trainer = TrainerFactory()._create_trainer(
program._pipeline_opt)
else:
trainer = TrainerFactory()._create_trainer(program._fleet_opt)
trainer._set_program(program)
else:
trainer = TrainerFactory()._create_trainer(
program.program._fleet_opt)
if program._pipeline_opt:
trainer = TrainerFactory()._create_trainer(
program.program._pipeline_opt)
else:
trainer = TrainerFactory()._create_trainer(
program.program._fleet_opt)
trainer._set_program(program.program)
# The following thread_num-determined logic will be deprecated
if thread <= 0:
if dataset.thread_num <= 0:
raise RuntimeError(
@ -796,6 +807,26 @@ class Executor(object):
trainer._set_thread(dataset.thread_num)
else:
trainer._set_thread(thread)
# Adjust the reader size for small file num
if program._pipeline_opt:
dataset.set_thread(thread *
program._pipeline_opt["concurrency_list"][0])
file_size = len(dataset.dataset.get_filelist())
if file_size < thread:
thread = file_size
print(
"Pipeline: setting the pipeline num to %d is enough because there are only %d files"
% (file_size, file_size))
if file_size < thread * program._pipeline_opt["concurrency_list"][
0]:
print(
"Pipeline: setting the 1st element in concurrency_list to %d is enough because there are only %d files"
% (file_size / thread, file_size))
program._pipeline_opt["concurrency_list"][
0] = file_size / thread
dataset.set_thread(
program._pipeline_opt["concurrency_list"][0] * thread)
trainer._set_debug(debug)
trainer._set_fetch_var_and_info(fetch_list, fetch_info, print_period)
return scope, trainer

@ -2785,6 +2785,9 @@ class Program(object):
self._fleet_opt = None
self._program_config = None
# assigned if this program has been parsed by a pipeline optimizer
self._pipeline_opt = None
@property
def _is_mem_optimized(self):
# if the program is optimized, operator input/outputs

@ -23,7 +23,7 @@ from paddle.fluid.framework import Program, Variable, name_scope, default_main_p
from . import framework
from . import layers
from . import unique_name
from .backward import append_backward
from .backward import append_backward, _some_in_set_, _append_grad_suffix_
from .clip import append_gradient_clip_ops, error_clip_callback
from .framework import program_guard
from .initializer import Constant
@ -43,7 +43,7 @@ __all__ = [
'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer',
'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'LarsMomentum',
'LarsMomentumOptimizer', 'DGCMomentumOptimizer', 'LambOptimizer',
'ExponentialMovingAverage'
'ExponentialMovingAverage', 'PipelineOptimizer'
]
@ -2607,3 +2607,230 @@ class ExponentialMovingAverage(object):
executor (Executor): The Executor to execute restoring.
"""
executor.run(self.restore_program)
class PipelineOptimizer(object):
def __init__(self,
optimizer,
cut_list=None,
place_list=None,
concurrency_list=None,
queue_size=30,
sync_steps=1,
start_cpu_core_id=0):
# TODO: check properties
self._optimizer = optimizer
self._cut_list = cut_list
self._place_list = place_list
self._concurrency_list = concurrency_list
self._queue_size = queue_size
self._sync_steps = sync_steps
self._start_cpu_core_id = start_cpu_core_id
def create_vars(self, block, main_program):
used_var_set = set()
for op_idx in range(block.desc.op_size()):
op_desc = block.desc.op(op_idx)
vars = op_desc.input_arg_names() + op_desc.output_arg_names()
for var in vars:
if var in used_var_set:
continue
used_var_set.add(var)
source_var = main_program.block(0).var(str(var))
block._clone_variable(source_var, False)
def extract_section_opt_ops(self, ops, cut_point_name):
"""
Extract opt ops in the given section
"""
output_names = set(cut_point_name)
relevant_op_flags = [True] * len(ops)
for i, op in reversed(list(enumerate(ops))):
if _some_in_set_(op.desc.output_arg_names(), output_names):
for name in op.desc.input_arg_names():
output_names.add(name)
else:
relevant_op_flags[i] = False
op_path = [ops[i] for i in range(len(ops)) if relevant_op_flags[i]]
return op_path
def find_input_output(self, ops, name, is_forward=True):
"""
Find the inputs or outputs of a section
"""
all_set = set()
part_set = set()
for op in ops:
if is_forward:
part_set.update(op.desc.output_arg_names())
else:
part_set.update(op.desc.input_arg_names())
all_set.update(op.desc.output_arg_names())
all_set.update(op.desc.input_arg_names())
return all_set - part_set
def find_persistable_vars(self, ops, whole_parameters):
"""
find the persistable input vars in current section
"""
res = set()
for op in ops:
vars = op.desc.input_arg_names()
for var in vars:
if var in whole_parameters:
res.add(var)
return res
def _is_opt_role_op(self, op):
op_maker = core.op_proto_and_checker_maker
optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
if op_maker.kOpRoleAttrName() in op.attr_names and \
int(op.all_attrs()[op_maker.kOpRoleAttrName()]) & int(optimize_role) != 0:
return True
return False
def _is_lr_role_op(self, op):
op_maker = core.op_proto_and_checker_maker
optimize_role = core.op_proto_and_checker_maker.OpRole.LRSched
if op_maker.kOpRoleAttrName() in op.attr_names and \
int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role):
return True
return False
def extract_section_ops(self, ops, cut_point_name):
"""
Extract ops in the given section
"""
output_names = set(cut_point_name)
relevant_op_flags = [True] * len(ops)
for i, op in reversed(list(enumerate(ops))):
if not self._is_opt_role_op(op) and _some_in_set_(
op.desc.output_arg_names(), output_names):
for name in op.desc.input_arg_names():
output_names.add(name)
elif op.desc.type() == "print" and op.desc.input_arg_names()[
0] in output_names:
continue
else:
relevant_op_flags[i] = False
op_path = [ops[i] for i in range(len(ops)) if relevant_op_flags[i]]
return op_path
def find_section_opt(self, ops, params):
res = self.extract_section_opt_ops(ops, params)
return res
def split_program(self, main_program, cut_list):
programs = []
block = main_program.block(0)
whole_parameters = [e.name for e in block.all_parameters()]
cut_var_names = []
cut_len = len(cut_list)
sec_params = []
for i, cut_vars in enumerate(cut_list[:-1]):
cut_var_names.append([cut_var.name for cut_var in cut_vars])
for i, cut_vars in reversed(list(enumerate(cut_list[:-1]))):
cut_var_names.append(
[_append_grad_suffix_(cut_var.name) for cut_var in cut_vars])
if i == 0:
cut_var_names[-1] += [var.name for var in cut_list[-1]]
ops = block.ops[:]
for i, cut_vars in enumerate(cut_var_names):
program = {
"program": Program(),
"input_set": set(),
"output_set": set()
}
cur_ops = self.extract_section_ops(ops, cut_vars)
if i == 0:
for op in ops:
if self._is_lr_role_op(op):
cur_ops.append(op)
#prevent inplace in/out
program["input_set"].update(
self.find_input_output(
cur_ops, [], is_forward=True))
for e in cur_ops:
ops.remove(e)
if i < cut_len:
sec_params.append(
self.find_persistable_vars(cur_ops, whole_parameters))
if i >= cut_len - 1:
opt_ops = self.find_section_opt(ops,
sec_params[2 * cut_len - 2 - i])
for e in opt_ops:
ops.remove(e)
cur_ops += opt_ops
op_descs = [op.desc for op in cur_ops]
for op_desc in op_descs:
ap_op = program["program"].block(0).desc.append_op()
ap_op.copy_from(op_desc)
program["input_set"].update(
self.find_input_output(
cur_ops, cut_vars, is_forward=True))
program["input_set"].update(sec_params[min(i, 2 * cut_len - 2 - i)])
program["output_set"].update(
self.find_input_output(
cur_ops, cut_vars, is_forward=False))
programs.append(program)
program = {
"program": Program(),
"input_set": set(),
"output_set": set()
}
op_descs = [op.desc for op in ops]
for op_desc in op_descs:
ap_op = program["program"].block(0).desc.append_op()
ap_op.copy_from(op_desc)
program["input_set"].update(
[cut_var.name + "@GRAD" for cut_var in cut_list[0]])
program["input_set"].update(
self.find_input_output(
ops, [], is_forward=True))
program["input_set"].update(sec_params[0])
programs.append(program)
inputs = set()
for program in reversed(list(programs)):
output_list = list(program["output_set"])
for output in output_list:
if output not in inputs:
program["output_set"].remove(output)
inputs.update(program["input_set"])
return programs
def minimize(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
self._optimizer.minimize(loss, startup_program, parameter_list,
no_grad_set)
program = loss.block.program
program_list = self.split_program(program, self._cut_list)
for p in program_list:
self.create_vars(p["program"].block(0), program)
whole_parameters = [e.name for e in program.block(0).all_parameters()]
param_need_sync = []
for i, section_p in enumerate(program_list):
if not isinstance(self._place_list[i], core.CUDAPlace):
continue
section_var = [e for e in section_p["program"].block(0).vars]
for p in section_var:
if p in whole_parameters:
param_need_sync.append(p)
program._pipeline_opt = {
"trainer": "PipelineTrainer",
"device_worker": "Section",
"section_program_list": program_list,
"place_list": self._place_list,
"concurrency_list": self._concurrency_list,
"queue_size": self._queue_size,
"start_cpu_core_id": self._start_cpu_core_id,
"sync_steps": self._sync_steps,
"param_need_sync": param_need_sync
}

File diff suppressed because it is too large Load Diff

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['TrainerDesc', 'MultiTrainer', 'DistMultiTrainer']
__all__ = ['TrainerDesc', 'MultiTrainer', 'DistMultiTrainer', 'PipelineTrainer']
# can be initialized from train_desc,
@ -66,7 +66,7 @@ class TrainerDesc(object):
def _desc(self):
from google.protobuf import text_format
return text_format.MessageToString(self.proto_desc)
return self.proto_desc.SerializeToString()
class MultiTrainer(TrainerDesc):
@ -102,3 +102,22 @@ class DistMultiTrainer(TrainerDesc):
self._device_worker._set_infer(self._infer)
self._device_worker._set_program(self._program)
self._device_worker._gen_worker_desc(self.proto_desc)
class PipelineTrainer(TrainerDesc):
def __init__(self):
super(PipelineTrainer, self).__init__()
pass
def _set_program(self, program):
super(PipelineTrainer, self)._set_program(program)
self._program = program
def _gen_trainer_desc(self):
super(PipelineTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "PipelineTrainer"
if self._program == None:
raise RuntimeError("None Program")
self._device_worker._set_infer(self._infer)
self._device_worker._set_program(self._program)
self._device_worker._gen_worker_desc(self.proto_desc)

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .trainer_desc import MultiTrainer, DistMultiTrainer
from .device_worker import Hogwild, DownpourSGD
from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer
from .device_worker import Hogwild, DownpourSGD, Section
__all__ = ["TrainerFactory"]
@ -35,8 +35,9 @@ class TrainerFactory(object):
device_worker_class = opt_info["device_worker"]
trainer = globals()[trainer_class]()
device_worker = globals()[device_worker_class]()
device_worker._set_fleet_desc(opt_info["fleet_desc"])
if "fleet_desc" in opt_info:
device_worker._set_fleet_desc(opt_info["fleet_desc"])
trainer._set_fleet_desc(opt_info["fleet_desc"])
trainer._set_use_cvm(opt_info["use_cvm"])
trainer._set_device_worker(device_worker)
trainer._set_fleet_desc(opt_info["fleet_desc"])
trainer._set_use_cvm(opt_info["use_cvm"])
return trainer

Loading…
Cancel
Save