jpg checkpoint
This commit is contained in:
parent
41edec4711
commit
a8d667f879
3 changed files with 347 additions and 172 deletions
|
@ -1,216 +1,391 @@
|
|||
import flatty/binny, pixie/common, pixie/images
|
||||
import pixie/common, pixie/images, strutils
|
||||
|
||||
# See https://github.com/nothings/stb/blob/master/stb_image.h
|
||||
# See http://www.vip.sugovica.hu/Sardi/kepnezo/JPEG%20File%20Layout%20and%20Format.htm
|
||||
|
||||
const
|
||||
jpgStartOfImage* = [0xFF.uint8, 0xD8]
|
||||
deZigZag = [
|
||||
0.uint8, 1, 8, 16, 9, 2, 3, 10,
|
||||
17, 24, 32, 25, 18, 11, 4, 5,
|
||||
12, 19, 26, 33, 40, 48, 41, 34,
|
||||
27, 20, 13, 6, 7, 14, 21, 28,
|
||||
35, 42, 49, 56, 57, 50, 43, 36,
|
||||
29, 22, 15, 23, 30, 37, 44, 51,
|
||||
58, 59, 52, 45, 38, 31, 39, 46,
|
||||
53, 60, 61, 54, 47, 55, 62, 63
|
||||
]
|
||||
bitmasks = [ # (1 shr n) - 1
|
||||
0.uint32, 1, 3, 7, 15, 31, 63, 127, 255, 511,
|
||||
1023, 2047, 4095, 8191, 16383, 32767, 65535
|
||||
]
|
||||
biases = [ # (-1 shl n) + 1
|
||||
0.int32, -1, -3, -7, -15, -31, -63, -127, -255,
|
||||
-511, -1023, -2047, -4095, -8191, -16383, -32767
|
||||
]
|
||||
|
||||
type
|
||||
Component = object
|
||||
id, samplingFactors, quantizationTable: uint8
|
||||
Huffman = object
|
||||
symbols: array[256, uint8]
|
||||
deltas: array[17, int]
|
||||
maxCodes: array[18, int]
|
||||
|
||||
Jpg = object
|
||||
Component = object
|
||||
id, quantizationTable: uint8
|
||||
verticalSamplingFactor, horizontalSamplingFactor: int
|
||||
width, height: int
|
||||
components: array[3, Component]
|
||||
w2, h2: int # TODO what are these
|
||||
huffmanDC, huffmanAC: int
|
||||
dcPred: int
|
||||
|
||||
DecoderState = object
|
||||
buffer: seq[uint8]
|
||||
pos, bitCount: int
|
||||
bits: uint32
|
||||
imageHeight, imageWidth: int
|
||||
quantizationTables: array[4, array[64, uint8]]
|
||||
huffmanTables: array[2, array[4, Huffman]] # 0 = DC, 1 = AC
|
||||
components: array[3, Component]
|
||||
maxHorizontalSamplingFactor, maxVerticalSamplingFactor: int
|
||||
mcuWidth, mcuHeight, numMcuWide, numMcuHigh: int
|
||||
componentOrder: array[3, int]
|
||||
hitEOI: bool
|
||||
|
||||
template failInvalid() =
|
||||
raise newException(PixieError, "Invalid JPG buffer, unable to load")
|
||||
|
||||
proc readSegmentLen(data: seq[uint8], pos: int): int =
|
||||
if pos + 2 > data.len:
|
||||
proc readUint8(state: var DecoderState): uint8 {.inline.} =
|
||||
if state.pos >= state.buffer.len:
|
||||
failInvalid()
|
||||
result = state.buffer[state.pos]
|
||||
inc state.pos
|
||||
|
||||
let segmentLen = data.readUint16(pos).swap().int
|
||||
if pos + segmentLen > data.len:
|
||||
proc readUint16be(state: var DecoderState): uint16 =
|
||||
(state.readUint8().uint16 shl 8) or state.readUint8()
|
||||
|
||||
proc skipBytes(state: var DecoderState, n: int) =
|
||||
if state.pos + n > state.buffer.len:
|
||||
failInvalid()
|
||||
state.pos += n
|
||||
|
||||
segmentLen
|
||||
proc seekToMarker(state: var DecoderState): uint8 =
|
||||
var x = state.readUint8()
|
||||
while x != 0xFF:
|
||||
x = state.readUint8()
|
||||
while x == 0xFF:
|
||||
x = state.readUint8()
|
||||
x
|
||||
|
||||
proc skipSegment(data: seq[uint8], pos: var int) {.inline.} =
|
||||
pos += readSegmentLen(data, pos)
|
||||
|
||||
proc decodeSOF(jpg: var Jpg, data: seq[uint8], pos: var int) =
|
||||
let segmentLen = readSegmentLen(data, pos)
|
||||
pos += 2
|
||||
|
||||
if pos + 6 > data.len:
|
||||
failInvalid()
|
||||
|
||||
let
|
||||
precision = data[pos].int
|
||||
height = data.readUint16(pos + 1).swap().int
|
||||
width = data.readUint16(pos + 3).swap().int
|
||||
components = data[pos + 5].int
|
||||
|
||||
pos += 6
|
||||
|
||||
if width <= 0:
|
||||
raise newException(PixieError, "Invalid JPG width")
|
||||
|
||||
if height <= 0:
|
||||
raise newException(PixieError, "Invalid JPG height")
|
||||
|
||||
if precision != 8:
|
||||
raise newException(PixieError, "Unsupported JPG bit depth")
|
||||
|
||||
if components != 3:
|
||||
raise newException(PixieError, "Unsupported JPG channel count")
|
||||
|
||||
jpg.width = width
|
||||
jpg.height = height
|
||||
|
||||
if 8 + components * 3 != segmentLen:
|
||||
failInvalid()
|
||||
|
||||
for i in 0 ..< 3:
|
||||
jpg.components[i] = Component(
|
||||
id: data[pos],
|
||||
samplingFactors: data[pos + 1],
|
||||
quantizationTable: data[pos + 2]
|
||||
)
|
||||
pos += 3
|
||||
|
||||
proc decodeDHT(data: seq[uint8], pos: var int) =
|
||||
# skipSegment(data, pos)
|
||||
# debugEcho pos
|
||||
let
|
||||
segmentLen = readSegmentLen(data, pos)
|
||||
stop = pos + segmentLen
|
||||
pos += 2
|
||||
|
||||
while stop - pos >= 17:
|
||||
let info = data[pos]
|
||||
if (info and 0b11100000) != 0:
|
||||
failInvalid()
|
||||
|
||||
var counts: array[17, int]
|
||||
for codeLen in 1 .. 16:
|
||||
counts[codeLen] = data[pos + codeLen].int
|
||||
|
||||
debugEcho counts
|
||||
|
||||
pos += 17
|
||||
|
||||
for codeLen in 1 .. 16:
|
||||
discard
|
||||
|
||||
break
|
||||
|
||||
pos = stop
|
||||
|
||||
proc decodeDQT(jpg: var Jpg, data: seq[uint8], pos: var int) =
|
||||
let
|
||||
segmentLen = readSegmentLen(data, pos)
|
||||
stop = pos + segmentLen
|
||||
pos += 2
|
||||
|
||||
while stop - pos >= 65:
|
||||
proc decodeDQT(state: var DecoderState) =
|
||||
var len = state.readUint16be() - 2
|
||||
while len > 0:
|
||||
let
|
||||
info = data[pos]
|
||||
qt = info and 0b00001111
|
||||
precision = info and 0b11110000
|
||||
|
||||
if qt > 3:
|
||||
failInvalid()
|
||||
info = state.readUint8()
|
||||
table = info and 15
|
||||
precision = info shr 4
|
||||
|
||||
if precision != 0:
|
||||
raise newException(
|
||||
PixieError, "Unsuppored JPG qantization table precision"
|
||||
)
|
||||
|
||||
inc pos
|
||||
|
||||
for i in 0 ..< 64:
|
||||
jpg.quantizationTables[qt][i] = data[pos + i]
|
||||
|
||||
pos += 64
|
||||
|
||||
proc decodeSOS(data: seq[uint8], pos: var int) =
|
||||
let segmentLen = readSegmentLen(data, pos)
|
||||
pos += 2
|
||||
|
||||
if segmentLen != 12:
|
||||
failInvalid()
|
||||
|
||||
let components = data[pos]
|
||||
if components != 3:
|
||||
raise newException(PixieError, "Unsupported JPG channel count")
|
||||
|
||||
for i in 0 ..< 3:
|
||||
discard
|
||||
|
||||
pos += 10
|
||||
pos += 3 # Skip 3 more bytes
|
||||
|
||||
while true:
|
||||
if pos >= data.len:
|
||||
if table > 3:
|
||||
failInvalid()
|
||||
|
||||
if data[pos] == 0xFF:
|
||||
if pos + 1 == data.len:
|
||||
failInvalid()
|
||||
if data[pos + 1] == 0xD9: # End of Image:
|
||||
pos += 2
|
||||
for i in 0 ..< 64:
|
||||
state.quantizationTables[table][deZigZag[i]] = state.readUint8()
|
||||
|
||||
len -= 65
|
||||
|
||||
if len != 0:
|
||||
failInvalid()
|
||||
|
||||
proc decodeDHT(state: var DecoderState) =
|
||||
proc buildHuffman(huffman: var Huffman, counts: array[16, uint8]) =
|
||||
var sizes: array[257, uint8]
|
||||
block:
|
||||
var k: int
|
||||
for i in 0.uint8 ..< 16:
|
||||
for j in 0.uint8 ..< counts[i]:
|
||||
sizes[k] = i + 1
|
||||
inc k
|
||||
sizes[k] = 0
|
||||
|
||||
var code, j: int
|
||||
for i in 1.uint8 .. 16:
|
||||
huffman.deltas[i] = j - code
|
||||
if sizes[j] == i:
|
||||
while sizes[j] == i:
|
||||
inc code
|
||||
inc j
|
||||
if code - 1 >= 1 shl i:
|
||||
failInvalid()
|
||||
huffman.maxCodes[i] = code shl (16 - i)
|
||||
code = code shl 1
|
||||
huffman.maxCodes[17] = int.high
|
||||
|
||||
var len = state.readUint16be() - 2
|
||||
while len > 0:
|
||||
let
|
||||
info = state.readUint8()
|
||||
table = info and 15
|
||||
tableCurrent = info shr 4 # DC or AC
|
||||
|
||||
if tableCurrent > 1 or table > 3:
|
||||
failInvalid()
|
||||
|
||||
var
|
||||
counts: array[16, uint8]
|
||||
numSymbols: uint8
|
||||
for i in 0 ..< 16:
|
||||
counts[i] = state.readUint8()
|
||||
numSymbols += counts[i]
|
||||
|
||||
len -= 17
|
||||
|
||||
state.huffmanTables[tableCurrent][table].buildHuffman(counts)
|
||||
|
||||
for i in 0.uint8 ..< numSymbols:
|
||||
state.huffmanTables[tableCurrent][table].symbols[i] = state.readUint8()
|
||||
|
||||
len -= numSymbols
|
||||
|
||||
if len != 0:
|
||||
failInvalid()
|
||||
|
||||
proc decodeSegment(state: var DecoderState, marker: uint8) =
|
||||
case marker:
|
||||
of 0xDB: # Define Quantanization Table(s)
|
||||
state.decodeDQT()
|
||||
of 0xC4: # Define Huffman Tables
|
||||
state.decodeDHT()
|
||||
else:
|
||||
if (marker >= 0xE0 and marker <= 0xEF) or marker == 0xFE:
|
||||
let len = state.readUint16be() - 2
|
||||
state.skipBytes(len.int)
|
||||
else:
|
||||
raise newException(
|
||||
PixieError, "Unexpected JPG segment marker " & toHex(marker)
|
||||
)
|
||||
|
||||
proc decodeSOF(state: var DecoderState) =
|
||||
var len = state.readUint16be() - 2
|
||||
|
||||
let precision = state.readUint8()
|
||||
if precision != 8:
|
||||
raise newException(PixieError, "Unsupported JPG bit depth, must be 8")
|
||||
|
||||
state.imageHeight = state.readUint16be().int
|
||||
state.imageWidth = state.readUint16be().int
|
||||
|
||||
if state.imageHeight == 0 or state.imageWidth == 0:
|
||||
failInvalid()
|
||||
|
||||
let components = state.readUint8()
|
||||
if components != 3:
|
||||
raise newException(PixieError, "Unsupported JPG component count, must be 3")
|
||||
|
||||
len -= 15
|
||||
|
||||
if len != 0:
|
||||
failInvalid()
|
||||
|
||||
for i in 0 ..< 3:
|
||||
state.components[i].id = state.readUint8()
|
||||
let
|
||||
info = state.readUint8()
|
||||
vertical = info and 15
|
||||
horizontal = info shr 4
|
||||
quantizationTable = state.readUint8()
|
||||
|
||||
if quantizationTable > 3:
|
||||
failInvalid()
|
||||
|
||||
if vertical == 0 or vertical > 4 or horizontal == 0 or horizontal > 4:
|
||||
failInvalid()
|
||||
|
||||
state.components[i].verticalSamplingFactor = vertical.int
|
||||
state.components[i].horizontalSamplingFactor = horizontal.int
|
||||
state.components[i].quantizationTable = quantizationTable
|
||||
|
||||
for i in 0 ..< 3:
|
||||
state.maxVerticalSamplingFactor = max(
|
||||
state.maxVerticalSamplingFactor,
|
||||
state.components[i].verticalSamplingFactor
|
||||
)
|
||||
state.maxHorizontalSamplingFactor = max(
|
||||
state.maxHorizontalSamplingFactor,
|
||||
state.components[i].horizontalSamplingFactor
|
||||
)
|
||||
|
||||
state.mcuWidth = state.maxHorizontalSamplingFactor * 8
|
||||
state.mcuHeight = state.maxVerticalSamplingFactor * 8
|
||||
state.numMcuWide =
|
||||
(state.imageWidth + state.mcuWidth - 1) div state.mcuWidth
|
||||
state.numMcuHigh =
|
||||
(state.imageHeight + state.mcuHeight - 1) div state.mcuHeight
|
||||
|
||||
for i in 0 ..< 3:
|
||||
state.components[i].width = (
|
||||
state.imageWidth *
|
||||
state.components[i].horizontalSamplingFactor +
|
||||
state.maxHorizontalSamplingFactor - 1
|
||||
) div state.maxHorizontalSamplingFactor
|
||||
state.components[i].height = (
|
||||
state.imageHeight *
|
||||
state.components[i].verticalSamplingFactor +
|
||||
state.maxVerticalSamplingFactor - 1
|
||||
) div state.maxVerticalSamplingFactor
|
||||
|
||||
state.components[i].w2 =
|
||||
state.numMcuWide * state.components[i].horizontalSamplingFactor * 8
|
||||
state.components[i].h2 =
|
||||
state.numMcuHigh * state.components[i].verticalSamplingFactor * 8
|
||||
|
||||
proc decodeSOS(state: var DecoderState) =
|
||||
var len = state.readUint16be() - 2
|
||||
|
||||
let components = state.readUint8()
|
||||
if components != 3:
|
||||
raise newException(PixieError, "Unsupported JPG component count, must be 3")
|
||||
|
||||
for i in 0 ..< 3:
|
||||
let
|
||||
id = state.readUint8()
|
||||
info = state.readUint8()
|
||||
huffmanAC = info and 15
|
||||
huffmanDC = info shr 4
|
||||
|
||||
if huffmanAC > 3 or huffmanDC > 3:
|
||||
failInvalid()
|
||||
|
||||
var component: int
|
||||
while component < 3:
|
||||
if state.components[component].id == id:
|
||||
break
|
||||
elif data[pos + 1] == 0x00:
|
||||
discard # Skip the 0x00 byte
|
||||
inc component
|
||||
if component == 3:
|
||||
failInvalid() # Not found
|
||||
|
||||
state.components[component].huffmanAC = huffmanAC.int
|
||||
state.components[component].huffmanDC = huffmanDC.int
|
||||
state.componentOrder[i] = component
|
||||
|
||||
# Skip 3 bytes
|
||||
for i in 0 ..< 3:
|
||||
discard state.readUint8()
|
||||
|
||||
len -= 10
|
||||
|
||||
if len != 0:
|
||||
failInvalid()
|
||||
|
||||
proc fillBits(state: var DecoderState) =
|
||||
while state.bitCount <= 24:
|
||||
let b = if state.hitEOI: 0.uint32 else: state.readUint8().uint32
|
||||
if b == 0xFF:
|
||||
let c = state.readUint8()
|
||||
if c == 0:
|
||||
discard
|
||||
elif c == 0xD9:
|
||||
state.hitEOI = true
|
||||
else:
|
||||
failInvalid()
|
||||
else:
|
||||
discard
|
||||
state.bits = state.bits or (b shl (24 - state.bitCount))
|
||||
state.bitCount += 8
|
||||
|
||||
inc pos
|
||||
proc huffmanDecode(state: var DecoderState, tableCurrent, table: int): uint8 =
|
||||
if state.bitCount < 16:
|
||||
state.fillBits()
|
||||
|
||||
var
|
||||
tmp = (state.bits shr 16).int
|
||||
i = 1
|
||||
while i < state.huffmanTables[tableCurrent][table].maxCodes.len:
|
||||
if tmp < state.huffmanTables[tableCurrent][table].maxCodes[i]:
|
||||
break
|
||||
inc i
|
||||
|
||||
if i == 17 or i > state.bitCount:
|
||||
failInvalid()
|
||||
|
||||
let symbolId = (state.bits shr (32 - i)).int +
|
||||
state.huffmanTables[tableCurrent][table].deltas[i]
|
||||
result = state.huffmanTables[tableCurrent][table].symbols[symbolId]
|
||||
|
||||
state.bits = state.bits shl i
|
||||
state.bitCount -= i
|
||||
|
||||
echo "post-decode: ", state.bitCount, " ", state.bits
|
||||
|
||||
template lrot(value: uint32, shift: int): uint32 =
|
||||
(value shl shift) or (value shr (32 - shift))
|
||||
|
||||
proc extendReceive(state: var DecoderState, t: int): int =
|
||||
if state.bitCount < t:
|
||||
state.fillBits()
|
||||
|
||||
let sign = (state.bits shr 31).int32
|
||||
var k = lrot(state.bits, t)
|
||||
state.bits = k and (not bitmasks[t])
|
||||
k = k and bitmasks[t]
|
||||
state.bitCount -= t
|
||||
result = k.int + (biases[t] and (not sign))
|
||||
echo "sgn: ", sign
|
||||
echo "post: ", state.bits
|
||||
|
||||
proc decodeImageBlock(state: var DecoderState, component: int): array[64, int16] =
|
||||
let t = state.huffmanDecode(0, state.components[component].huffmanDC).int
|
||||
if t < 0:
|
||||
failInvalid()
|
||||
|
||||
echo "t: ", t
|
||||
|
||||
let
|
||||
diff = if t == 0: 0 else: state.extendReceive(t)
|
||||
dc = state.components[component].dcPred + diff
|
||||
state.components[component].dcPred = dc
|
||||
result[0] = (dc * state.quantizationTables[
|
||||
state.components[component].quantizationTable
|
||||
][0].int).int16
|
||||
|
||||
echo "data[0]: ", result[0]
|
||||
|
||||
proc decodeImageData(state: var DecoderState) =
|
||||
for y in 0 ..< state.numMcuHigh:
|
||||
for x in 0 ..< state.numMcuWide:
|
||||
for component in state.componentOrder:
|
||||
for j in 0 ..< state.components[component].verticalSamplingFactor:
|
||||
for i in 0 ..< state.components[component].horizontalSamplingFactor:
|
||||
let data = state.decodeImageBlock(component)
|
||||
return
|
||||
|
||||
proc decodeJpg*(data: seq[uint8]): Image =
|
||||
## Decodes the JPEG into an Image.
|
||||
|
||||
if data.len < 4:
|
||||
var state = DecoderState()
|
||||
state.buffer = data
|
||||
|
||||
if state.readUint8() != 0xFF or state.readUint8() != 0xD8: # SOI
|
||||
failInvalid()
|
||||
|
||||
if data.readUint16(0) != cast[uint16](jpgStartOfImage):
|
||||
failInvalid()
|
||||
var marker = state.seekToMarker()
|
||||
while marker != 0xC0: # SOF
|
||||
state.decodeSegment(marker)
|
||||
marker = state.seekToMarker()
|
||||
|
||||
var
|
||||
jpg: Jpg
|
||||
pos: int
|
||||
while true:
|
||||
if pos + 2 > data.len:
|
||||
failInvalid()
|
||||
state.decodeSOF()
|
||||
|
||||
let marker = [data[pos], data[pos + 1]]
|
||||
pos += 2
|
||||
marker = state.seekToMarker()
|
||||
while marker != 0xDA: # Start of Scan
|
||||
state.decodeSegment(marker)
|
||||
marker = state.seekToMarker()
|
||||
|
||||
if marker[0] != 0xFF:
|
||||
failInvalid()
|
||||
state.decodeSOS()
|
||||
|
||||
case marker[1]:
|
||||
of 0xD8: # Start of Image
|
||||
discard
|
||||
of 0xC0: # Start of Frame
|
||||
jpg.decodeSOF(data, pos)
|
||||
of 0xC2: # Start of Frame
|
||||
raise newException(PixieError, "Progressive JPG not supported")
|
||||
of 0xC4: # Define Huffman Tables
|
||||
decodeDHT(data, pos)
|
||||
of 0xDB: # Define Quantanization Table(s)
|
||||
jpg.decodeDQT(data, pos)
|
||||
# of 0xDD: # Define Restart Interval
|
||||
of 0xDA: # Start of Scan
|
||||
decodeSOS(data, pos)
|
||||
break
|
||||
of 0xFE: # Comment
|
||||
skipSegment(data, pos)
|
||||
of 0xD9: # End of Image
|
||||
failInvalid() # Not expected here
|
||||
else:
|
||||
if (marker[1] and 0xF0) == 0xE0:
|
||||
# Skip APPn segments
|
||||
skipSegment(data, pos)
|
||||
else:
|
||||
raise newException(PixieError, "Unsupported JPG segment")
|
||||
state.decodeImageData()
|
||||
|
||||
raise newException(PixieError, "Decoding JPG not supported yet")
|
||||
# raise newException(PixieError, "Decoding JPG not supported yet")
|
||||
|
||||
proc decodeJpg*(data: string): Image {.inline.} =
|
||||
decodeJpg(cast[seq[uint8]](data))
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pixie/fileformats/png, stb_image/read as stbi, stb_image/write as stbr,
|
||||
fidget/opengl/perf, nimPNG
|
||||
|
||||
let data = readFile("tests/data/lenna.png")
|
||||
let data = readFile("tests/images/lenna.png")
|
||||
|
||||
timeIt "pixie decode":
|
||||
for i in 0 ..< 100:
|
||||
|
|
|
@ -2,4 +2,4 @@ import pixie/fileformats/jpg
|
|||
|
||||
let original = readFile("tests/images/jpg/jpeg420exif.jpg")
|
||||
|
||||
# discard decodeJpg(original)
|
||||
discard decodeJpg(original)
|
||||
|
|
Loading…
Reference in a new issue