|
|
|
@ -23,8 +23,59 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace ir {
|
|
|
|
|
namespace {
|
|
|
|
|
void CheckProgram(const ProgramDesc &program) {
|
|
|
|
|
std::map<int, bool> visit;
|
|
|
|
|
#define _INT(role) static_cast<int>(role)
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < program.Size(); ++i) {
|
|
|
|
|
for (OpDesc *op : program.Block(i).AllOps()) {
|
|
|
|
|
int role_id = boost::get<int>(
|
|
|
|
|
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
|
|
|
|
|
visit[role_id] = true;
|
|
|
|
|
switch (role_id) {
|
|
|
|
|
case _INT(OpRole::kForward):
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
visit.find(_INT(OpRole::kBackward)) == visit.end(),
|
|
|
|
|
"Cannot add forward operator before backward operator.");
|
|
|
|
|
break;
|
|
|
|
|
case _INT(OpRole::kBackward):
|
|
|
|
|
case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
|
|
|
|
|
"Cannot add backward operator before optimize operator.");
|
|
|
|
|
break;
|
|
|
|
|
case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
|
|
|
|
|
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
|
|
|
|
|
_INT(OpRole::kLoss)) == visit.end(),
|
|
|
|
|
"Cannot add backward|loss operator before "
|
|
|
|
|
"forward|loss operator.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
|
|
|
|
|
"Cannot add backward operator before optimize operator.");
|
|
|
|
|
break;
|
|
|
|
|
case _INT(OpRole::kOptimize):
|
|
|
|
|
case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
|
|
|
|
|
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
|
|
|
|
|
"Optimize operators must follow backward operator.");
|
|
|
|
|
break;
|
|
|
|
|
case _INT(OpRole::kLRSched):
|
|
|
|
|
case _INT(OpRole::kDist):
|
|
|
|
|
case _INT(OpRole::kRPC):
|
|
|
|
|
case _INT(OpRole::kNotSpecified):
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
LOG(FATAL) << "Unknown operator role. Don't add new role because "
|
|
|
|
|
"you don't know what you are doing.";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#undef _INT
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
Graph::Graph(const ProgramDesc &program) : program_(program) {
|
|
|
|
|
CheckProgram(program_);
|
|
|
|
|
// Make the nodes id start from 0.
|
|
|
|
|
Node::ResetId();
|
|
|
|
|
auto var_nodes = InitFromProgram(program_);
|
|
|
|
|