from std/bitops import countLeadingZeroBits

type Float16* = distinct uint16

template `-%`(x, y: uint16): uint16 =
    cast[uint16](cast[int16](x) -% cast[int16](y))

type Denormals* = enum
    Ignore
    SetZero
    Calculate

proc toFloat16*[T](n: T, denormals: static[Denormals] = Ignore, clamp: static[bool] = false): Float16 =
    when clamp:
        var fltInt32 = cast[uint32](clamp(n.float32, -65504.0, 65504.0))
    else:
        var fltInt32 = cast[uint32](n.float32)
    var fltInt16 = cast[uint16]((fltInt32 shr 31) shl 5)
    var tmp: uint16 = cast[uint16](fltInt32 shr 23) and 0xff
    when denormals == Calculate:
        let shift = 113'u16 -% min(tmp, 113)
        let implicit = (shift != 0).uint16 shl 10
    else:
        const shift = 0'u16
        const implicit = 0'u16
    tmp = (tmp -% 0x70) and cast[uint16](cast[uint32]((0x70'i32 -% cast[int32](tmp)) shr 4) shr 27)
    fltInt16 = (fltInt16 or tmp) shl 10
    var r = fltInt16 or (((fltInt32 shr 13) and 0x3ff or implicit) shr shift)
    when denormals == SetZero:
        if (r and 0x7C00) == 0:
            r = r and 0x8000
    return cast[Float16](r)

proc toFloat32*(n: Float16, denormals: static[Denormals] = Ignore): float32 =
    var u = (cast[uint16](n).uint32 and 0x7fff'u32) shl 13
    if u != 0:
        if u < 0x00800000:
            when denormals == Calculate:
                let c = cast[uint16](countLeadingZeroBits(u)) -% 8
                u = (u shl c) and 0x007FFFFF
                u = u or cast[uint32]((1'i32 -% cast[int32](c)) shl 23)
            elif denormals == SetZero:
                u = 0
        u = u + 0x38000000'u32
        if u >= ((127 + 16).uint32 shl 23):
            u = u or (255'u32 shl 23)
    u = u or (cast[uint16](n).uint32 and 0x8000'u32) shl 16
    return cast[float32](u)

template toFloat64*(n: Float16, denormals: static[Denormals] = Ignore): float64 =
    n.toFloat32(denormals).float64

template toFloat*(n: Float16, denormals: static[Denormals] = Ignore): float =
    n.toFloat32(denormals).float

template `$`*(n: Float16): string = $(n.tofloat32)

# from std/strutils import tohex
# template `tohex`*(n: Float16): string = n.uint16.tohex

# proc debug_f32(n: uint32): string =
#     let s = n.int.tobin(32)
#     return s[0] & " " & s[1 ..< 9] & " " & s[9 ..< 32]