|
|
|
@ -47,13 +47,29 @@ int QuantDTypeCastFp16CPUKernel::Init() {
|
|
|
|
|
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
inverse_ = false;
|
|
|
|
|
int_to_float_ = false;
|
|
|
|
|
is_uint8_ = false;
|
|
|
|
|
} else if (param->srcT == kNumberTypeInt8) {
|
|
|
|
|
if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeFloat16) {
|
|
|
|
|
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
inverse_ = true;
|
|
|
|
|
int_to_float_ = true;
|
|
|
|
|
is_uint8_ = false;
|
|
|
|
|
} else if (param->dstT == kNumberTypeUInt8) {
|
|
|
|
|
if (in_tensor->data_type() != kNumberTypeFloat16 || out_tensor->data_type() != kNumberTypeUInt8) {
|
|
|
|
|
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
int_to_float_ = false;
|
|
|
|
|
is_uint8_ = true;
|
|
|
|
|
} else if (param->srcT == kNumberTypeUInt8) {
|
|
|
|
|
if (in_tensor->data_type() != kNumberTypeUInt8 || out_tensor->data_type() != kNumberTypeFloat16) {
|
|
|
|
|
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
int_to_float_ = true;
|
|
|
|
|
is_uint8_ = true;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "param data type not supported:"
|
|
|
|
|
<< " src: " << param->srcT << " dst: " << param->dstT;
|
|
|
|
@ -87,14 +103,26 @@ int QuantDTypeCastFp16CPUKernel::QuantDTypeCast(int task_id) {
|
|
|
|
|
auto quant_arg = !out_tensors_.front()->quant_params().empty() ? out_tensors_.front()->quant_params().front()
|
|
|
|
|
: in_tensors_.front()->quant_params().front();
|
|
|
|
|
int ret;
|
|
|
|
|
MS_ASSERT(int8_ptr_);
|
|
|
|
|
MS_ASSERT(float16_ptr_);
|
|
|
|
|
if (inverse_) {
|
|
|
|
|
ret = DoDequantizeInt8ToFp16(int8_ptr_ + thread_offset, float16_ptr_ + thread_offset, quant_arg.scale,
|
|
|
|
|
if (!is_uint8_) {
|
|
|
|
|
MS_ASSERT(int8_ptr_);
|
|
|
|
|
if (int_to_float_) {
|
|
|
|
|
ret = DoDequantizeInt8ToFp16(int8_ptr_ + thread_offset, float16_ptr_ + thread_offset, quant_arg.scale,
|
|
|
|
|
quant_arg.zeroPoint, num_unit_thread);
|
|
|
|
|
} else {
|
|
|
|
|
ret = DoQuantizeFp16ToInt8(float16_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
|
|
|
|
|
quant_arg.zeroPoint, num_unit_thread);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
ret = DoQuantizeToInt8FromFp16(float16_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
|
|
|
|
|
quant_arg.zeroPoint, num_unit_thread);
|
|
|
|
|
// uint8
|
|
|
|
|
MS_ASSERT(uint8_ptr_);
|
|
|
|
|
if (int_to_float_) {
|
|
|
|
|
ret = DoDequantizeUInt8ToFp16(uint8_ptr_ + thread_offset, float16_ptr_ + thread_offset, quant_arg.scale,
|
|
|
|
|
quant_arg.zeroPoint, num_unit_thread);
|
|
|
|
|
} else {
|
|
|
|
|
ret = DoQuantizeFp16ToUInt8(float16_ptr_ + thread_offset, uint8_ptr_ + thread_offset, quant_arg.scale,
|
|
|
|
|
quant_arg.zeroPoint, num_unit_thread);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
@ -123,6 +151,14 @@ int QuantDTypeCastFp16CPUKernel::Run() {
|
|
|
|
|
out_tensors_.at(0)->data_type() == TypeId::kNumberTypeInt8) {
|
|
|
|
|
float16_ptr_ = reinterpret_cast<float16_t *>(in_tensors_.at(0)->data_c());
|
|
|
|
|
int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_.at(0)->data_c());
|
|
|
|
|
} else if (in_tensors_.at(0)->data_type() == TypeId::kNumberTypeUInt8 &&
|
|
|
|
|
out_tensors_.at(0)->data_type() == TypeId::kNumberTypeFloat16) {
|
|
|
|
|
uint8_ptr_ = reinterpret_cast<uint8_t *>(in_tensors_.at(0)->data_c());
|
|
|
|
|
float16_ptr_ = reinterpret_cast<float16_t *>(out_tensors_.at(0)->data_c());
|
|
|
|
|
} else if (in_tensors_.at(0)->data_type() == TypeId::kNumberTypeFloat16 &&
|
|
|
|
|
out_tensors_.at(0)->data_type() == TypeId::kNumberTypeUInt8) {
|
|
|
|
|
float16_ptr_ = reinterpret_cast<float16_t *>(in_tensors_.at(0)->data_c());
|
|
|
|
|
uint8_ptr_ = reinterpret_cast<uint8_t *>(out_tensors_.at(0)->data_c());
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "QuantDTypeCastFp16 not support input or output type";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|