diff --git a/src/model.cpp b/src/model.cpp index d23b97fac..2c708ed6a 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -162,43 +162,7 @@ uint16_t f8_e4m3_to_f16(uint8_t f8) { } uint16_t f8_e5m2_to_f16(uint8_t fp8) { - uint8_t sign = (fp8 >> 7) & 0x1; - uint8_t exponent = (fp8 >> 2) & 0x1F; - uint8_t mantissa = fp8 & 0x3; - - uint16_t fp16_sign = sign << 15; - uint16_t fp16_exponent; - uint16_t fp16_mantissa; - - if (exponent == 0 && mantissa == 0) { // zero - return fp16_sign; - } - - if (exponent == 0x1F) { // NAN and INF - fp16_exponent = 0x1F; - fp16_mantissa = mantissa ? (mantissa << 8) : 0; - return fp16_sign | (fp16_exponent << 10) | fp16_mantissa; - } - - if (exponent == 0) { // subnormal numbers - fp16_mantissa = (mantissa << 8); - return fp16_sign | fp16_mantissa; - } - - // normal numbers - int16_t true_exponent = (int16_t)exponent - 15 + 15; - if (true_exponent <= 0) { - fp16_exponent = 0; - fp16_mantissa = (mantissa << 8); - } else if (true_exponent >= 0x1F) { - fp16_exponent = 0x1F; - fp16_mantissa = 0; - } else { - fp16_exponent = (uint16_t)true_exponent; - fp16_mantissa = mantissa << 8; - } - - return fp16_sign | (fp16_exponent << 10) | fp16_mantissa; + return static_cast(fp8) << 8; } void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {