|
|
|
|
@ -24,23 +24,13 @@ namespace framework {
|
|
|
|
|
namespace ir {
|
|
|
|
|
namespace fusion_group {
|
|
|
|
|
|
|
|
|
|
static std::unordered_set<std::string> binary_op_types;
|
|
|
|
|
static std::unordered_set<std::string> unary_op_types;
|
|
|
|
|
static std::unordered_set<std::string> elementwise_op_types;
|
|
|
|
|
|
|
|
|
|
static std::unordered_set<std::string>& GetBinaryOpTypes() {
|
|
|
|
|
if (binary_op_types.empty()) {
|
|
|
|
|
binary_op_types =
|
|
|
|
|
OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 2);
|
|
|
|
|
static std::unordered_set<std::string>& GetElementwiseOpTypes() {
|
|
|
|
|
if (elementwise_op_types.empty()) {
|
|
|
|
|
elementwise_op_types = OperationMap::Instance().Find(/* type= */ 0);
|
|
|
|
|
}
|
|
|
|
|
return binary_op_types;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_set<std::string>& GetUnaryOpTypes() {
|
|
|
|
|
if (unary_op_types.empty()) {
|
|
|
|
|
unary_op_types =
|
|
|
|
|
OperationMap::Instance().Find(/* type= */ 0, /* num_operands= */ 1);
|
|
|
|
|
}
|
|
|
|
|
return unary_op_types;
|
|
|
|
|
return elementwise_op_types;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types,
|
|
|
|
|
@ -70,13 +60,8 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
|
|
|
|
|
return l.size() != 0U && r.size() != 0U && l == r;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool IsBinaryOp(const Node* n) {
|
|
|
|
|
if (IsSpecifiedOp(GetBinaryOpTypes(), n)) {
|
|
|
|
|
if ((!IsGradOp(n) && n->inputs.size() != 2U) || n->inputs.size() == 0U) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// The shape of all inputs should be the same.
|
|
|
|
|
bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
|
|
|
|
|
if (IsSpecifiedOp(GetElementwiseOpTypes(), n)) {
|
|
|
|
|
std::vector<int64_t> shape_0;
|
|
|
|
|
for (size_t i = 0; i < n->inputs.size(); ++i) {
|
|
|
|
|
auto* in_i = n->inputs[i];
|
|
|
|
|
@ -98,14 +83,6 @@ static bool IsBinaryOp(const Node* n) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool IsUnaryOp(const Node* n) {
|
|
|
|
|
return IsSpecifiedOp(GetUnaryOpTypes(), n);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
|
|
|
|
|
return IsBinaryOp(n) || IsUnaryOp(n);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<Node*>> ElementwiseGroupDetector::operator()(
|
|
|
|
|
Graph* graph) {
|
|
|
|
|
auto teller = [&](const Node* n) -> bool { return IsElementwiseOp(n); };
|
|
|
|
|
|