fix the bug of assert_is_op_output. test=develop (#22262)

revert-22710-feature/integrated_ps_api
Zhen Wang 6 years ago committed by GitHub
parent a46bb2e6ab
commit e40cfb1010
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -656,6 +656,17 @@ bool HasInput(Node *op, const std::string &argument) {
return true;
}
bool HasOutput(Node *op, const std::string &argument) {
PADDLE_ENFORCE_EQ(
op->IsOp(), true,
platform::errors::InvalidArgument(
"First parameter of function HasOuput must be Node::Op"));
auto const &names = op->Op()->OutputNames();
if (std::find(names.begin(), names.end(), argument) == names.end())
return false;
return true;
}
bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
PADDLE_ENFORCE_EQ(
var->IsVar(), true,
@ -665,7 +676,8 @@ bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
op->IsOp(), true,
platform::errors::InvalidArgument(
"Second parameter of function IsNthOutput must be Node::Op"));
if (op->Op()->Output(argument).size() <= nth) return false;
if (!HasOutput(op, argument) || op->Op()->Output(argument).size() <= nth)
return false;
return var->Name() == op->Op()->Output(argument)[nth];
}

@ -318,6 +318,9 @@ bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth);
// Check whether the op node has input of given name.
bool HasInput(Node* op, const std::string& argument);
// Check whether the op node has output of given name.
bool HasOutput(Node* op, const std::string& argument);
// Tell whether a var node is a op node's nth output.
bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth);

Loading…
Cancel
Save