add the support for pipeline (#24560)

* add device_worker for pipeline, test=develop
fix_copy_if_different
lilong12 6 years ago committed by GitHub
parent 0dcb87546e
commit e39aa70ec7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -51,10 +51,6 @@ bool CheckValidOutput(LoDTensor* tensor, size_t batch_size);
class FleetWrapper;
#define SEC_LOG \
VLOG(3) << "[s" << section_id_ << "p" << pipeline_id_ << "t" << thread_id_ \
<< "]: "
class PullDenseWorker {
public:
virtual ~PullDenseWorker() {}
@ -311,40 +307,9 @@ class DownpourWorkerOpt : public DownpourWorker {
};
#if defined(PADDLE_WITH_NCCL)
using ScopeQueue = operators::reader::BlockingQueue<Scope*>;
class SyncFunctor {
public:
SyncFunctor(int rank_id, int rank_num, int sync_steps);
virtual ~SyncFunctor() {}
void SetSyncParam(const std::vector<std::string>& sync_param) {
sync_param_ = &sync_param;
}
void SetNcclCtxMap(platform::NCCLContextMap* nccl_ctx_map) {
nccl_ctx_map_ = nccl_ctx_map;
}
int operator()(Scope* scope);
static std::vector<Scope*> pipeline_scopes_;
static uint64_t sync_flag_;
protected:
const int rank_id_;
const int rank_num_;
const std::vector<std::string>* sync_param_ = nullptr;
platform::NCCLContextMap* nccl_ctx_map_ = nullptr;
uint64_t sync_signal_;
const int sync_steps_;
int counter_;
void Synchronize();
};
class SectionWorker : public DeviceWorker {
public:
SectionWorker() {}
SectionWorker() { local_batch_id_ = 0; }
~SectionWorker() override {}
void Initialize(const TrainerDesc& desc) override;
@ -360,50 +325,39 @@ class SectionWorker : public DeviceWorker {
const platform::Place& place() const { return place_; }
void SetSectionIndex(int section_id) { section_id_ = section_id; }
void SetDeviceIndex(int tid) override { pipeline_id_ = tid; }
void SetDeviceIndex(int tid) override {}
void SetThreadIndex(int thread_id) { thread_id_ = thread_id; }
void SetVarNames(const std::vector<std::string>& in_var_names,
const std::vector<std::string>& out_var_names) {
in_var_names_ = &in_var_names;
out_var_names_ = &out_var_names;
}
void SetScopeQueue(ScopeQueue* in_scope_queue, ScopeQueue* out_scope_queue) {
in_scope_queue_ = in_scope_queue;
out_scope_queue_ = out_scope_queue;
void SetMicrobatchNum(int num) { num_microbatches_ = num; }
void SetMicrobatchScopes(const std::vector<Scope*>& scope) {
microbatch_scopes_ = scope;
}
void SetCountMutex(std::mutex* mutex) { worker_count_mutex_ = mutex; }
void SetWorkerCount(int* worker_count) { worker_count_ = worker_count; }
void SetSectionNum(int section_num) { section_num_ = section_num; }
void SetPipelineNum(int pipeline_num) { pipeline_num_ = pipeline_num; }
void SetNextSectionPlace(const paddle::platform::Place& place) {
next_section_place_ = place;
void SetMinibatchScope(const Scope* scope) { minibatch_scope_ = scope; }
void SetSkipVars(const std::vector<std::string>& skip_vars) {
skip_vars_ = skip_vars;
}
SyncFunctor* sync_func_ = nullptr;
void SetSyncFunctor(SyncFunctor* sync_func) { sync_func_ = sync_func; }
static std::atomic<int> cpu_id_;
protected:
void AutoSetCPUAffinity(bool reuse);
int section_id_;
int pipeline_id_;
int section_num_;
int pipeline_num_;
int thread_id_;
// This worker will consume scope from in_scope_queue_
// and produce scope to out_scope_queue_
ScopeQueue* in_scope_queue_ = nullptr;
ScopeQueue* out_scope_queue_ = nullptr;
const std::vector<std::string>* in_var_names_ = nullptr;
const std::vector<std::string>* out_var_names_ = nullptr;
std::mutex* worker_count_mutex_ = nullptr;
int* worker_count_ = nullptr;
paddle::platform::Place next_section_place_;
int num_microbatches_;
std::vector<Scope*> microbatch_scopes_;
std::vector<std::string> skip_vars_;
const Scope* minibatch_scope_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
static std::mutex thread_mutex;
static std::condition_variable thread_condition;
static bool threads_completed;
std::shared_ptr<framework::ProgramDesc> program_;
static uint64_t batch_id_;
uint64_t local_batch_id_;
platform::DeviceContext* dev_ctx_ = nullptr;
};
#endif
} // namespace framework
} // namespace paddle

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -137,49 +137,31 @@ class PipelineTrainer : public TrainerBase {
virtual Scope* GetWorkerScope(int thread_id);
void InitDumpEnv() override;
virtual std::string GetDumpPath(int tid);
void GetSkipVars(int section_id, const ProgramDesc& main_program);
protected:
int section_num_;
int pipeline_num_;
int scope_queue_size_;
int sync_steps_;
int num_microbatches_;
int start_cpu_core_id_;
std::vector<std::string> feed_var_names_;
std::vector<platform::Place> places_;
std::vector<std::vector<std::string>> skip_vars_;
TrainerDesc trainer_desc_;
SectionWorkerParameter pipeline_config_;
// The in/output var names for each section
std::vector<std::unique_ptr<std::vector<std::string>>> in_var_names_;
std::vector<std::unique_ptr<std::vector<std::string>>> out_var_names_;
// Counter for the running thread
std::vector<std::vector<int*>> worker_count_;
std::vector<std::vector<std::unique_ptr<std::mutex>>> worker_count_mutex_;
// worker: [section_id][pipeline_id][thread_id]
std::vector<std::vector<
std::vector<std::shared_ptr<paddle::framework::DeviceWorker>>>>
workers_;
std::vector<std::thread> section_threads_;
// We use scope to maintain context info, and scopes
// will be deliverd between different sections.
std::vector<std::vector<std::unique_ptr<ScopeQueue>>> scope_queues_;
std::vector<Scope*> pipeline_scopes_;
// The parameters that should be syncronized between different cards using
// nccl all-reduce
std::shared_ptr<std::vector<std::string>> param_need_sync_;
std::vector<std::string> persistable_vars_;
std::vector<std::unique_ptr<SyncFunctor>> sync_functors_;
std::shared_ptr<platform::NCCLContextMap> nccl_ctx_map_;
std::vector<DataFeed*> readers_;
void InitFirstScopeQueue(ScopeQueue* scope_queue, int pipeline_id,
const ProgramDesc& main_program,
const Scope& root_scope);
void CopyParameters(const Scope& root_scope, int pipeline_id);
void construct_sync_functor();
// worker: [section_id]
std::vector<std::shared_ptr<paddle::framework::DeviceWorker>> workers_;
// minibatch_scopes_: [section_id]
std::vector<Scope*> minibatch_scopes_;
// microbatch_scopes_: [section_id][microbatch_id]
std::vector<std::vector<Scope*>> microbatch_scopes_;
void CopyParameters(int section_id, int microbatch_id,
const ProgramDesc& program, const platform::Place& place);
bool isPersistableVarGrad(std::string name);
bool isPersistable(VarDesc* var);
};
#endif
} // namespace framework
} // namespace paddle

@ -83,6 +83,7 @@ message SectionWorkerParameter {
optional int64 sync_steps = 3 [ default = 1 ];
optional int32 start_cpu_core_id = 4 [ default = 1 ];
repeated string param_need_sync = 5;
optional int32 num_microbatches = 6;
}
message SectionConfig {
@ -99,6 +100,7 @@ message SectionConfig {
optional int32 concurrency = 3 [ default = 1 ];
repeated string section_in_var_names = 4;
repeated string section_out_var_names = 5;
optional int32 place_id = 6 [ default = -1 ];
}
message FetchConfig {

@ -403,11 +403,8 @@ class Section(DeviceWorker):
trainer_desc.device_worker_name = "SectionWorker"
pipeline_opt = self._program._pipeline_opt
section_param = trainer_desc.section_param
section_param.queue_size = pipeline_opt["queue_size"]
section_param.sync_steps = pipeline_opt["sync_steps"]
section_param.num_microbatches = pipeline_opt["num_microbatches"]
section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"]
for e in pipeline_opt["param_need_sync"]:
section_param.param_need_sync.append(e)
for i, program in enumerate(pipeline_opt["section_program_list"]):
cfg = section_param.section_config.add()
cfg.program_desc.ParseFromString(program["program"]._get_desc()
@ -415,6 +412,7 @@ class Section(DeviceWorker):
# TODO: why does not work
# cfg.program_desc.CopyFrom(program.program._get_desc())
place = pipeline_opt["place_list"][i]
place_id = pipeline_opt["place_id_list"][i]
if isinstance(place, core.CPUPlace):
cfg.place = cfg.CPUPlace
elif isinstance(place, core.CUDAPlace):
@ -425,12 +423,7 @@ class Section(DeviceWorker):
raise NotImplementedError(
"SectionWorker only supports CPUPlace, CUDAPlace and CUDAPinnedPlace now."
)
cfg.concurrency = pipeline_opt["concurrency_list"][i]
for var in program["input_set"]:
cfg.section_in_var_names.append(var)
for var in program["output_set"]:
cfg.section_out_var_names.append(var)
cfg.place_id = place_id
class DeviceWorkerFactory(object):

@ -4474,7 +4474,7 @@ class PipelineOptimizer(object):
"place_list": place_list,
"place_id_list": place_id_list,
"sync_steps": -1,
"queue_size": self._num_microbatches,
"num_microbatches": self._num_microbatches,
"start_cpu_core_id": self._start_cpu_core_id,
}
return optimize_ops, params_grads, program_list

@ -100,7 +100,7 @@ def build_network(input, layers=50, class_dim=1000):
pool_type='max')
if layers >= 50:
for block in range(len(depth)):
with fluid.device_guard("cpu"):
with fluid.device_guard("gpu:0"):
for i in range(depth[block]):
conv = bottleneck_block(
input=conv,
@ -118,7 +118,7 @@ def build_network(input, layers=50, class_dim=1000):
initializer=fluid.initializer.Uniform(-stdv, stdv)))
else:
for block in range(len(depth)):
with fluid.device_guard("cpu"):
with fluid.device_guard("gpu:0"):
for i in range(depth[block]):
conv = basic_block(
input=conv,
@ -140,38 +140,68 @@ def build_network(input, layers=50, class_dim=1000):
class TestPipeline(unittest.TestCase):
""" TestCases for Pipeline Training. """
def _run(self, debug):
main_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog):
with fluid.device_guard("cpu"):
image = fluid.layers.data(
name="image", shape=[3, 224, 224], dtype="float32")
label = fluid.layers.data(
name="label", shape=[1], dtype="int64")
data_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=64,
use_double_buffer=True,
iterable=False)
fc = build_network(image, layers=50)
with fluid.device_guard("gpu:0"):
out, prob = fluid.layers.softmax_with_cross_entropy(
logits=fc, label=label, return_softmax=True)
loss = fluid.layers.mean(out)
acc_top1 = fluid.layers.accuracy(input=prob, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=prob, label=label, k=5)
base_lr = 0.1
passes = [30, 60, 80, 90]
total_images = 1281167
steps_per_pass = total_images // 128
bd = [steps_per_pass * p for p in passes]
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
lr_val = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
optimizer = fluid.optimizer.MomentumOptimizer(
lr_val,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
optimizer = fluid.optimizer.PipelineOptimizer(
optimizer, num_microbatches=2)
optimizer.minimize(loss)
def train_reader():
for _ in range(4):
img = np.random.random(size=[3, 224, 224]).astype('float32')
label = np.random.random(size=[1]).astype('int64')
yield img, label
data_loader.set_sample_generator(train_reader, batch_size=1)
place = fluid.CPUPlace()
# The following dataset is only used for the
# interface 'train_from_dataset'.
# And it has no actual meaning.
dataset = fluid.DatasetFactory().create_dataset('FileInstantDataset')
dataset.set_batch_size(1)
dataset.set_thread(1)
dataset.set_filelist(['/tmp/tmp_2.txt'])
dataset.set_use_var([image, label])
exe = fluid.Executor(place)
exe.run(startup_prog)
data_loader.start()
exe.train_from_dataset(main_prog, dataset, debug=debug)
def test_pipeline(self):
with fluid.device_guard("cpu"):
image = fluid.layers.data(
name="image", shape=[3, 224, 224], dtype="float32")
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
data_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=64,
use_double_buffer=True,
iterable=False)
fc = build_network(image, layers=50)
with fluid.device_guard("gpu:0"):
out, prob = fluid.layers.softmax_with_cross_entropy(
logits=fc, label=label, return_softmax=True)
loss = fluid.layers.mean(out)
acc_top1 = fluid.layers.accuracy(input=prob, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=prob, label=label, k=5)
base_lr = 0.1
passes = [30, 60, 80, 90]
total_images = 1281167
steps_per_pass = total_images // 128
bd = [steps_per_pass * p for p in passes]
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
lr_val = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
optimizer = fluid.optimizer.Momentum(
lr_val,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
optimizer = fluid.optimizer.PipelineOptimizer(
optimizer, num_microbatches=2)
optimizer.minimize(loss)
self._run(False)
self._run(True)
def test_pipeline_noneoptimizer(self):
with fluid.device_guard("gpu:0"):

Loading…
Cancel
Save