|
|
|
@ -383,6 +383,63 @@ ExecutorPy::~ExecutorPy() {
|
|
|
|
|
ConfigManager::GetInstance().ResetConfig();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ExecutorPy::GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weight_node,
|
|
|
|
|
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> *fake_quant_table) {
|
|
|
|
|
std::string weight_name;
|
|
|
|
|
auto x = root_node->input(1);
|
|
|
|
|
if (IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
|
|
|
|
|
weight_name = weight_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name();
|
|
|
|
|
} else {
|
|
|
|
|
weight_name = weight_node->cast<ParameterPtr>()->name();
|
|
|
|
|
}
|
|
|
|
|
// find the fakequant from input
|
|
|
|
|
int64_t count = 0;
|
|
|
|
|
const int64_t max_depth = 5;
|
|
|
|
|
CNodePtr cnode = nullptr;
|
|
|
|
|
auto is_quant_cnode = [](const AnfNodePtr &node) {
|
|
|
|
|
return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) ||
|
|
|
|
|
IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel);
|
|
|
|
|
};
|
|
|
|
|
while (!is_quant_cnode(x)) {
|
|
|
|
|
if (count >= max_depth) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
cnode = x->cast<CNodePtr>();
|
|
|
|
|
if (cnode == nullptr || cnode->size() <= 1) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
x = cnode->input(1);
|
|
|
|
|
count += 1;
|
|
|
|
|
}
|
|
|
|
|
if (x->isa<Parameter>() || IsPrimitiveCNode(x, prim::kPrimLoad)) {
|
|
|
|
|
(*fake_quant_table)[weight_name] = std::make_pair(nullptr, "input");
|
|
|
|
|
}
|
|
|
|
|
// get the fakequant parameter minq's name
|
|
|
|
|
if (!is_quant_cnode(x)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
cnode = x->cast<CNodePtr>();
|
|
|
|
|
if (cnode == nullptr || IsPrimitiveCNode(cnode, prim::kPrimLoad) || cnode->size() != 4) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto fakequant_min_node = cnode->input(2);
|
|
|
|
|
if (!fakequant_min_node->isa<Parameter>() && !IsPrimitiveCNode(fakequant_min_node, prim::kPrimLoad)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
std::string fakequant_min_node_name;
|
|
|
|
|
if (IsPrimitiveCNode(fakequant_min_node, prim::kPrimLoad)) {
|
|
|
|
|
fakequant_min_node_name = fakequant_min_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>()->name();
|
|
|
|
|
} else {
|
|
|
|
|
fakequant_min_node_name = fakequant_min_node->cast<ParameterPtr>()->name();
|
|
|
|
|
}
|
|
|
|
|
auto quant_op_value = cnode->input(0)->cast<ValueNodePtr>()->value();
|
|
|
|
|
if (!quant_op_value->isa<PrimitivePy>()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto quant_op = quant_op_value->cast<PrimitivePyPtr>();
|
|
|
|
|
(*fake_quant_table)[weight_name] = std::make_pair(quant_op, fakequant_min_node_name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchInfoForQuantExport(
|
|
|
|
|
const std::string &phase_s) {
|
|
|
|
|
FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph();
|
|
|
|
@ -399,58 +456,21 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI
|
|
|
|
|
IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel);
|
|
|
|
|
};
|
|
|
|
|
for (const auto &node : nodes) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
if (cnode == nullptr || cnode->size() != 3) {
|
|
|
|
|
auto root_node = node->cast<CNodePtr>();
|
|
|
|
|
if (root_node == nullptr || root_node->size() != 3) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto x = cnode->input(1);
|
|
|
|
|
auto weight = cnode->input(2);
|
|
|
|
|
auto weight = root_node->input(2);
|
|
|
|
|
if (!is_quant_cnode(weight)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// get parameter weight's name
|
|
|
|
|
cnode = weight->cast<CNodePtr>();
|
|
|
|
|
auto cnode = weight->cast<CNodePtr>();
|
|
|
|
|
auto weight_node = cnode->input(2);
|
|
|
|
|
if (!weight_node->isa<Parameter>()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto weight_name = weight_node->cast<ParameterPtr>()->name();
|
|
|
|
|
// find the fakequant from input
|
|
|
|
|
int64_t count = 0;
|
|
|
|
|
const int64_t max_depth = 5;
|
|
|
|
|
while (!is_quant_cnode(x)) {
|
|
|
|
|
if (count >= max_depth) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
cnode = x->cast<CNodePtr>();
|
|
|
|
|
if (cnode == nullptr || cnode->size() <= 1) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
x = cnode->input(1);
|
|
|
|
|
count += 1;
|
|
|
|
|
}
|
|
|
|
|
if (x->isa<Parameter>()) {
|
|
|
|
|
fake_quant_table[weight_name] = std::make_pair(nullptr, "input");
|
|
|
|
|
}
|
|
|
|
|
// get the fakequant parameter minq's name
|
|
|
|
|
if (!is_quant_cnode(x)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
cnode = x->cast<CNodePtr>();
|
|
|
|
|
if (cnode == nullptr || cnode->size() != 4) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto fakequant_min_node = cnode->input(2);
|
|
|
|
|
if (!fakequant_min_node->isa<Parameter>()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto fakequant_min_node_name = fakequant_min_node->cast<ParameterPtr>()->name();
|
|
|
|
|
auto quant_op_value = cnode->input(0)->cast<ValueNodePtr>()->value();
|
|
|
|
|
if (!quant_op_value->isa<PrimitivePy>()) {
|
|
|
|
|
if (!weight_node->isa<Parameter>() && !IsPrimitiveCNode(weight_node, prim::kPrimLoad)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto quant_op = quant_op_value->cast<PrimitivePyPtr>();
|
|
|
|
|
fake_quant_table[weight_name] = std::make_pair(quant_op, fakequant_min_node_name);
|
|
|
|
|
GetWeightInfo(root_node, weight_node, &fake_quant_table);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return fake_quant_table;
|
|
|
|
|