diff --git a/src/pixie/fileformats/jpeg.nim b/src/pixie/fileformats/jpeg.nim index 2fb59e2..3bb484d 100644 --- a/src/pixie/fileformats/jpeg.nim +++ b/src/pixie/fileformats/jpeg.nim @@ -79,7 +79,7 @@ type componentOrder: seq[int] progressive: bool restartInterval: int - todo: int + todoBeforeRestart: int eobRun: int when defined(release): @@ -91,7 +91,11 @@ template failInvalid(reason = "unable to load") = template clampByte(x: int32): uint8 = ## Clamp integer into byte range. - clamp(x, 0, 0xFF).uint8 + # clamp(x, 0, 0xFF).uint8 + let + signBit = (cast[uint32](x) shr 31) + value = cast[uint32](x) and (signBit - 1) + min(value, 255).uint8 proc readUint8(state: var DecoderState): uint8 = ## Reads a byte from the input stream. @@ -117,7 +121,7 @@ proc skipChunk(state: var DecoderState) = proc decodeDRI(state: var DecoderState) = ## Decode Define Restart Interval - var len = state.readUint16be() - 2 + let len = state.readUint16be() - 2 if len != 2: failInvalid("invalid DRI length") state.restartInterval = state.readUint16be().int @@ -212,7 +216,7 @@ proc decodeDHT(state: var DecoderState) = proc decodeSOF0(state: var DecoderState) = ## Decode start of Frame - var len = state.readUint16be() - 2 + var len = state.readUint16be().int - 2 let precision = state.readUint8() if precision != 8: @@ -230,6 +234,8 @@ proc decodeSOF0(state: var DecoderState) = if numComponents notin {1, 3}: failInvalid("unsupported component count, must be 1 or 3") + len -= 6 + for i in 0 ..< numComponents: var component = Component() component.id = state.readUint8() @@ -250,6 +256,8 @@ proc decodeSOF0(state: var DecoderState) = component.quantizationTableId = quantizationTableId state.components.add(component) + len -= 3 * numComponents + for component in state.components.mitems: state.maxXScale = max(state.maxXScale, component.xScale) state.maxYScale = max(state.maxYScale, component.yScale) @@ -292,6 +300,9 @@ proc decodeSOF0(state: var DecoderState) = component.widthStride * component.heightStride ) + if len != 0: + failInvalid() + proc decodeSOF1(state: var DecoderState) = failInvalid("unsupported extended sequential DCT format") @@ -302,16 +313,16 @@ proc decodeSOF2(state: var DecoderState) = state.progressive = true proc reset(state: var DecoderState) = - ## Rests the decoder state need for reset markers. + ## Rests the decoder state need for restart markers. state.bitBuffer = 0 state.bitsBuffered = 0 for component in 0 ..< state.components.len: state.components[component].dcPred = 0 state.hitEnd = false if state.restartInterval != 0: - state.todo = state.restartInterval + state.todoBeforeRestart = state.restartInterval else: - state.todo = 0x7fffffff + state.todoBeforeRestart = int.high state.eobRun = 0 proc decodeSOS(state: var DecoderState) = @@ -474,9 +485,8 @@ proc decodeRegularBlock( ) = ## Decodes a whole block. let t = state.huffmanDecode(0, state.components[component].huffmanDC).int - if t < 0: - failInvalid() - + if t > 15: + failInvalid("bad huffman code") let diff = if t == 0: @@ -488,7 +498,7 @@ proc decodeRegularBlock( data[0] = cast[int16](dc) var i = 1 - while true: + while i < 64: let rs = state.huffmanDecode(1, state.components[component].huffmanAC) s = rs and 15 @@ -499,15 +509,12 @@ proc decodeRegularBlock( i += 16 else: i += r.int - if i notin 0 ..< 64: + if i >= 64: failInvalid() let zig = deZigZag[i] data[zig] = cast[int16](state.getBitsAsSignedInt(s.int)) inc i - if not(i < 64): - break - proc decodeProgressiveBlock( state: var DecoderState, component: int, data: var array[64, int16] ) = @@ -517,18 +524,17 @@ proc decodeProgressiveBlock( if state.successiveApproxHigh == 0: let t = state.huffmanDecode(0, state.components[component].huffmanDC).int - if t < 0 or t > 15: - failInvalid() - let - diff = if t != 0: - state.getBitsAsSignedInt(t) - else: - 0 + if t > 15: + failInvalid("bad huffman code") let + diff = + if t > 0: + state.getBitsAsSignedInt(t) + else: + 0 dc = state.components[component].dcPred + diff state.components[component].dcPred = dc data[0] = cast[int16](dc * (1 shl state.successiveApproxLow)) - else: if getBit(state) != 0: data[0] = cast[int16](data[0] + (1 shl state.successiveApproxLow)) @@ -548,12 +554,9 @@ proc decodeProgressiveContinuationBlock( return var k = state.spectralStart - while true: + while k <= state.spectralEnd: let rs = state.huffmanDecode(1, state.components[component].huffmanAC) - if rs < 0: - failInvalid("bad huffman code") - let s = rs and 15 r = rs.int shr 4 if s == 0: @@ -566,7 +569,7 @@ proc decodeProgressiveContinuationBlock( k += 16 else: k += r.int - if k notin 0 ..< 64: + if k >= 64: failInvalid() let zig = deZigZag[k] inc k @@ -574,9 +577,6 @@ proc decodeProgressiveContinuationBlock( failInvalid() data[zig] = cast[int16](state.getBitsAsSignedInt(s.int) * (1 shl shift)) - if not(k <= state.spectralEnd): - break - else: var bit = 1 shl state.successiveApproxLow @@ -593,11 +593,8 @@ proc decodeProgressiveContinuationBlock( data[zig] = cast[int16](data[zig] - bit) else: var k = state.spectralStart - while true: - let - rs = state.huffmanDecode(1, state.components[component].huffmanAC) - if rs < 0: - failInvalid("bad huffman code") + while k <= state.spectralEnd: + let rs = state.huffmanDecode(1, state.components[component].huffmanAC) var s = rs.int and 15 r = rs.int shr 4 @@ -633,9 +630,6 @@ proc decodeProgressiveContinuationBlock( break dec r - if not (k <= state.spectralEnd): - break - template idct1D(s0, s1, s2, s3, s4, s5, s6, s7: int32) = ## Inverse discrete cosine transform 1D template f2f(x: float32): int32 = (x * 4096 + 0.5).int32 @@ -766,18 +760,18 @@ proc decodeBlock(state: var DecoderState, comp, row, column: int) = else: state.decodeRegularBlock(comp, data) -proc checkReset(state: var DecoderState) = - ## Check if we might have run into a reset marker, then deal with it. - dec state.todo - if state.todo <= 0: +proc checkRestart(state: var DecoderState) = + ## Check if we might have run into a restart marker, then deal with it. + dec state.todoBeforeRestart + if state.todoBeforeRestart <= 0: if state.bitsBuffered < 24: state.fillBitBuffer() if state.buffer[state.pos] == 0xFF.char: - if state.buffer[state.pos+1] in {0xD0.char .. 0xD7.char}: + if state.buffer[state.pos + 1] in {0xD0.char .. 0xD7.char}: state.pos += 2 else: - failInvalid("did not get expected reset marker") + failInvalid("did not get expected restart marker") state.reset() proc decodeBlocks(state: var DecoderState) = @@ -791,7 +785,7 @@ proc decodeBlocks(state: var DecoderState) = for column in 0 ..< h: for row in 0 ..< w: state.decodeBlock(comp, row, column) - state.checkReset() + state.checkRestart() else: # Interleaved regular component pass. for mcuY in 0 ..< state.numMcuHigh: @@ -803,7 +797,7 @@ proc decodeBlocks(state: var DecoderState) = row = (mcuX * state.components[comp].yScale + compX) col = (mcuY * state.components[comp].xScale + compY) state.decodeBlock(comp, row, col) - state.checkReset() + state.checkRestart() proc quantizationAndIDCTPass(state: var DecoderState) = ## Does quantization and IDCT. @@ -811,16 +805,14 @@ proc quantizationAndIDCTPass(state: var DecoderState) = let w = (state.components[comp].width + 7) shr 3 h = (state.components[comp].height + 7) shr 3 + qTableId = state.components[comp].quantizationTableId + if qTableId.int notin 0 ..< state.quantizationTables.len: + failInvalid() for column in 0 ..< h: for row in 0 ..< w: - var data = state.components[comp].blocks[row][column] - + var data {.byaddr.} = state.components[comp].blocks[row][column] for i in 0 ..< 64: - let qTableId = state.components[comp].quantizationTableId - if qTableId.int notin 0 ..< state.quantizationTables.len: - failInvalid() data[i] = cast[int16](data[i] * state.quantizationTables[qTableId][i].int32) - state.components[comp].idctBlock( state.components[comp].widthStride * column * 8 + row * 8, data @@ -909,15 +901,11 @@ proc buildImage(state: var DecoderState): Image = cb.unsafe[x, y], cr.unsafe[x, y], ) - elif state.components.len == 1: - let - cy = state.components[0].channel + let cy = state.components[0].channel for y in 0 ..< state.imageHeight: for x in 0 ..< state.imageWidth: - result.unsafe[x, y] = grayScaleToRgbx( - cy.unsafe[x, y], - ) + result.unsafe[x, y] = grayScaleToRgbx(cy.unsafe[x, y]) else: failInvalid() @@ -952,8 +940,8 @@ proc decodeJpeg*(data: string): Image {.raises: [PixieError].} = # EOI - End of Image break of 0xD0 .. 0xD7: - # Reset markers - failInvalid("invalid reset marker") + # Restart markers + failInvalid("invalid restart marker") of 0xDB: # Define Quantization Table(s) state.decodeDQT()