This commit is contained in:
treeform 2022-05-07 11:18:58 -07:00
parent f4cd87b505
commit bc635bc3da

View file

@ -103,10 +103,6 @@ template clampByte(x): uint8 =
## Clamp integer into byte range. ## Clamp integer into byte range.
clamp(x, 0, 0xFF).uint8 clamp(x, 0, 0xFF).uint8
template clampInt16(x): int16 =
## Clamp integer into byte range.
clamp(x, -32768, 32767).int16
proc readUint8(state: var DecoderState): uint8 = proc readUint8(state: var DecoderState): uint8 =
## Reads a byte from the input stream. ## Reads a byte from the input stream.
if state.pos >= state.buffer.len: if state.pos >= state.buffer.len:
@ -243,8 +239,8 @@ proc decodeSOF0(state: var DecoderState) =
failInvalid("unsupported component count, must be 1 or 3") failInvalid("unsupported component count, must be 1 or 3")
for i in 0 ..< numComponents: for i in 0 ..< numComponents:
state.components.add(Component()) var component = Component()
state.components[i].id = state.readUint8() component.id = state.readUint8()
let let
info = state.readUint8() info = state.readUint8()
vertical = info and 15 vertical = info and 15
@ -257,18 +253,19 @@ proc decodeSOF0(state: var DecoderState) =
if vertical == 0 or vertical > 4 or horizontal == 0 or horizontal > 4: if vertical == 0 or vertical > 4 or horizontal == 0 or horizontal > 4:
failInvalid("invalid component scaling factor") failInvalid("invalid component scaling factor")
state.components[i].xScale = vertical.int component.xScale = vertical.int
state.components[i].yScale = horizontal.int component.yScale = horizontal.int
state.components[i].quantizationTableId = quantizationTableId component.quantizationTableId = quantizationTableId
state.components.add(component)
for i in 0 ..< state.components.len: for component in state.components.mitems:
state.maxXScale = max( state.maxXScale = max(
state.maxXScale, state.maxXScale,
state.components[i].xScale component.xScale
) )
state.maxYScale = max( state.maxYScale = max(
state.maxYScale, state.maxYScale,
state.components[i].yScale component.yScale
) )
state.mcuWidth = state.maxYScale * 8 state.mcuWidth = state.maxYScale * 8
@ -278,40 +275,40 @@ proc decodeSOF0(state: var DecoderState) =
state.numMcuHigh = state.numMcuHigh =
(state.imageHeight + state.mcuHeight - 1) div state.mcuHeight (state.imageHeight + state.mcuHeight - 1) div state.mcuHeight
for i in 0 ..< state.components.len: for component in state.components.mitems:
state.components[i].width = ( component.width = (
state.imageWidth * state.imageWidth *
state.components[i].yScale + component.yScale +
state.maxYScale - 1 state.maxYScale - 1
) div state.maxYScale ) div state.maxYScale
state.components[i].height = ( component.height = (
state.imageHeight * state.imageHeight *
state.components[i].xScale + component.xScale +
state.maxXScale - 1 state.maxXScale - 1
) div state.maxXScale ) div state.maxXScale
# Allocate block data structures. # Allocate block data structures.
state.components[i].blocks = newSeqWith( component.blocks = newSeqWith(
state.components[i].width, component.width,
newSeq[array[64, int16]]( newSeq[array[64, int16]](
state.components[i].height component.height
) )
) )
state.components[i].widthStride = component.widthStride =
state.numMcuWide * state.components[i].yScale * 8 state.numMcuWide * component.yScale * 8
state.components[i].heightStride = component.heightStride =
state.numMcuHigh * state.components[i].xScale * 8 state.numMcuHigh * component.xScale * 8
state.components[i].channel = newMask( component.channel = newMask(
state.components[i].widthStride, state.components[i].heightStride component.widthStride, component.heightStride
) )
if state.progressive: if state.progressive:
state.components[i].widthCoeff = state.components[i].widthStride div 8 component.widthCoeff = component.widthStride div 8
state.components[i].heightCoeff = state.components[i].heightStride div 8 component.heightCoeff = component.heightStride div 8
state.components[i].coeff.setLen( component.coeff.setLen(
state.components[i].widthStride * state.components[i].heightStride component.widthStride * component.heightStride
) )
proc decodeSOF1(state: var DecoderState) = proc decodeSOF1(state: var DecoderState) =
@ -509,7 +506,7 @@ proc decodeRegularBlock(
state.getBitsAsSignedInt(t) state.getBitsAsSignedInt(t)
dc = state.components[component].dcPred + diff dc = state.components[component].dcPred + diff
state.components[component].dcPred = dc state.components[component].dcPred = dc
data[0] = clampInt16(dc) data[0] = dc.int16
var i = 1 var i = 1
while true: while true:
@ -551,7 +548,7 @@ proc decodeProgressiveBlock(
let let
dc = state.components[component].dcPred + diff dc = state.components[component].dcPred + diff
state.components[component].dcPred = dc state.components[component].dcPred = dc
data[0] = clampInt16(dc * (1 shl state.successiveApproxLow)) data[0] = (dc * (1 shl state.successiveApproxLow)).int16
else: else:
if getBit(state) != 0: if getBit(state) != 0:
@ -596,7 +593,7 @@ proc decodeProgressiveContinuationBlock(
inc k inc k
if s >= 15: if s >= 15:
failInvalid() failInvalid()
data[zig] = clampInt16(state.getBitsAsSignedInt(s.int) * (1 shl shift)) data[zig] = (state.getBitsAsSignedInt(s.int) * (1 shl shift)).int16
if not(k <= state.spectralEnd): if not(k <= state.spectralEnd):
break break
@ -708,7 +705,7 @@ proc idctBlock(component: var Component, offset: int, data: array[64, int16]) =
data[i + 40] == 0 and data[i + 40] == 0 and
data[i + 48] == 0 and data[i + 48] == 0 and
data[i + 56] == 0: data[i + 56] == 0:
let dcterm = clampInt16(data[i].int * 4.int) let dcterm = data[i] * 4
values[i + 0] = dcterm values[i + 0] = dcterm
values[i + 8] = dcterm values[i + 8] = dcterm
values[i + 16] = dcterm values[i + 16] = dcterm
@ -854,7 +851,7 @@ proc quantizationAndIDCTPass(state: var DecoderState) =
let qTableId = state.components[comp].quantizationTableId let qTableId = state.components[comp].quantizationTableId
if qTableId.int notin 0 ..< state.quantizationTables.len: if qTableId.int notin 0 ..< state.quantizationTables.len:
failInvalid() failInvalid()
data[i] = clampInt16(data[i] * state.quantizationTables[qTableId][i].int) data[i] = data[i] * state.quantizationTables[qTableId][i].int16
state.components[comp].idctBlock( state.components[comp].idctBlock(
state.components[comp].widthStride * column * 8 + row * 8, state.components[comp].widthStride * column * 8 + row * 8,