|
|
|
@ -38,29 +38,35 @@ void CheckProgram(const ProgramDesc &program) {
|
|
|
|
|
visit[role_id] = true;
|
|
|
|
|
switch (role_id) {
|
|
|
|
|
case _INT(OpRole::kForward):
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
visit.find(_INT(OpRole::kBackward)) == visit.end(),
|
|
|
|
|
"Cannot add forward operator before backward operator.");
|
|
|
|
|
if (visit.find(_INT(OpRole::kBackward)) != visit.end()) {
|
|
|
|
|
LOG(ERROR)
|
|
|
|
|
<< "Cannot add backward operator before forward operator %s."
|
|
|
|
|
<< op->Type();
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case _INT(OpRole::kBackward):
|
|
|
|
|
case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
|
|
|
|
|
"Cannot add backward operator before optimize operator.");
|
|
|
|
|
"Cannot add backward operator %s before optimize operator.",
|
|
|
|
|
op->Type());
|
|
|
|
|
break;
|
|
|
|
|
case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
|
|
|
|
|
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
|
|
|
|
|
_INT(OpRole::kLoss)) == visit.end(),
|
|
|
|
|
"Cannot add backward|loss operator before "
|
|
|
|
|
"forward|loss operator.");
|
|
|
|
|
"forward|loss operator %s.",
|
|
|
|
|
op->Type());
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
|
|
|
|
|
"Cannot add backward operator before optimize operator.");
|
|
|
|
|
"Cannot add forward|loss operator %s after optimize operator.",
|
|
|
|
|
op->Type());
|
|
|
|
|
break;
|
|
|
|
|
case _INT(OpRole::kOptimize):
|
|
|
|
|
case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
|
|
|
|
|
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
|
|
|
|
|
"Optimize operators must follow backward operator.");
|
|
|
|
|
"Optimize operators %s must follow backward operator.",
|
|
|
|
|
op->Type());
|
|
|
|
|
break;
|
|
|
|
|
case _INT(OpRole::kLRSched):
|
|
|
|
|
case _INT(OpRole::kDist):
|
|
|
|
|