|
|
|
@ -39,10 +39,7 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector
|
|
|
|
|
|
|
|
|
|
STATUS QuantCast::Run(FuncGraphPtr graph) {
|
|
|
|
|
MS_ASSERT(graph != nullptr);
|
|
|
|
|
|
|
|
|
|
auto cnodes = graph->GetOrderedCnodes();
|
|
|
|
|
bool first = true;
|
|
|
|
|
|
|
|
|
|
for (auto &cnode : cnodes) {
|
|
|
|
|
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
|
|
|
|
auto curnode_quant_type = schema::QuantType_QUANT_NONE;
|
|
|
|
@ -51,34 +48,30 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
|
|
|
|
|
} else {
|
|
|
|
|
curnode_quant_type = primitive_c->GetQuantType();
|
|
|
|
|
}
|
|
|
|
|
if (first) {
|
|
|
|
|
if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) {
|
|
|
|
|
auto value_node =
|
|
|
|
|
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitive_c->GetInputQuantParams().front());
|
|
|
|
|
std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)};
|
|
|
|
|
auto quant_cast_cnode = graph->NewCNode(op_inputs);
|
|
|
|
|
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast");
|
|
|
|
|
cnode->set_input(1, quant_cast_cnode);
|
|
|
|
|
MS_LOG(DEBUG) << "Add quant cast at front. "
|
|
|
|
|
<< "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type;
|
|
|
|
|
}
|
|
|
|
|
first = false;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
|
|
|
|
auto input_node = cnode->input(i);
|
|
|
|
|
if (!input_node->isa<CNode>()) {
|
|
|
|
|
continue;
|
|
|
|
|
auto is_graph_input = false;
|
|
|
|
|
if (input_node->isa<Parameter>()) {
|
|
|
|
|
if (!input_node->cast<ParameterPtr>()->has_default()) {
|
|
|
|
|
is_graph_input = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto input_cnode = std::dynamic_pointer_cast<CNode>(input_node);
|
|
|
|
|
auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
|
|
|
|
|
if (input_cnode_primitive_c == nullptr) {
|
|
|
|
|
MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": "
|
|
|
|
|
<< " PrimitiveC is null";
|
|
|
|
|
if (!input_node->isa<CNode>() && !is_graph_input) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto input_cnode_quant_type = input_cnode_primitive_c->GetQuantType();
|
|
|
|
|
auto input_cnode_quant_type = schema::QuantType_QUANT_NONE;
|
|
|
|
|
std::shared_ptr<PrimitiveC> input_cnode_primitive_c = nullptr;
|
|
|
|
|
if (!is_graph_input) {
|
|
|
|
|
auto input_cnode = std::dynamic_pointer_cast<CNode>(input_node);
|
|
|
|
|
input_cnode_primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
|
|
|
|
|
if (input_cnode_primitive_c == nullptr) {
|
|
|
|
|
MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": "
|
|
|
|
|
<< " PrimitiveC is null";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
input_cnode_quant_type = input_cnode_primitive_c->GetQuantType();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (curnode_quant_type != input_cnode_quant_type) {
|
|
|
|
|
ValueNodePtr value_node = nullptr;
|
|
|
|
@ -94,22 +87,22 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
|
|
|
|
|
if (value_node == nullptr) {
|
|
|
|
|
MS_LOG(WARNING) << "value_node is null! "
|
|
|
|
|
<< "cur_node: " << cnode->fullname_with_scope() << " quant_type: "
|
|
|
|
|
<< " input_" << i << ": " << input_cnode->fullname_with_scope()
|
|
|
|
|
<< " input_" << i << ": "
|
|
|
|
|
<< " quant_type:" << input_cnode_quant_type;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> op_inputs = {value_node, input_cnode};
|
|
|
|
|
std::vector<AnfNodePtr> op_inputs = {value_node, input_node};
|
|
|
|
|
auto quant_cast_cnode = graph->NewCNode(op_inputs);
|
|
|
|
|
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast_" + std::to_string(i));
|
|
|
|
|
cnode->set_input(i, quant_cast_cnode);
|
|
|
|
|
MS_LOG(DEBUG) << "Add quant cast. "
|
|
|
|
|
<< "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type
|
|
|
|
|
<< " input_" << i << ": " << input_cnode->fullname_with_scope()
|
|
|
|
|
<< " input_" << i << ": "
|
|
|
|
|
<< " quant_type:" << input_cnode_quant_type;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(DEBUG) << "No need to add quant cast. "
|
|
|
|
|
<< "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type
|
|
|
|
|
<< " input_" << i << ": " << input_cnode->fullname_with_scope()
|
|
|
|
|
<< " input_" << i << ": "
|
|
|
|
|
<< " quant_type:" << input_cnode_quant_type;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|