|
|
|
@ -88,38 +88,52 @@ bool GroupDetector::CheckPrecondition(const Node* n) {
|
|
|
|
|
return true;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
return n && n->IsOp() && n->Op() && check_data_type(n->inputs) &&
|
|
|
|
|
check_data_type(n->outputs);
|
|
|
|
|
auto check_running_on_cpu = [&](const Node* n) -> bool {
|
|
|
|
|
if (n && n->IsOp() && n->Op()) {
|
|
|
|
|
auto* op = n->Op();
|
|
|
|
|
bool is_run_on_cpu = false;
|
|
|
|
|
if (op->HasAttr("force_cpu") &&
|
|
|
|
|
op->GetAttrType("force_cpu") == proto::AttrType::BOOLEAN) {
|
|
|
|
|
is_run_on_cpu = op->GetAttrIfExists<bool>("force_cpu");
|
|
|
|
|
}
|
|
|
|
|
if (op->HasAttr("op_device")) {
|
|
|
|
|
is_run_on_cpu = op->GetAttrIfExists<std::string>("op_device") == "cpu";
|
|
|
|
|
}
|
|
|
|
|
return is_run_on_cpu;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
return n && n->IsOp() && n->Op() && !check_running_on_cpu(n) &&
|
|
|
|
|
check_data_type(n->inputs) && check_data_type(n->outputs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
|
|
|
|
|
if (IsSpecifiedOp(GetElementwiseOpTypes(), n)) {
|
|
|
|
|
// Check whether all inputs have the same shape.
|
|
|
|
|
bool is_first = true;
|
|
|
|
|
std::vector<int64_t> shape_0;
|
|
|
|
|
for (size_t i = 0; i < n->inputs.size(); ++i) {
|
|
|
|
|
auto* in_i = n->inputs[i];
|
|
|
|
|
if (!(in_i && in_i->IsVar() && in_i->Var())) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto* in_i : n->inputs) {
|
|
|
|
|
if (in_i && in_i->IsVar() && in_i->Var()) {
|
|
|
|
|
std::vector<int64_t> shape_i = in_i->Var()->GetShape();
|
|
|
|
|
if (i == 0U) {
|
|
|
|
|
if (is_first) {
|
|
|
|
|
shape_0 = shape_i;
|
|
|
|
|
is_first = false;
|
|
|
|
|
} else {
|
|
|
|
|
if (!IsEqualAndNotEmpty(shape_0, shape_i)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
auto op = n->Op();
|
|
|
|
|
std::vector<std::string> output_names =
|
|
|
|
|
OperationMap::Instance().Get(op->Type()).output_names;
|
|
|
|
|
|
|
|
|
|
for (auto& name : output_names) {
|
|
|
|
|
if (op->Output(name).size() != 1) return false;
|
|
|
|
|
if (op->Output(name).size() < 1U) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|