jpg checkpoint

This commit is contained in:
Ryan Oldenburg 2020-11-27 15:23:29 -06:00
parent 41edec4711
commit a8d667f879
3 changed files with 347 additions and 172 deletions

View file

@ -1,216 +1,391 @@
import flatty/binny, pixie/common, pixie/images
import pixie/common, pixie/images, strutils
# See https://github.com/nothings/stb/blob/master/stb_image.h
# See http://www.vip.sugovica.hu/Sardi/kepnezo/JPEG%20File%20Layout%20and%20Format.htm
const
jpgStartOfImage* = [0xFF.uint8, 0xD8]
deZigZag = [
0.uint8, 1, 8, 16, 9, 2, 3, 10,
17, 24, 32, 25, 18, 11, 4, 5,
12, 19, 26, 33, 40, 48, 41, 34,
27, 20, 13, 6, 7, 14, 21, 28,
35, 42, 49, 56, 57, 50, 43, 36,
29, 22, 15, 23, 30, 37, 44, 51,
58, 59, 52, 45, 38, 31, 39, 46,
53, 60, 61, 54, 47, 55, 62, 63
]
bitmasks = [ # (1 shr n) - 1
0.uint32, 1, 3, 7, 15, 31, 63, 127, 255, 511,
1023, 2047, 4095, 8191, 16383, 32767, 65535
]
biases = [ # (-1 shl n) + 1
0.int32, -1, -3, -7, -15, -31, -63, -127, -255,
-511, -1023, -2047, -4095, -8191, -16383, -32767
]
type
Component = object
id, samplingFactors, quantizationTable: uint8
Huffman = object
symbols: array[256, uint8]
deltas: array[17, int]
maxCodes: array[18, int]
Jpg = object
Component = object
id, quantizationTable: uint8
verticalSamplingFactor, horizontalSamplingFactor: int
width, height: int
components: array[3, Component]
w2, h2: int # TODO what are these
huffmanDC, huffmanAC: int
dcPred: int
DecoderState = object
buffer: seq[uint8]
pos, bitCount: int
bits: uint32
imageHeight, imageWidth: int
quantizationTables: array[4, array[64, uint8]]
huffmanTables: array[2, array[4, Huffman]] # 0 = DC, 1 = AC
components: array[3, Component]
maxHorizontalSamplingFactor, maxVerticalSamplingFactor: int
mcuWidth, mcuHeight, numMcuWide, numMcuHigh: int
componentOrder: array[3, int]
hitEOI: bool
template failInvalid() =
raise newException(PixieError, "Invalid JPG buffer, unable to load")
proc readSegmentLen(data: seq[uint8], pos: int): int =
if pos + 2 > data.len:
proc readUint8(state: var DecoderState): uint8 {.inline.} =
if state.pos >= state.buffer.len:
failInvalid()
result = state.buffer[state.pos]
inc state.pos
let segmentLen = data.readUint16(pos).swap().int
if pos + segmentLen > data.len:
proc readUint16be(state: var DecoderState): uint16 =
(state.readUint8().uint16 shl 8) or state.readUint8()
proc skipBytes(state: var DecoderState, n: int) =
if state.pos + n > state.buffer.len:
failInvalid()
state.pos += n
segmentLen
proc seekToMarker(state: var DecoderState): uint8 =
var x = state.readUint8()
while x != 0xFF:
x = state.readUint8()
while x == 0xFF:
x = state.readUint8()
x
proc skipSegment(data: seq[uint8], pos: var int) {.inline.} =
pos += readSegmentLen(data, pos)
proc decodeSOF(jpg: var Jpg, data: seq[uint8], pos: var int) =
let segmentLen = readSegmentLen(data, pos)
pos += 2
if pos + 6 > data.len:
failInvalid()
let
precision = data[pos].int
height = data.readUint16(pos + 1).swap().int
width = data.readUint16(pos + 3).swap().int
components = data[pos + 5].int
pos += 6
if width <= 0:
raise newException(PixieError, "Invalid JPG width")
if height <= 0:
raise newException(PixieError, "Invalid JPG height")
if precision != 8:
raise newException(PixieError, "Unsupported JPG bit depth")
if components != 3:
raise newException(PixieError, "Unsupported JPG channel count")
jpg.width = width
jpg.height = height
if 8 + components * 3 != segmentLen:
failInvalid()
for i in 0 ..< 3:
jpg.components[i] = Component(
id: data[pos],
samplingFactors: data[pos + 1],
quantizationTable: data[pos + 2]
)
pos += 3
proc decodeDHT(data: seq[uint8], pos: var int) =
# skipSegment(data, pos)
# debugEcho pos
let
segmentLen = readSegmentLen(data, pos)
stop = pos + segmentLen
pos += 2
while stop - pos >= 17:
let info = data[pos]
if (info and 0b11100000) != 0:
failInvalid()
var counts: array[17, int]
for codeLen in 1 .. 16:
counts[codeLen] = data[pos + codeLen].int
debugEcho counts
pos += 17
for codeLen in 1 .. 16:
discard
break
pos = stop
proc decodeDQT(jpg: var Jpg, data: seq[uint8], pos: var int) =
let
segmentLen = readSegmentLen(data, pos)
stop = pos + segmentLen
pos += 2
while stop - pos >= 65:
proc decodeDQT(state: var DecoderState) =
var len = state.readUint16be() - 2
while len > 0:
let
info = data[pos]
qt = info and 0b00001111
precision = info and 0b11110000
if qt > 3:
failInvalid()
info = state.readUint8()
table = info and 15
precision = info shr 4
if precision != 0:
raise newException(
PixieError, "Unsuppored JPG qantization table precision"
)
inc pos
for i in 0 ..< 64:
jpg.quantizationTables[qt][i] = data[pos + i]
pos += 64
proc decodeSOS(data: seq[uint8], pos: var int) =
let segmentLen = readSegmentLen(data, pos)
pos += 2
if segmentLen != 12:
failInvalid()
let components = data[pos]
if components != 3:
raise newException(PixieError, "Unsupported JPG channel count")
for i in 0 ..< 3:
discard
pos += 10
pos += 3 # Skip 3 more bytes
while true:
if pos >= data.len:
if table > 3:
failInvalid()
if data[pos] == 0xFF:
if pos + 1 == data.len:
failInvalid()
if data[pos + 1] == 0xD9: # End of Image:
pos += 2
for i in 0 ..< 64:
state.quantizationTables[table][deZigZag[i]] = state.readUint8()
len -= 65
if len != 0:
failInvalid()
proc decodeDHT(state: var DecoderState) =
proc buildHuffman(huffman: var Huffman, counts: array[16, uint8]) =
var sizes: array[257, uint8]
block:
var k: int
for i in 0.uint8 ..< 16:
for j in 0.uint8 ..< counts[i]:
sizes[k] = i + 1
inc k
sizes[k] = 0
var code, j: int
for i in 1.uint8 .. 16:
huffman.deltas[i] = j - code
if sizes[j] == i:
while sizes[j] == i:
inc code
inc j
if code - 1 >= 1 shl i:
failInvalid()
huffman.maxCodes[i] = code shl (16 - i)
code = code shl 1
huffman.maxCodes[17] = int.high
var len = state.readUint16be() - 2
while len > 0:
let
info = state.readUint8()
table = info and 15
tableCurrent = info shr 4 # DC or AC
if tableCurrent > 1 or table > 3:
failInvalid()
var
counts: array[16, uint8]
numSymbols: uint8
for i in 0 ..< 16:
counts[i] = state.readUint8()
numSymbols += counts[i]
len -= 17
state.huffmanTables[tableCurrent][table].buildHuffman(counts)
for i in 0.uint8 ..< numSymbols:
state.huffmanTables[tableCurrent][table].symbols[i] = state.readUint8()
len -= numSymbols
if len != 0:
failInvalid()
proc decodeSegment(state: var DecoderState, marker: uint8) =
case marker:
of 0xDB: # Define Quantanization Table(s)
state.decodeDQT()
of 0xC4: # Define Huffman Tables
state.decodeDHT()
else:
if (marker >= 0xE0 and marker <= 0xEF) or marker == 0xFE:
let len = state.readUint16be() - 2
state.skipBytes(len.int)
else:
raise newException(
PixieError, "Unexpected JPG segment marker " & toHex(marker)
)
proc decodeSOF(state: var DecoderState) =
var len = state.readUint16be() - 2
let precision = state.readUint8()
if precision != 8:
raise newException(PixieError, "Unsupported JPG bit depth, must be 8")
state.imageHeight = state.readUint16be().int
state.imageWidth = state.readUint16be().int
if state.imageHeight == 0 or state.imageWidth == 0:
failInvalid()
let components = state.readUint8()
if components != 3:
raise newException(PixieError, "Unsupported JPG component count, must be 3")
len -= 15
if len != 0:
failInvalid()
for i in 0 ..< 3:
state.components[i].id = state.readUint8()
let
info = state.readUint8()
vertical = info and 15
horizontal = info shr 4
quantizationTable = state.readUint8()
if quantizationTable > 3:
failInvalid()
if vertical == 0 or vertical > 4 or horizontal == 0 or horizontal > 4:
failInvalid()
state.components[i].verticalSamplingFactor = vertical.int
state.components[i].horizontalSamplingFactor = horizontal.int
state.components[i].quantizationTable = quantizationTable
for i in 0 ..< 3:
state.maxVerticalSamplingFactor = max(
state.maxVerticalSamplingFactor,
state.components[i].verticalSamplingFactor
)
state.maxHorizontalSamplingFactor = max(
state.maxHorizontalSamplingFactor,
state.components[i].horizontalSamplingFactor
)
state.mcuWidth = state.maxHorizontalSamplingFactor * 8
state.mcuHeight = state.maxVerticalSamplingFactor * 8
state.numMcuWide =
(state.imageWidth + state.mcuWidth - 1) div state.mcuWidth
state.numMcuHigh =
(state.imageHeight + state.mcuHeight - 1) div state.mcuHeight
for i in 0 ..< 3:
state.components[i].width = (
state.imageWidth *
state.components[i].horizontalSamplingFactor +
state.maxHorizontalSamplingFactor - 1
) div state.maxHorizontalSamplingFactor
state.components[i].height = (
state.imageHeight *
state.components[i].verticalSamplingFactor +
state.maxVerticalSamplingFactor - 1
) div state.maxVerticalSamplingFactor
state.components[i].w2 =
state.numMcuWide * state.components[i].horizontalSamplingFactor * 8
state.components[i].h2 =
state.numMcuHigh * state.components[i].verticalSamplingFactor * 8
proc decodeSOS(state: var DecoderState) =
var len = state.readUint16be() - 2
let components = state.readUint8()
if components != 3:
raise newException(PixieError, "Unsupported JPG component count, must be 3")
for i in 0 ..< 3:
let
id = state.readUint8()
info = state.readUint8()
huffmanAC = info and 15
huffmanDC = info shr 4
if huffmanAC > 3 or huffmanDC > 3:
failInvalid()
var component: int
while component < 3:
if state.components[component].id == id:
break
elif data[pos + 1] == 0x00:
discard # Skip the 0x00 byte
inc component
if component == 3:
failInvalid() # Not found
state.components[component].huffmanAC = huffmanAC.int
state.components[component].huffmanDC = huffmanDC.int
state.componentOrder[i] = component
# Skip 3 bytes
for i in 0 ..< 3:
discard state.readUint8()
len -= 10
if len != 0:
failInvalid()
proc fillBits(state: var DecoderState) =
while state.bitCount <= 24:
let b = if state.hitEOI: 0.uint32 else: state.readUint8().uint32
if b == 0xFF:
let c = state.readUint8()
if c == 0:
discard
elif c == 0xD9:
state.hitEOI = true
else:
failInvalid()
else:
discard
state.bits = state.bits or (b shl (24 - state.bitCount))
state.bitCount += 8
inc pos
proc huffmanDecode(state: var DecoderState, tableCurrent, table: int): uint8 =
if state.bitCount < 16:
state.fillBits()
var
tmp = (state.bits shr 16).int
i = 1
while i < state.huffmanTables[tableCurrent][table].maxCodes.len:
if tmp < state.huffmanTables[tableCurrent][table].maxCodes[i]:
break
inc i
if i == 17 or i > state.bitCount:
failInvalid()
let symbolId = (state.bits shr (32 - i)).int +
state.huffmanTables[tableCurrent][table].deltas[i]
result = state.huffmanTables[tableCurrent][table].symbols[symbolId]
state.bits = state.bits shl i
state.bitCount -= i
echo "post-decode: ", state.bitCount, " ", state.bits
template lrot(value: uint32, shift: int): uint32 =
(value shl shift) or (value shr (32 - shift))
proc extendReceive(state: var DecoderState, t: int): int =
if state.bitCount < t:
state.fillBits()
let sign = (state.bits shr 31).int32
var k = lrot(state.bits, t)
state.bits = k and (not bitmasks[t])
k = k and bitmasks[t]
state.bitCount -= t
result = k.int + (biases[t] and (not sign))
echo "sgn: ", sign
echo "post: ", state.bits
proc decodeImageBlock(state: var DecoderState, component: int): array[64, int16] =
let t = state.huffmanDecode(0, state.components[component].huffmanDC).int
if t < 0:
failInvalid()
echo "t: ", t
let
diff = if t == 0: 0 else: state.extendReceive(t)
dc = state.components[component].dcPred + diff
state.components[component].dcPred = dc
result[0] = (dc * state.quantizationTables[
state.components[component].quantizationTable
][0].int).int16
echo "data[0]: ", result[0]
proc decodeImageData(state: var DecoderState) =
for y in 0 ..< state.numMcuHigh:
for x in 0 ..< state.numMcuWide:
for component in state.componentOrder:
for j in 0 ..< state.components[component].verticalSamplingFactor:
for i in 0 ..< state.components[component].horizontalSamplingFactor:
let data = state.decodeImageBlock(component)
return
proc decodeJpg*(data: seq[uint8]): Image =
## Decodes the JPEG into an Image.
if data.len < 4:
var state = DecoderState()
state.buffer = data
if state.readUint8() != 0xFF or state.readUint8() != 0xD8: # SOI
failInvalid()
if data.readUint16(0) != cast[uint16](jpgStartOfImage):
failInvalid()
var marker = state.seekToMarker()
while marker != 0xC0: # SOF
state.decodeSegment(marker)
marker = state.seekToMarker()
var
jpg: Jpg
pos: int
while true:
if pos + 2 > data.len:
failInvalid()
state.decodeSOF()
let marker = [data[pos], data[pos + 1]]
pos += 2
marker = state.seekToMarker()
while marker != 0xDA: # Start of Scan
state.decodeSegment(marker)
marker = state.seekToMarker()
if marker[0] != 0xFF:
failInvalid()
state.decodeSOS()
case marker[1]:
of 0xD8: # Start of Image
discard
of 0xC0: # Start of Frame
jpg.decodeSOF(data, pos)
of 0xC2: # Start of Frame
raise newException(PixieError, "Progressive JPG not supported")
of 0xC4: # Define Huffman Tables
decodeDHT(data, pos)
of 0xDB: # Define Quantanization Table(s)
jpg.decodeDQT(data, pos)
# of 0xDD: # Define Restart Interval
of 0xDA: # Start of Scan
decodeSOS(data, pos)
break
of 0xFE: # Comment
skipSegment(data, pos)
of 0xD9: # End of Image
failInvalid() # Not expected here
else:
if (marker[1] and 0xF0) == 0xE0:
# Skip APPn segments
skipSegment(data, pos)
else:
raise newException(PixieError, "Unsupported JPG segment")
state.decodeImageData()
raise newException(PixieError, "Decoding JPG not supported yet")
# raise newException(PixieError, "Decoding JPG not supported yet")
proc decodeJpg*(data: string): Image {.inline.} =
decodeJpg(cast[seq[uint8]](data))

View file

@ -1,7 +1,7 @@
import pixie/fileformats/png, stb_image/read as stbi, stb_image/write as stbr,
fidget/opengl/perf, nimPNG
let data = readFile("tests/data/lenna.png")
let data = readFile("tests/images/lenna.png")
timeIt "pixie decode":
for i in 0 ..< 100:

View file

@ -2,4 +2,4 @@ import pixie/fileformats/jpg
let original = readFile("tests/images/jpg/jpeg420exif.jpg")
# discard decodeJpg(original)
discard decodeJpg(original)