fix quantized rounding

pull/10479/head
fuzhiye 4 years ago
parent 8938c2f5ee
commit 6d86efc1d8

@ -53,10 +53,22 @@ ConvDwInt8PostAlign4:
sqrdmulh v2.4s, v2.4s, v27.4s
sqrdmulh v3.4s, v3.4s, v27.4s
sqrshl v0.4s, v0.4s, v28.4s
sqrshl v1.4s, v1.4s, v28.4s
sqrshl v2.4s, v2.4s, v28.4s
sqrshl v3.4s, v3.4s, v28.4s
and v4.16b, v0.16b, v28.16b
sshr v4.4s, v4.4s, #31
sqadd v0.4s, v0.4s, v4.4s
srshl v0.4s, v0.4s, v28.4s
and v5.16b, v1.16b, v28.16b
sshr v5.4s, v5.4s, #31
sqadd v1.4s, v1.4s, v5.4s
srshl v1.4s, v1.4s, v28.4s
and v6.16b, v2.16b, v28.16b
sshr v6.4s, v6.4s, #31
sqadd v2.4s, v2.4s, v6.4s
srshl v2.4s, v2.4s, v28.4s
and v7.16b, v3.16b, v28.16b
sshr v7.4s, v7.4s, #31
sqadd v3.4s, v3.4s, v7.4s
srshl v3.4s, v3.4s, v28.4s
AddZpDepth16:
add v0.4s, v0.4s, v29.4s
@ -109,8 +121,14 @@ ConvDwInt8PostAlign4:
RightShiftDepth8:
sqrdmulh v0.4s, v0.4s, v27.4s
sqrdmulh v1.4s, v1.4s, v27.4s
sqrshl v0.4s, v0.4s, v28.4s
sqrshl v1.4s, v1.4s, v28.4s
and v4.16b, v0.16b, v28.16b
sshr v4.4s, v4.4s, #31
sqadd v0.4s, v0.4s, v4.4s
srshl v0.4s, v0.4s, v28.4s
and v5.16b, v1.16b, v28.16b
sshr v5.4s, v5.4s, #31
sqadd v1.4s, v1.4s, v5.4s
srshl v1.4s, v1.4s, v28.4s
AddZpDepth8:
add v0.4s, v0.4s, v29.4s
@ -140,7 +158,10 @@ ConvDwInt8PostAlign4:
sqshl v0.4s, v0.4s, v26.4s
sqrdmulh v0.4s, v0.4s, v27.4s
sqrshl v0.4s, v0.4s, v28.4s
and v4.16b, v0.16b, v28.16b
sshr v4.4s, v4.4s, #31
sqadd v0.4s, v0.4s, v4.4s
srshl v0.4s, v0.4s, v28.4s
add v0.4s, v0.4s, v29.4s
smax v0.4s, v0.4s, v30.4s

@ -43,8 +43,14 @@ ConvDwInt8PostAlign4PerChannel:
sqrdmulh v0.4s, v0.4s, v4.4s
sqrdmulh v1.4s, v1.4s, v5.4s
sqrshl v0.4s, v0.4s, v6.4s
sqrshl v1.4s, v1.4s, v7.4s
and v16.16b, v0.16b, v6.16b
sshr v16.4s, v16.4s, #31
sqadd v0.4s, v0.4s, v16.4s
srshl v0.4s, v0.4s, v6.4s
and v17.16b, v1.16b, v7.16b
sshr v17.4s, v17.4s, #31
sqadd v1.4s, v1.4s, v17.4s
srshl v1.4s, v1.4s, v7.4s
add v0.4s, v0.4s, v29.4s
add v1.4s, v1.4s, v29.4s
@ -80,7 +86,10 @@ ConvDwInt8PostAlign4PerChannel:
sqrdmulh v0.4s, v0.4s, v4.4s
ld1 {v6.4s}, [x6], #16
sqrshl v0.4s, v0.4s, v6.4s
and v16.16b, v0.16b, v6.16b
sshr v16.4s, v16.4s, #31
sqadd v0.4s, v0.4s, v16.4s
srshl v0.4s, v0.4s, v6.4s
add v0.4s, v0.4s, v29.4s
smax v0.4s, v0.4s, v30.4s

@ -29,17 +29,24 @@ int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float
return NNACL_OK;
}
int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) {
int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size,
bool uint8_flag) {
if (quant_values == NULL || real_values == NULL) {
return NNACL_PARAM_INVALID;
}
if (uint8_flag) {
zp += 128;
}
const float inverse_scale = 1.0f / scale;
for (int i = 0; i < size; ++i) {
if (isinf(real_values[i])) {
quant_values[i] = 127;
} else {
int temp = round(real_values[i] * inverse_scale + zp);
if (uint8_flag) {
temp -= 128;
}
temp = temp < 127 ? temp : 127;
temp = temp > -128 ? temp : -128;
quant_values[i] = (int8_t)temp;

@ -29,7 +29,8 @@ typedef struct QuantDTypeCastParameter {
extern "C" {
#endif
int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size);
int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size);
int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size,
bool uint8_flag);
int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size);
int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size);
int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size);

@ -80,6 +80,12 @@ typedef struct OpParameter {
typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, ActType_Prelu } ActType;
typedef enum PadMode { Pad_No, Pad_Same, Pad_Valid } PadMode;
typedef enum RoundingMode { Rounding_No, Rounding_Away_from_zero, Rounding_Up } RoundingMode;
typedef enum CalFixedMultiplierMode {
Method_No,
Method_SinglePrecision,
Method_DoublePrecision
} CalFixedMultiplierMode;
#ifdef ENABLE_ARM
#define MS_FLOAT32X4 float32x4_t

@ -42,7 +42,7 @@ int16_t SaturatingRoundingDoublingHighMulInt16(int16_t a, int16_t b) {
}
// division by a 2^exponent with rounding
// or arithmetic right shift with rouding
// or arithmetic right shift with rounding
int RoundingDivideByPOT(int x, int exponent) {
const int mask = (1ll << exponent) - 1;
const int remainder = x & mask;
@ -50,10 +50,23 @@ int RoundingDivideByPOT(int x, int exponent) {
return (x >> exponent) + (remainder > threshold ? 1 : 0);
}
int UpwardRounding(int x, int exponent) {
const int32_t rounding_offset = (exponent > 0) ? (1 << (exponent - 1)) : 0;
if (x > INT32_MAX - rounding_offset) {
return 1 << (31 - exponent);
}
return (x + rounding_offset) >> exponent;
}
int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift) {
return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift);
}
int MultiplyByQuantizedMultiplierWithUpwardRounding(int32_t value, int32_t multiplier, int32_t left_shift,
int32_t right_shift) {
return UpwardRounding(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift);
}
int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift) {
return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value, multiplier), right_shift);
}

@ -40,8 +40,13 @@ int16_t SaturatingRoundingDoublingHighMulInt16(int16_t a, int16_t b);
// or arithmetic right shift with rouding
int RoundingDivideByPOT(int x, int exponent);
int UpwardRounding(int x, int exponent);
int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift);
int MultiplyByQuantizedMultiplierWithUpwardRounding(int32_t value, int32_t multiplier, int32_t left_shift,
int32_t right_shift);
int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift);
int SaturatingRoundingMultiplyByPOT(int32_t x, int exponent);

@ -15,6 +15,7 @@
*/
#include "nnacl/quantization/quantize.h"
#include <stdio.h>
const uint64_t dSignMask = 1ull << 63;
const uint64_t dExponentMask = 0x7ffull << 52;
@ -35,8 +36,8 @@ void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantiz
*right_shift = -shift;
}
void QuantizeRoundParameter(double double_multiplier, int32_t *quantized_multiplier, int *left_shift,
int *right_shift) {
void QuantizeRoundParameterWithDoublePrecision(double double_multiplier, int32_t *quantized_multiplier, int *left_shift,
int *right_shift) {
int shift = 0;
QuantizeMultiplierSmallerThanOne(double_multiplier, quantized_multiplier, &shift);
shift = -shift;
@ -49,6 +50,29 @@ void QuantizeRoundParameter(double double_multiplier, int32_t *quantized_multipl
}
}
void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t *quantized_multiplier, int *left_shift,
int *right_shift) {
int shift = 0;
const uint32_t scale_bits = (uint32_t)(double_multiplier);
/* multipiler is in[0x40000000, 0x7FFFFF80] range */
*quantized_multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
if (quantized_multiplier[0] < INT32_C(0x40000000) || quantized_multiplier[0] > INT32_C(0x7FFFFF80)) {
printf("quantized multiplier must be in [0x40000000, 0x7FFFFF80] range, now multiplier is %d\n",
quantized_multiplier[0]);
return;
}
/* shift is in [0, 31] range */
shift = 127 + 31 - 32 - ((uint32_t)(double_multiplier) >> 23);
shift = -shift;
if (shift < 0) {
*left_shift = 0;
*right_shift = shift;
} else {
*left_shift = shift;
*right_shift = 0;
}
}
uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); }
int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); }

@ -34,6 +34,8 @@ typedef struct QuantArg {
} QuantArg;
typedef struct ConvQuantArg {
RoundingMode round_mode_;
CalFixedMultiplierMode quant_multiplier_mode_;
QuantArg *input_quant_args_;
QuantArg *filter_quant_args_;
QuantArg *output_quant_args_;
@ -46,7 +48,6 @@ typedef struct ConvQuantArg {
size_t input_arg_num_;
size_t filter_arg_num_;
size_t output_arg_num_;
uint8_t asymmetric_;
uint8_t per_channel_;
} ConvQuantArg;
@ -282,7 +283,11 @@ void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier,
void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, int *right_shift);
void QuantizeRoundParameter(double double_multiplier, int32_t *quantized_multiplier, int *left_shift, int *right_shift);
void QuantizeRoundParameterWithDoublePrecision(double double_multiplier, int32_t *quantized_multiplier, int *left_shift,
int *right_shift);
void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t *quantized_multiplier, int *left_shift,
int *right_shift);
uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp);

@ -40,6 +40,8 @@ table QuantParam {
varCorr: float = 1;
meanCorr: float = 0;
dstDtype: int = 32;
roundType: int = 1;
multiplier: int = -1; // calculate fixed point multiplier method
}
table Tensor {

@ -69,6 +69,9 @@ void LiteSession::ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lit
quant_arg.var_corr = quant_params->Get(j)->varCorr();
quant_arg.mean_corr = quant_params->Get(j)->meanCorr();
quant_arg.inited = quant_params->Get(j)->inited();
quant_arg.roundType = quant_params->Get(j)->roundType();
quant_arg.multiplier = quant_params->Get(j)->multiplier();
quant_arg.dstDtype = quant_params->Get(j)->dstDtype();
dst_tensor->AddQuantParam(quant_arg);
}
}

@ -261,12 +261,43 @@ int ConvolutionBaseCPUKernel::SetQuantMultiplier() {
static_cast<double>(conv_quant_arg_->input_quant_args_[0].scale_ * conv_quant_arg_->filter_quant_args_[i].scale_);
double real_multiplier = in_scale / static_cast<double>(conv_quant_arg_->output_quant_args_[0].scale_);
conv_quant_arg_->real_multiplier_[i] = real_multiplier;
QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], &conv_quant_arg_->left_shift_[i],
&conv_quant_arg_->right_shift_[i]);
if (conv_quant_arg_->quant_multiplier_mode_ == Method_SinglePrecision) {
QuantizeRoundParameterWithSinglePrecision(real_multiplier, &conv_quant_arg_->quant_multiplier_[i],
&conv_quant_arg_->left_shift_[i], &conv_quant_arg_->right_shift_[i]);
} else if (conv_quant_arg_->quant_multiplier_mode_ == Method_DoublePrecision) {
QuantizeRoundParameterWithDoublePrecision(real_multiplier, &conv_quant_arg_->quant_multiplier_[i],
&conv_quant_arg_->left_shift_[i], &conv_quant_arg_->right_shift_[i]);
}
}
return RET_OK;
}
void ConvolutionBaseCPUKernel::SetRoundingAndMultipilerMode() {
auto input_quant_arg = in_tensors_.at(kInputIndex)->quant_params().front();
int round_type = input_quant_arg.roundType;
switch (round_type) {
case 1:
conv_quant_arg_->round_mode_ = Rounding_Away_from_zero;
break;
case 2:
conv_quant_arg_->round_mode_ = Rounding_Up;
break;
default:
conv_quant_arg_->round_mode_ = Rounding_No;
}
int cal_multiplier_type = input_quant_arg.multiplier;
switch (cal_multiplier_type) {
case 0:
conv_quant_arg_->quant_multiplier_mode_ = Method_SinglePrecision;
break;
case 1:
conv_quant_arg_->quant_multiplier_mode_ = Method_DoublePrecision;
break;
default:
conv_quant_arg_->quant_multiplier_mode_ = Method_No;
}
}
int ConvolutionBaseCPUKernel::SetQuantParam() {
auto ret = MallocQuantParam();
if (ret != RET_OK) {
@ -288,13 +319,12 @@ int ConvolutionBaseCPUKernel::SetQuantParam() {
MS_LOG(ERROR) << "Set Output Tensor Quant Param Failed.";
return ret;
}
ret = SetIfPerChannel();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Set if per tensor channel failed.";
return ret;
}
SetRoundingAndMultipilerMode();
ret = SetQuantMultiplier();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Set Quant Multiplier Failed.";

@ -53,6 +53,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
int SetFilterTensorQuantParam();
int SetOutputTensorQuantParam();
int SetQuantMultiplier();
void SetRoundingAndMultipilerMode();
int CheckResizeValid();
void FreeQuantParam();

@ -120,8 +120,12 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) {
ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
} else if (src_dtype == TypeId::kNumberTypeFloat32 && dst_dtype == TypeId::kNumberTypeInt8) {
bool from_uint8_src = false;
if (quant_arg.dstDtype == TypeId::kNumberTypeUInt8) {
from_uint8_src = true;
}
ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
quant_arg.zeroPoint, num_unit_thread, from_uint8_src);
} else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeUInt8) {
ret = Int8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread);
} else if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeFloat32) {
@ -138,8 +142,12 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) {
input_quant_arg.scale, input_quant_arg.zeroPoint);
if (ret) {
auto output_quant_arg = out_tensors_.front()->quant_params().front();
bool from_uint8_src = false;
if (quant_arg.dstDtype == TypeId::kNumberTypeUInt8) {
from_uint8_src = true;
}
ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_out_ptr_ + thread_offset, output_quant_arg.scale,
output_quant_arg.zeroPoint, num_unit_thread);
output_quant_arg.zeroPoint, num_unit_thread, from_uint8_src);
}
}

@ -254,8 +254,8 @@ int ConvolutionDepthwiseSWInt8CPUKernel::ReinitQuantParam() {
const double in_scale = static_cast<double>(input_scale_[i] * weight_scale_[i]);
double real_multiplier = in_scale / static_cast<double>(output_scale_[i]);
conv_quant_arg_->real_multiplier_[i] = real_multiplier;
QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], &conv_quant_arg_->left_shift_[i],
&conv_quant_arg_->right_shift_[i]);
QuantizeRoundParameterWithDoublePrecision(real_multiplier, &conv_quant_arg_->quant_multiplier_[i],
&conv_quant_arg_->left_shift_[i], &conv_quant_arg_->right_shift_[i]);
}
// now only consider per tensor for output

@ -132,8 +132,8 @@ int FullconnectionInt8CPUKernel::Init() {
for (int i = 0; i < weight_quant_num; ++i) {
const double in_scale = static_cast<double>(quant_.input_.scale_ * quant_.filter_scale_[i]);
double real_multiplier = in_scale / static_cast<double>(quant_.output_.scale_);
QuantizeRoundParameter(real_multiplier, &quant_.quant_multiplier_[i], &quant_.left_shift_[i],
&quant_.right_shift_[i]);
QuantizeRoundParameterWithDoublePrecision(real_multiplier, &quant_.quant_multiplier_[i], &quant_.left_shift_[i],
&quant_.right_shift_[i]);
}
CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6,

@ -138,8 +138,8 @@ int MatmulInt8CPUKernel::ReSize() {
}
}
double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_;
QuantizeRoundParameter(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift,
&quant_params_.right_shift);
QuantizeRoundParameterWithDoublePrecision(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift,
&quant_params_.right_shift);
return RET_OK;
}

@ -39,7 +39,8 @@ int ReluXInt8CPUKernel::Init() {
quant_arg_.output_arg.zp_ = output->quant_params().front().zeroPoint;
const double multiplier = quant_arg_.input_arg.scale_ / quant_arg_.output_arg.scale_;
QuantizeRoundParameter(multiplier, &quant_arg_.input_multiplier_, &quant_arg_.left_shift_, &quant_arg_.right_shift_);
QuantizeRoundParameterWithDoublePrecision(multiplier, &quant_arg_.input_multiplier_, &quant_arg_.left_shift_,
&quant_arg_.right_shift_);
return RET_OK;
}

@ -86,8 +86,8 @@ int ResizeInt8CPUKernel::Init() {
quant_out_->zp_ = output->quant_params().front().zeroPoint;
quant_out_->scale_ = output->quant_params().front().scale;
QuantizeRoundParameter(quant_in_->scale_ / quant_out_->scale_, &multiplier_->multiplier_, &multiplier_->left_shift_,
&multiplier_->right_shift_);
QuantizeRoundParameterWithDoublePrecision(quant_in_->scale_ / quant_out_->scale_, &multiplier_->multiplier_,
&multiplier_->left_shift_, &multiplier_->right_shift_);
if (!InferShapeDone()) {
return RET_OK;
}

@ -38,6 +38,9 @@ struct QuantArg {
bool inited;
std::vector<float> clusters{};
int bitNum;
int roundType;
int multiplier;
int dstDtype;
};
class Tensor : public mindspore::tensor::MSTensor {

@ -118,7 +118,7 @@ TEST_F(TestMatmulInt8, simple) {
int a_sums[ROW4] = {0};
int bias[COL4] = {0};
int multiplier, ls, rs;
QuantizeRoundParameter(1.0f, &multiplier, &ls, &rs);
QuantizeRoundParameterWithDoublePrecision(1.0f, &multiplier, &ls, &rs);
#ifdef ENABLE_ARM64
MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, &multiplier, &ls,
&rs, ROW, COL, COL, false);

@ -121,6 +121,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
std::make_unique<schema::QuantParamT>(input_quant_param);
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
<< " zp: " << input_quant_param_ptr->zeroPoint;
input_quant_param_ptr->dstDtype = tensor_input->dataType;
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
}
}
@ -151,6 +152,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
std::make_unique<schema::QuantParamT>(channel_quant_param);
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
<< " zp: " << output_quant_param_ptr->zeroPoint;
output_quant_param_ptr->dstDtype = output_tensor->dataType;
output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr));
}
}
@ -258,6 +260,9 @@ int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::uniqu
auto subgraph_name = func_graph->get_attr("graph_name");
MS_ASSERT(subgraph_name != nullptr);
sub_graphT->name = GetValue<std::string>(subgraph_name);
auto fmk = func_graph->get_attr("fmk");
MS_ASSERT(fmk != nullptr);
meta_graphT->fmkType = GetValue<int>(fmk);
auto cnodes = func_graph->GetOrderedCnodes();
for (const auto &cnode : cnodes) {

@ -448,6 +448,8 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
toAddTensor->dataType = prim->dstT;
if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) {
preTensor->quantParams.front()->zeroPoint += 128;
} else if (prim->srcT == TypeId::kNumberTypeInt8 && prim->dstT == TypeId::kNumberTypeUInt8) {
toAddTensor->quantParams.front()->zeroPoint += 128;
}
}
graphT->allTensors.emplace_back(std::move(toAddTensor));
@ -491,6 +493,8 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
toAddTensor->dataType = prim->dstT;
if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) {
preTensor->quantParams.front()->zeroPoint += 128;
} else if (prim->srcT == TypeId::kNumberTypeInt8 && prim->dstT == TypeId::kNumberTypeUInt8) {
toAddTensor->quantParams.front()->zeroPoint += 128;
}
}
graphT->allTensors.emplace_back(std::move(toAddTensor));
@ -552,8 +556,10 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
MS_ASSERT(prim != nullptr);
postTensor->dataType = prim->srcT;
toAddTensor->dataType = prim->dstT;
if (prim->dstT == TypeId::kNumberTypeUInt8 && prim->srcT == TypeId::kNumberTypeInt8) {
if (prim->srcT == TypeId::kNumberTypeInt8 && prim->dstT == TypeId::kNumberTypeUInt8) {
toAddTensor->quantParams.front()->zeroPoint += 128;
} else if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) {
postTensor->quantParams.front()->zeroPoint += 128;
}
}
graphT->allTensors.emplace_back(std::move(toAddTensor));
@ -624,6 +630,8 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
toAddTensor->dataType = prim->dstT;
if (prim->dstT == TypeId::kNumberTypeUInt8 && prim->srcT == TypeId::kNumberTypeInt8) {
toAddTensor->quantParams.front()->zeroPoint += 128;
} else if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) {
postTensor->quantParams.front()->zeroPoint += 128;
}
}
graphT->allTensors.emplace_back(std::move(toAddTensor));

@ -38,6 +38,8 @@ std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schem
dstQuantParam->max = srcQuantParam->max;
dstQuantParam->narrowRange = srcQuantParam->narrowRange;
dstQuantParam->numBits = srcQuantParam->numBits;
dstQuantParam->dstDtype = srcQuantParam->dstDtype;
dstQuantParam->multiplier = srcQuantParam->multiplier;
return dstQuantParam;
}

@ -71,6 +71,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
return nullptr;
}
graph->set_attr("graph_name", MakeValue("main_graph"));
graph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_MS)));
} else {
MS_ASSERT(nullptr != modelParser);
const std::string modelFile = flag->modelFile;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save