|
|
|
@ -599,10 +599,19 @@ bool VarLinksToOp(Node *node, const std::string &op_type) {
|
|
|
|
|
bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) {
|
|
|
|
|
PADDLE_ENFORCE(var->IsVar());
|
|
|
|
|
PADDLE_ENFORCE(op->IsOp());
|
|
|
|
|
if (op->Op()->Input(argument).size() <= nth) return false;
|
|
|
|
|
if (!HasInput(op, argument) || op->Op()->Input(argument).size() <= nth)
|
|
|
|
|
return false;
|
|
|
|
|
return var->Name() == op->Op()->Input(argument)[nth];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasInput(Node *op, const std::string &argument) {
|
|
|
|
|
PADDLE_ENFORCE(op->IsOp());
|
|
|
|
|
auto const &names = op->Op()->InputNames();
|
|
|
|
|
if (std::find(names.begin(), names.end(), argument) == names.end())
|
|
|
|
|
return false;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
|
|
|
|
|
PADDLE_ENFORCE(var->IsVar());
|
|
|
|
|
PADDLE_ENFORCE(op->IsOp());
|
|
|
|
@ -1082,8 +1091,15 @@ PDNode *patterns::Conv::operator()() {
|
|
|
|
|
PDNode *patterns::ConvResidual::operator()(bool with_residual_data) {
|
|
|
|
|
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
|
|
|
|
|
|
|
|
|
|
if (!with_residual_data)
|
|
|
|
|
conv_op->assert_op_attr("fuse_residual_connection", false);
|
|
|
|
|
if (!with_residual_data) {
|
|
|
|
|
conv_op->assert_more([&](Node *x) {
|
|
|
|
|
auto node_names = x->Op()->InputNames();
|
|
|
|
|
if (!HasInput(x, "ResidualData") ||
|
|
|
|
|
x->Op()->Input("ResidualData").size() == 0)
|
|
|
|
|
return true;
|
|
|
|
|
return false;
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto input_var = pattern->NewNode(conv_input_repr())
|
|
|
|
|
->AsInput()
|
|
|
|
|