|
|
|
|
@ -39,13 +39,13 @@ void SectionWorker::RunForward(
|
|
|
|
|
int op_role = op->Attr<int>(std::string("op_role"));
|
|
|
|
|
// We run op with op_role = kLRSched only for the first microbatch
|
|
|
|
|
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
|
|
|
|
|
bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
|
|
|
|
|
op_role == (static_cast<int>(OpRole::kForward) |
|
|
|
|
|
static_cast<int>(OpRole::kLoss)) ||
|
|
|
|
|
op_role == static_cast<int>(OpRole::kLRSched);
|
|
|
|
|
bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
|
|
|
|
|
op_role == (static_cast<int>(OpRole::kForward) |
|
|
|
|
|
static_cast<int>(OpRole::kLoss));
|
|
|
|
|
bool run_first_mbatch = (op_role == static_cast<int>(OpRole::kForward)) ||
|
|
|
|
|
(op_role == (static_cast<int>(OpRole::kForward) |
|
|
|
|
|
static_cast<int>(OpRole::kLoss))) ||
|
|
|
|
|
(op_role == static_cast<int>(OpRole::kLRSched));
|
|
|
|
|
bool run_others = (op_role == static_cast<int>(OpRole::kForward)) ||
|
|
|
|
|
(op_role == (static_cast<int>(OpRole::kForward) |
|
|
|
|
|
static_cast<int>(OpRole::kLoss)));
|
|
|
|
|
if ((micro_id == 0 && run_first_mbatch) || (micro_id != 0 && run_others)) {
|
|
|
|
|
VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
|
|
|
|
|
<< micro_id;
|
|
|
|
|
@ -64,9 +64,9 @@ void SectionWorker::RunBackward(
|
|
|
|
|
&unused_vars_) {
|
|
|
|
|
for (auto &op : ops_) {
|
|
|
|
|
int op_role = op->Attr<int>(std::string("op_role"));
|
|
|
|
|
if (op_role == static_cast<int>(OpRole::kBackward) ||
|
|
|
|
|
op_role == (static_cast<int>(OpRole::kBackward) |
|
|
|
|
|
static_cast<int>(OpRole::kLoss))) {
|
|
|
|
|
if ((op_role == static_cast<int>(OpRole::kBackward)) ||
|
|
|
|
|
(op_role == (static_cast<int>(OpRole::kBackward) |
|
|
|
|
|
static_cast<int>(OpRole::kLoss)))) {
|
|
|
|
|
VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
|
|
|
|
|
<< micro_id;
|
|
|
|
|
op->Run(*microbatch_scopes_[micro_id], place_);
|
|
|
|
|
|