|
|
|
@ -22,15 +22,79 @@ class TrainerDesc;
|
|
|
|
|
|
|
|
|
|
uint64_t SectionWorker::batch_id_(0);
|
|
|
|
|
|
|
|
|
|
void SectionWorker::Initialize(const TrainerDesc& desc) {
|
|
|
|
|
void SectionWorker::Initialize(const TrainerDesc &desc) {
|
|
|
|
|
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
|
|
|
|
|
program_.reset(
|
|
|
|
|
new ProgramDesc(desc.section_param().section_config().program_desc()));
|
|
|
|
|
for (auto& op_desc : program_->Block(0).AllOps()) {
|
|
|
|
|
for (auto &op_desc : program_->Block(0).AllOps()) {
|
|
|
|
|
ops_.push_back(OpRegistry::CreateOp(*op_desc));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SectionWorker::RunForward(
|
|
|
|
|
int micro_id, std::unique_ptr<GarbageCollector> &gc,
|
|
|
|
|
std::unordered_map<const OperatorBase *, std::vector<std::string>>
|
|
|
|
|
&unused_vars_) {
|
|
|
|
|
for (auto &op : ops_) {
|
|
|
|
|
int op_role = op->Attr<int>(std::string("op_role"));
|
|
|
|
|
// We run op with op_role = kLRSched only for the first microbatch
|
|
|
|
|
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
|
|
|
|
|
bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
|
|
|
|
|
op_role == (static_cast<int>(OpRole::kForward) |
|
|
|
|
|
static_cast<int>(OpRole::kLoss)) ||
|
|
|
|
|
op_role == static_cast<int>(OpRole::kLRSched);
|
|
|
|
|
bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
|
|
|
|
|
op_role == (static_cast<int>(OpRole::kForward) |
|
|
|
|
|
static_cast<int>(OpRole::kLoss));
|
|
|
|
|
if ((micro_id == 0 && run_first_mbatch) || (micro_id != 0 && run_others)) {
|
|
|
|
|
VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
|
|
|
|
|
<< micro_id;
|
|
|
|
|
op->Run(*microbatch_scopes_[micro_id], place_);
|
|
|
|
|
if (gc) {
|
|
|
|
|
DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(),
|
|
|
|
|
unused_vars_, gc.get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SectionWorker::RunBackward(
|
|
|
|
|
int micro_id, std::unique_ptr<GarbageCollector> &gc,
|
|
|
|
|
std::unordered_map<const OperatorBase *, std::vector<std::string>>
|
|
|
|
|
&unused_vars_) {
|
|
|
|
|
for (auto &op : ops_) {
|
|
|
|
|
int op_role = op->Attr<int>(std::string("op_role"));
|
|
|
|
|
if (op_role == static_cast<int>(OpRole::kBackward) ||
|
|
|
|
|
op_role == (static_cast<int>(OpRole::kBackward) |
|
|
|
|
|
static_cast<int>(OpRole::kLoss))) {
|
|
|
|
|
VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
|
|
|
|
|
<< micro_id;
|
|
|
|
|
op->Run(*microbatch_scopes_[micro_id], place_);
|
|
|
|
|
if (gc) {
|
|
|
|
|
DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(),
|
|
|
|
|
unused_vars_, gc.get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SectionWorker::RunUpdate(
|
|
|
|
|
std::unique_ptr<GarbageCollector> &gc,
|
|
|
|
|
std::unordered_map<const OperatorBase *, std::vector<std::string>>
|
|
|
|
|
&unused_vars_) {
|
|
|
|
|
for (auto &op : ops_) {
|
|
|
|
|
int op_role = op->Attr<int>(std::string("op_role"));
|
|
|
|
|
if (op_role == static_cast<int>(OpRole::kOptimize)) {
|
|
|
|
|
VLOG(3) << "Update: running op " << op->Type();
|
|
|
|
|
op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_);
|
|
|
|
|
if (gc) {
|
|
|
|
|
DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
|
|
|
|
|
op.get(), unused_vars_, gc.get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SectionWorker::TrainFiles() {
|
|
|
|
|
VLOG(5) << "begin section_worker TrainFiles";
|
|
|
|
|
|
|
|
|
@ -48,69 +112,56 @@ void SectionWorker::TrainFiles() {
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < num_microbatches_; ++i) {
|
|
|
|
|
for (auto& op : ops_) {
|
|
|
|
|
int op_role = op->Attr<int>(std::string("op_role"));
|
|
|
|
|
// We run op with op_role = kLRSched only for the first microbatch
|
|
|
|
|
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
|
|
|
|
|
bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
|
|
|
|
|
op_role == (static_cast<int>(OpRole::kForward) |
|
|
|
|
|
static_cast<int>(OpRole::kLoss)) ||
|
|
|
|
|
op_role == static_cast<int>(OpRole::kLRSched);
|
|
|
|
|
bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
|
|
|
|
|
op_role == (static_cast<int>(OpRole::kForward) |
|
|
|
|
|
static_cast<int>(OpRole::kLoss));
|
|
|
|
|
if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
|
|
|
|
|
VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
|
|
|
|
|
<< i;
|
|
|
|
|
op->Run(*microbatch_scopes_[i], place_);
|
|
|
|
|
if (gc) {
|
|
|
|
|
DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
|
|
|
|
|
gc.get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (schedule_mode_ == 0) {
|
|
|
|
|
// F-then-B scheduler which runs Forward phase for all microbatches,
|
|
|
|
|
// then runs Backward phase for all microbatches.
|
|
|
|
|
// step1: run forward
|
|
|
|
|
for (int i = 0; i < num_microbatches_; ++i) {
|
|
|
|
|
RunForward(i, gc, unused_vars_);
|
|
|
|
|
}
|
|
|
|
|
#ifdef PADDLE_WITH_RCCL
|
|
|
|
|
hipDeviceSynchronize();
|
|
|
|
|
#else
|
|
|
|
|
cudaDeviceSynchronize();
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// backward pass
|
|
|
|
|
for (int i = 0; i < num_microbatches_; ++i) {
|
|
|
|
|
for (auto& op : ops_) {
|
|
|
|
|
int op_role = op->Attr<int>(std::string("op_role"));
|
|
|
|
|
if (op_role == static_cast<int>(OpRole::kBackward) ||
|
|
|
|
|
op_role == (static_cast<int>(OpRole::kBackward) |
|
|
|
|
|
static_cast<int>(OpRole::kLoss))) {
|
|
|
|
|
VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
|
|
|
|
|
<< i;
|
|
|
|
|
op->Run(*microbatch_scopes_[i], place_);
|
|
|
|
|
if (gc) {
|
|
|
|
|
DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
|
|
|
|
|
gc.get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// step2: run backward
|
|
|
|
|
for (int i = 0; i < num_microbatches_; ++i) {
|
|
|
|
|
RunBackward(i, gc, unused_vars_);
|
|
|
|
|
}
|
|
|
|
|
// step3: run update
|
|
|
|
|
RunUpdate(gc, unused_vars_);
|
|
|
|
|
} else {
|
|
|
|
|
// 1F1B scheduler, which runs forward phase and backward phase altertively
|
|
|
|
|
// after startup phase. For a stage, the number of microbatches for
|
|
|
|
|
// startup is num_pipeline_stages_ - pipeline_stage_ - 1, where
|
|
|
|
|
// num_pipeline_stages_ is the total number of pipeline stages and
|
|
|
|
|
// pipeline_stage_ is the pipeline stage of the current device.
|
|
|
|
|
auto startup_steps = num_pipeline_stages_ - pipeline_stage_ - 1;
|
|
|
|
|
VLOG(3) << "startup_steps:" << startup_steps
|
|
|
|
|
<< ", num_stages: " << num_pipeline_stages_
|
|
|
|
|
<< ", stage:" << pipeline_stage_;
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
num_microbatches_, startup_steps,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"To use pipeline with 1F1B scheduler, please make sure number of "
|
|
|
|
|
"microbatches (%d) is than startup steps (%d).",
|
|
|
|
|
num_microbatches_, startup_steps));
|
|
|
|
|
int fw_step = 0;
|
|
|
|
|
int bw_step = 0;
|
|
|
|
|
// startup phase
|
|
|
|
|
while (fw_step < startup_steps) {
|
|
|
|
|
RunForward(fw_step, gc, unused_vars_);
|
|
|
|
|
fw_step += 1;
|
|
|
|
|
}
|
|
|
|
|
#ifdef PADDLE_WITH_RCCL
|
|
|
|
|
hipDeviceSynchronize();
|
|
|
|
|
#else
|
|
|
|
|
cudaDeviceSynchronize();
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// update pass
|
|
|
|
|
for (auto& op : ops_) {
|
|
|
|
|
int op_role = op->Attr<int>(std::string("op_role"));
|
|
|
|
|
if (op_role == static_cast<int>(OpRole::kOptimize)) {
|
|
|
|
|
VLOG(3) << "Update: running op " << op->Type();
|
|
|
|
|
op->Run(*microbatch_scopes_[0], place_);
|
|
|
|
|
if (gc) {
|
|
|
|
|
DeleteUnusedTensors(*microbatch_scopes_[0], op.get(), unused_vars_,
|
|
|
|
|
gc.get());
|
|
|
|
|
}
|
|
|
|
|
// 1f1b phase
|
|
|
|
|
while (fw_step < num_microbatches_) {
|
|
|
|
|
RunForward(fw_step, gc, unused_vars_);
|
|
|
|
|
fw_step += 1;
|
|
|
|
|
RunBackward(bw_step, gc, unused_vars_);
|
|
|
|
|
bw_step += 1;
|
|
|
|
|
}
|
|
|
|
|
// backward phase
|
|
|
|
|
while (bw_step < num_microbatches_) {
|
|
|
|
|
RunBackward(bw_step, gc, unused_vars_);
|
|
|
|
|
bw_step += 1;
|
|
|
|
|
}
|
|
|
|
|
RunUpdate(gc, unused_vars_);
|
|
|
|
|
}
|
|
|
|
|
dev_ctx_->Wait();
|
|
|
|
|
++batch_id_;
|
|
|
|
|