!10153 castfp16 support uint8

From: @zhaozhenlong
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
pull/10153/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 502e1cbadd

@ -29,12 +29,16 @@ int DoDequantizeInt8ToFp16(int8_t *quant_values, float16_t *real_values, float s
return NNACL_OK;
}
int DoQuantizeToInt8FromFp16(float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size) {
int DoQuantizeFp16ToInt8(float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size) {
if (quant_values == NULL || real_values == NULL) {
return NNACL_PARAM_INVALID;
}
for (int i = 0; i < size; ++i) {
if (isinf(real_values[i])) {
quant_values[i] = 127;
continue;
}
float temp = round((float)real_values[i] / scale + zp);
if (temp > 127) {
quant_values[i] = 127;
@ -46,3 +50,37 @@ int DoQuantizeToInt8FromFp16(float16_t *real_values, int8_t *quant_values, float
}
return NNACL_OK;
}
int DoDequantizeUInt8ToFp16(uint8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size) {
uint8_t zp_ = (uint8_t)zp;
if (quant_values == NULL || real_values == NULL) {
return NNACL_PARAM_INVALID;
}
for (int i = 0; i < size; ++i) {
real_values[i] = (quant_values[i] - zp_) * scale;
}
return NNACL_OK;
}
int DoQuantizeFp16ToUInt8(float16_t *real_values, uint8_t *quant_values, float scale, int32_t zp, int size) {
if (quant_values == NULL || real_values == NULL) {
return NNACL_PARAM_INVALID;
}
for (int i = 0; i < size; ++i) {
if (isinf(real_values[i])) {
quant_values[i] = 255;
continue;
}
float temp = round((float)real_values[i] / scale + zp);
if (temp > 255.0f) {
quant_values[i] = 255;
} else if (temp < 0.0f) {
quant_values[i] = 0;
} else {
quant_values[i] = (uint8_t)temp;
}
}
return NNACL_OK;
}

@ -27,7 +27,10 @@
extern "C" {
#endif
int DoDequantizeInt8ToFp16(int8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size);
int DoQuantizeToInt8FromFp16(float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size);
int DoQuantizeFp16ToInt8(float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size);
int DoDequantizeUInt8ToFp16(uint8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size);
int DoQuantizeFp16ToUInt8(float16_t *real_values, uint8_t *quant_values, float scale, int32_t zp, int size);
#ifdef __cplusplus
}
#endif

@ -47,13 +47,29 @@ int QuantDTypeCastFp16CPUKernel::Init() {
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
return RET_ERROR;
}
inverse_ = false;
int_to_float_ = false;
is_uint8_ = false;
} else if (param->srcT == kNumberTypeInt8) {
if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeFloat16) {
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
return RET_ERROR;
}
inverse_ = true;
int_to_float_ = true;
is_uint8_ = false;
} else if (param->dstT == kNumberTypeUInt8) {
if (in_tensor->data_type() != kNumberTypeFloat16 || out_tensor->data_type() != kNumberTypeUInt8) {
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
return RET_ERROR;
}
int_to_float_ = false;
is_uint8_ = true;
} else if (param->srcT == kNumberTypeUInt8) {
if (in_tensor->data_type() != kNumberTypeUInt8 || out_tensor->data_type() != kNumberTypeFloat16) {
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
return RET_ERROR;
}
int_to_float_ = true;
is_uint8_ = true;
} else {
MS_LOG(ERROR) << "param data type not supported:"
<< " src: " << param->srcT << " dst: " << param->dstT;
@ -87,14 +103,26 @@ int QuantDTypeCastFp16CPUKernel::QuantDTypeCast(int task_id) {
auto quant_arg = !out_tensors_.front()->quant_params().empty() ? out_tensors_.front()->quant_params().front()
: in_tensors_.front()->quant_params().front();
int ret;
MS_ASSERT(int8_ptr_);
MS_ASSERT(float16_ptr_);
if (inverse_) {
ret = DoDequantizeInt8ToFp16(int8_ptr_ + thread_offset, float16_ptr_ + thread_offset, quant_arg.scale,
if (!is_uint8_) {
MS_ASSERT(int8_ptr_);
if (int_to_float_) {
ret = DoDequantizeInt8ToFp16(int8_ptr_ + thread_offset, float16_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
} else {
ret = DoQuantizeFp16ToInt8(float16_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
}
} else {
ret = DoQuantizeToInt8FromFp16(float16_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
// uint8
MS_ASSERT(uint8_ptr_);
if (int_to_float_) {
ret = DoDequantizeUInt8ToFp16(uint8_ptr_ + thread_offset, float16_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
} else {
ret = DoQuantizeFp16ToUInt8(float16_ptr_ + thread_offset, uint8_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
}
}
if (ret != RET_OK) {
@ -123,6 +151,14 @@ int QuantDTypeCastFp16CPUKernel::Run() {
out_tensors_.at(0)->data_type() == TypeId::kNumberTypeInt8) {
float16_ptr_ = reinterpret_cast<float16_t *>(in_tensors_.at(0)->data_c());
int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_.at(0)->data_c());
} else if (in_tensors_.at(0)->data_type() == TypeId::kNumberTypeUInt8 &&
out_tensors_.at(0)->data_type() == TypeId::kNumberTypeFloat16) {
uint8_ptr_ = reinterpret_cast<uint8_t *>(in_tensors_.at(0)->data_c());
float16_ptr_ = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c());
} else if (in_tensors_.at(0)->data_type() == TypeId::kNumberTypeFloat16 &&
out_tensors_.at(0)->data_type() == TypeId::kNumberTypeUInt8) {
float16_ptr_ = reinterpret_cast<float16_t *>(in_tensors_.at(0)->data_c());
uint8_ptr_ = reinterpret_cast<uint8_t *>(out_tensors_.at(0)->data_c());
} else {
MS_LOG(ERROR) << "QuantDTypeCastFp16 not support input or output type";
return RET_ERROR;

@ -41,8 +41,10 @@ class QuantDTypeCastFp16CPUKernel : public LiteKernel {
int thread_n_stride_;
int num_unit_;
int8_t *int8_ptr_;
uint8_t *uint8_ptr_;
float16_t *float16_ptr_;
bool inverse_;
bool int_to_float_;
bool is_uint8_;
};
} // namespace mindspore::kernel

Loading…
Cancel
Save