|
|
|
@ -64,64 +64,61 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
|
|
|
|
|
MS_ASSERT(dst_node != nullptr);
|
|
|
|
|
// add quant param
|
|
|
|
|
dst_node->quantType = primitive->GetQuantType();
|
|
|
|
|
if (dst_node->quantType == schema::QuantType_PostTraining || dst_node->quantType == schema::QuantType_AwareTraining
|
|
|
|
|
|| dst_node->quantType == schema::QuantType_WeightQuant) {
|
|
|
|
|
MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam";
|
|
|
|
|
// activation
|
|
|
|
|
auto input_quant_params = primitive->GetInputQuantParams();
|
|
|
|
|
auto node_type = (schema::PrimitiveType)primitive->Type();
|
|
|
|
|
if (input_quant_params.empty()) {
|
|
|
|
|
MS_LOG(WARNING) << "node: " << dst_node->name << " input quant params is empty";
|
|
|
|
|
return RET_OK;
|
|
|
|
|
MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam";
|
|
|
|
|
// activation
|
|
|
|
|
auto input_quant_params = primitive->GetInputQuantParams();
|
|
|
|
|
auto node_type = (schema::PrimitiveType)primitive->Type();
|
|
|
|
|
if (input_quant_params.empty()) {
|
|
|
|
|
MS_LOG(WARNING) << "node: " << dst_node->name << " input quant params is empty";
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < input_quant_params.size(); i++) {
|
|
|
|
|
if (i >= dst_node->inputIndex.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size()
|
|
|
|
|
<< " quant_params; but only " << dst_node->inputIndex.size() << " input";
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < input_quant_params.size(); i++) {
|
|
|
|
|
if (i >= dst_node->inputIndex.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size()
|
|
|
|
|
<< " quant_params; but only " << dst_node->inputIndex.size() << " input";
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
auto activate_index = dst_node->inputIndex[i];
|
|
|
|
|
auto tensor_input = meta_graph->allTensors[activate_index].get();
|
|
|
|
|
if (tensor_input->quantParams.empty()) {
|
|
|
|
|
for (auto input_quant_param : input_quant_params[i]) {
|
|
|
|
|
std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
|
|
|
|
|
std::make_unique<schema::QuantParamT>(input_quant_param);
|
|
|
|
|
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
|
|
|
|
|
<< " zp: " << input_quant_param_ptr->zeroPoint;
|
|
|
|
|
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
|
|
|
|
|
}
|
|
|
|
|
auto activate_index = dst_node->inputIndex[i];
|
|
|
|
|
auto tensor_input = meta_graph->allTensors[activate_index].get();
|
|
|
|
|
if (tensor_input->quantParams.empty()) {
|
|
|
|
|
for (auto input_quant_param : input_quant_params[i]) {
|
|
|
|
|
std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
|
|
|
|
|
std::make_unique<schema::QuantParamT>(input_quant_param);
|
|
|
|
|
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
|
|
|
|
|
<< " zp: " << input_quant_param_ptr->zeroPoint;
|
|
|
|
|
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// output
|
|
|
|
|
auto output_index = dst_node->outputIndex[0];
|
|
|
|
|
auto tensor_output = meta_graph->allTensors[output_index].get();
|
|
|
|
|
auto output_quant_params = primitive->GetOutputQuantParams();
|
|
|
|
|
if (output_quant_params.empty()) {
|
|
|
|
|
if (node_type != schema::PrimitiveType_QuantDTypeCast) {
|
|
|
|
|
MS_LOG(DEBUG) << "node: " << dst_node->name << " output quant params is empty";
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (auto output_quant_param : output_quant_params[0]) {
|
|
|
|
|
if (tensor_output->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) {
|
|
|
|
|
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
|
|
|
|
|
std::make_unique<schema::QuantParamT>(output_quant_param);
|
|
|
|
|
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
|
|
|
|
|
<< " zp: " << output_quant_param_ptr->zeroPoint;
|
|
|
|
|
tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr));
|
|
|
|
|
}
|
|
|
|
|
// output
|
|
|
|
|
auto output_index = dst_node->outputIndex[0];
|
|
|
|
|
auto tensor_output = meta_graph->allTensors[output_index].get();
|
|
|
|
|
auto output_quant_params = primitive->GetOutputQuantParams();
|
|
|
|
|
if (output_quant_params.empty()) {
|
|
|
|
|
if (node_type != schema::PrimitiveType_QuantDTypeCast) {
|
|
|
|
|
MS_LOG(DEBUG) << "node: " << dst_node->name << " output quant params is empty";
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (auto output_quant_param : output_quant_params[0]) {
|
|
|
|
|
if (tensor_output->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) {
|
|
|
|
|
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
|
|
|
|
|
std::make_unique<schema::QuantParamT>(output_quant_param);
|
|
|
|
|
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
|
|
|
|
|
<< " zp: " << output_quant_param_ptr->zeroPoint;
|
|
|
|
|
tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (dst_node->quantType == schema::QuantType_PostTraining) {
|
|
|
|
|
if (node_type != schema::PrimitiveType_QuantDTypeCast) {
|
|
|
|
|
}
|
|
|
|
|
if (dst_node->quantType == schema::QuantType_PostTraining) {
|
|
|
|
|
if (node_type != schema::PrimitiveType_QuantDTypeCast) {
|
|
|
|
|
tensor_output->dataType = kNumberTypeInt8;
|
|
|
|
|
} else {
|
|
|
|
|
MS_ASSERT(utils::isa<std::shared_ptr<QuantDTypeCast>>(primitive));
|
|
|
|
|
auto primc = utils::cast<std::shared_ptr<QuantDTypeCast>>(primitive);
|
|
|
|
|
MS_ASSERT(primc != nullptr);
|
|
|
|
|
if (primc->GetDstT() != kNumberTypeFloat32) {
|
|
|
|
|
tensor_output->dataType = kNumberTypeInt8;
|
|
|
|
|
} else {
|
|
|
|
|
MS_ASSERT(utils::isa<std::shared_ptr<QuantDTypeCast>>(primitive));
|
|
|
|
|
auto primc = utils::cast<std::shared_ptr<QuantDTypeCast>>(primitive);
|
|
|
|
|
MS_ASSERT(primc != nullptr);
|
|
|
|
|
if (primc->GetDstT() != kNumberTypeFloat32) {
|
|
|
|
|
tensor_output->dataType = kNumberTypeInt8;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|