|
|
|
@ -26,59 +26,58 @@ namespace ir {
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
|
|
void CheckProgram(const ProgramDesc &program) {
|
|
|
|
|
std::map<int, bool> visit;
|
|
|
|
|
#define _INT(role) static_cast<int>(role)
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < program.Size(); ++i) {
|
|
|
|
|
for (OpDesc *op : program.Block(i).AllOps()) {
|
|
|
|
|
// For backward compatibility, some program doesn't have role added.
|
|
|
|
|
if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue;
|
|
|
|
|
int role_id = boost::get<int>(
|
|
|
|
|
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
|
|
|
|
|
visit[role_id] = true;
|
|
|
|
|
switch (role_id) {
|
|
|
|
|
case _INT(OpRole::kForward):
|
|
|
|
|
if (visit.find(_INT(OpRole::kBackward)) != visit.end()) {
|
|
|
|
|
LOG(ERROR)
|
|
|
|
|
<< "Cannot add backward operator before forward operator %s."
|
|
|
|
|
<< op->Type();
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case _INT(OpRole::kBackward):
|
|
|
|
|
case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
|
|
|
|
|
"Cannot add backward operator %s before optimize operator.",
|
|
|
|
|
op->Type());
|
|
|
|
|
break;
|
|
|
|
|
case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
|
|
|
|
|
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
|
|
|
|
|
_INT(OpRole::kLoss)) == visit.end(),
|
|
|
|
|
"Cannot add backward|loss operator before "
|
|
|
|
|
"forward|loss operator %s.",
|
|
|
|
|
op->Type());
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
|
|
|
|
|
"Cannot add forward|loss operator %s after optimize operator.",
|
|
|
|
|
op->Type());
|
|
|
|
|
break;
|
|
|
|
|
case _INT(OpRole::kOptimize):
|
|
|
|
|
case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
|
|
|
|
|
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
|
|
|
|
|
"Optimize operators %s must follow backward operator.",
|
|
|
|
|
op->Type());
|
|
|
|
|
break;
|
|
|
|
|
case _INT(OpRole::kLRSched):
|
|
|
|
|
case _INT(OpRole::kDist):
|
|
|
|
|
case _INT(OpRole::kRPC):
|
|
|
|
|
case _INT(OpRole::kNotSpecified):
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
LOG(FATAL) << "Unknown operator role. Don't add new role because "
|
|
|
|
|
"you don't know what you are doing.";
|
|
|
|
|
}
|
|
|
|
|
std::map<int, bool> visit;
|
|
|
|
|
for (OpDesc *op : program.Block(0).AllOps()) {
|
|
|
|
|
// For backward compatibility, some program doesn't have role added.
|
|
|
|
|
if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue;
|
|
|
|
|
int role_id =
|
|
|
|
|
boost::get<int>(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
|
|
|
|
|
visit[role_id] = true;
|
|
|
|
|
switch (role_id) {
|
|
|
|
|
case _INT(OpRole::kForward):
|
|
|
|
|
if (visit.find(_INT(OpRole::kBackward)) != visit.end()) {
|
|
|
|
|
LOG(ERROR)
|
|
|
|
|
<< "Cannot add backward operator before forward operator %s."
|
|
|
|
|
<< op->Type();
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case _INT(OpRole::kBackward):
|
|
|
|
|
case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
|
|
|
|
|
"Cannot add backward operator %s after optimize operator.",
|
|
|
|
|
op->Type());
|
|
|
|
|
break;
|
|
|
|
|
case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
|
|
|
|
|
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
|
|
|
|
|
_INT(OpRole::kLoss)) == visit.end(),
|
|
|
|
|
"Cannot add backward|loss operator before "
|
|
|
|
|
"forward|loss operator %s.",
|
|
|
|
|
op->Type());
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
|
|
|
|
|
"Cannot add forward|loss operator %s after optimize operator.",
|
|
|
|
|
op->Type());
|
|
|
|
|
break;
|
|
|
|
|
case _INT(OpRole::kOptimize):
|
|
|
|
|
case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
|
|
|
|
|
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
|
|
|
|
|
"Optimize operators %s must follow backward operator.",
|
|
|
|
|
op->Type());
|
|
|
|
|
break;
|
|
|
|
|
case _INT(OpRole::kLRSched):
|
|
|
|
|
case _INT(OpRole::kDist):
|
|
|
|
|
case _INT(OpRole::kRPC):
|
|
|
|
|
case _INT(OpRole::kNotSpecified):
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
LOG(FATAL) << "Unknown operator role. Don't add new role because "
|
|
|
|
|
"you don't know what you are doing.";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#undef _INT
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|