From 5c0dc4e2ce2ca8d1ad80ab6ee84ff0e039839096 Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Fri, 4 Sep 2020 11:47:18 +0800 Subject: [PATCH] rewrite fp16 to fp32 and fp32 to fp16 --- mindspore/lite/nnacl/fp32/common_func.c | 131 ++++++++++++++---------- mindspore/lite/nnacl/fp32/common_func.h | 4 +- 2 files changed, 81 insertions(+), 54 deletions(-) diff --git a/mindspore/lite/nnacl/fp32/common_func.c b/mindspore/lite/nnacl/fp32/common_func.c index fb0286000f..4110547294 100644 --- a/mindspore/lite/nnacl/fp32/common_func.c +++ b/mindspore/lite/nnacl/fp32/common_func.c @@ -87,68 +87,95 @@ union float32_bits { }; typedef union float32_bits float32_bits; -float ShortToFloat32(uint16_t srcValue) { +float ShortToFloat32(uint16_t src_value) { const float32_bits magic = {113 << 23}; - const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift + const unsigned int shifted_exp = 0x7c00 << 13; float32_bits o; - o.u = (srcValue & 0x7fff) << 13; // exponent/mantissa bits - unsigned int exp = shifted_exp & o.u; // just the exponent - o.u += (127 - 15) << 23; // exponent adjust + o.u = (src_value & 0x7fff) << 13; + unsigned int exp = shifted_exp & o.u; + o.u += (127 - 15) << 23; - // handle exponent special cases - if (exp == shifted_exp) { // Inf/NaN? - o.u += (128 - 16) << 23; // extra exp adjust - } else if (exp == 0) { // Zero/Denormal? - o.u += 1 << 23; // extra exp adjust - o.f -= magic.f; // renormalize + if (exp == shifted_exp) { + o.u += (128 - 16) << 23; + } else if (exp == 0) { + o.u += 1 << 23; + o.f -= magic.f; } - o.u |= (srcValue & 0x8000) << 16; // sign bit + o.u |= (src_value & 0x8000) << 16; return o.f; } -uint16_t Float32ToShort(float srcValue) { - float32_bits f; - f.f = srcValue; - - const float32_bits f32infty = {255 << 23}; - const float32_bits f16max = {(127 + 16) << 23}; - const float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; - unsigned int sign_mask = 0x80000000u; - uint16_t o; - - unsigned int sign = f.u & sign_mask; - f.u ^= sign; - - // NOTE all the integer compares in this function can be safely - // compiled into signed compares since all operands are below - // 0x80000000. Important if you want fast straight SSE2 code - // (since there's no unsigned PCMPGTD). - - if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set) - o = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf - } else { // (De)normalized number or zero - if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero - // use a magic value to align our 10 mantissa bits at the bottom of - // the float. as long as FP addition is round-to-nearest-even this - // just works. - f.f += denorm_magic.f; - - // and one integer subtract of the bias later, we have our final float! - o = (uint16_t)(f.u - denorm_magic.u); +static const unsigned int FP32_BIT_SIZE = 32; +static const unsigned int FP32_EXPONENT_BIAS = 127; +static const unsigned int FP32_SIGNIFICAND = 23; + +static const unsigned int FP32_EXPONENT_MAX = 255; + +static const unsigned int FP16_BIT_SIZE = 16; +static const unsigned int FP16_EXPONENT_BIAS = 15; +static const unsigned int FP16_SIGNIFICAND = 10; + +static const int FP16_EXPONENT_MAX = 30; +static const int FP16_EXPONENT_MIN = -10; + +uint16_t Float32ToShort(float src_value) { + float *psrcValue = NULL; + psrcValue = &src_value; + unsigned int srcValueBit = (unsigned int)(*psrcValue); + unsigned int sign = srcValueBit >> (FP32_BIT_SIZE - 1); + unsigned int mantissa = srcValueBit & 0x007FFFFF; + // exponent + int exp = ((srcValueBit & 0x7F800000) >> FP32_SIGNIFICAND) + FP16_EXPONENT_BIAS - FP32_EXPONENT_BIAS; + uint16_t res; + if (exp > 0 && exp < FP16_EXPONENT_MAX) { + // use rte rounding mode, round the significand, combine sign, exponent and significand into a short. + res = (sign << (FP16_BIT_SIZE - 1)) | (exp << FP16_SIGNIFICAND) | + ((mantissa + 0x00001000) >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + } else if (srcValueBit == 0) { + res = 0; + } else { + if (exp <= 0) { + if (exp < FP16_EXPONENT_MIN) { + // value is less than min half float point + res = 0; + } else { + // normalized single, magnitude is less than min normal half float point. + mantissa = (mantissa | 0x00800000) >> (1 - exp); + // round to nearest + if ((mantissa & 0x00001000) > 0) { + mantissa = mantissa + 0x00002000; + } + // combine sign & mantissa (exp is zero to get denormalized number) + res = (sign << FP16_EXPONENT_BIAS) | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + } + } else if (exp == (FP32_EXPONENT_MAX - FP32_EXPONENT_BIAS + FP16_EXPONENT_BIAS)) { + if (mantissa == 0) { + // input float is infinity, return infinity half + res = (sign << FP16_EXPONENT_BIAS) | 0x7C00; + } else { + // input float is NaN, return half NaN + res = (sign << FP16_EXPONENT_BIAS) | 0x7C00 | (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + } } else { - unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd - - // update exponent, rounding bias part 1 - f.u += ((unsigned int)(15 - 127) << 23) + 0xfff; - // rounding bias part 2 - f.u += mant_odd; - // take the bits! - o = (uint16_t)(f.u >> 13); + // exp > 0, normalized single, round to nearest + if ((mantissa & 0x00001000) > 0) { + mantissa = mantissa + 0x00002000; + if ((mantissa & 0x00800000) > 0) { + mantissa = 0; + exp = exp + 1; + } + } + if (exp > FP16_EXPONENT_MAX) { + // exponent overflow - return infinity half + res = (sign << FP16_EXPONENT_BIAS) | 0x7C00; + } else { + // combine sign, exp and mantissa into normalized half + res = (sign << FP16_EXPONENT_BIAS) | (exp << FP16_SIGNIFICAND) | + (mantissa >> (FP32_SIGNIFICAND - FP16_SIGNIFICAND)); + } } } - - o |= (uint16_t)(sign >> 16); - return o; + return res; } diff --git a/mindspore/lite/nnacl/fp32/common_func.h b/mindspore/lite/nnacl/fp32/common_func.h index 873768f508..300149c492 100644 --- a/mindspore/lite/nnacl/fp32/common_func.h +++ b/mindspore/lite/nnacl/fp32/common_func.h @@ -31,9 +31,9 @@ void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bi size_t plane_size, size_t stride, bool is_relu, bool is_relu6); void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, size_t plane_size, size_t stride, bool is_relu, bool is_relu6); -float ShortToFloat32(uint16_t srcValue); +float ShortToFloat32(uint16_t src_value); -uint16_t Float32ToShort(float srcValue); +uint16_t Float32ToShort(float src_value); #ifdef ENABLE_ARM void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,