|
|
|
@ -656,6 +656,17 @@ bool HasInput(Node *op, const std::string &argument) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasOutput(Node *op, const std::string &argument) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
op->IsOp(), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"First parameter of function HasOuput must be Node::Op"));
|
|
|
|
|
auto const &names = op->Op()->OutputNames();
|
|
|
|
|
if (std::find(names.begin(), names.end(), argument) == names.end())
|
|
|
|
|
return false;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
var->IsVar(), true,
|
|
|
|
@ -665,7 +676,8 @@ bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
|
|
|
|
|
op->IsOp(), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Second parameter of function IsNthOutput must be Node::Op"));
|
|
|
|
|
if (op->Op()->Output(argument).size() <= nth) return false;
|
|
|
|
|
if (!HasOutput(op, argument) || op->Op()->Output(argument).size() <= nth)
|
|
|
|
|
return false;
|
|
|
|
|
return var->Name() == op->Op()->Output(argument)[nth];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|