/tools/pic_compress_8x8.py
#!/usr/bin/env python3
import sys, struct, math, argparse, copy, operator
from enum import Enum
from functools import reduce
from PIL import Image

class Chunk:
    def __init__(self, words):
        if len(words) != 16:
            raise Exception('chunk must be 16 words')
        self.words = words

    def reduce_8(self):
        reduced = []
        for y in range(0, 16, 2):
            x = self.words[y] | self.words[y + 1]
            x = (x & 0xAAAA) >> 1 | (x & 0x5555)
            reduced.append(x)
        return reduced

    def reduce_4(self):
        reduced = []
        words = self.reduce_8()
        for y in range(0, 8, 2):
            x = words[y] | words[y + 1]
            x = (x >> 2) | x
            reduced.append(x)
        return reduced

    def reduce_2(self):
        a = self.words[0] | self.words[1] | self.words[2] | self.words[3] | self.words[4] | self.words[5] | self.words[6] | self.words[7]
        b = self.words[8] | self.words[9] | self.words[10] | self.words[11] | self.words[12] | self.words[13] | self.words[14] | self.words[15]
        return ((a & 0xFF00) == 0) | ((a & 0x00FF) == 0) << 1 | ((b & 0xFF00) == 0) << 2 | ((b & 0x00FF) == 0) << 3

    def pixel_count(self):
        return sum(map(int.bit_count, self.words))

    def pixels_changed(self, other):
        return sum(map(lambda v: (v[0] ^ v[1]).bit_count(), zip(self.words, other.words)))

    def similarity(self, other):
        s = 0
        if self.reduce_8() == other.reduce_8():
            s += 32
        if self.reduce_4() == other.reduce_4():
            s += 16
        if self.reduce_2() == other.reduce_2():
            s += 8
        s -= abs(self.pixel_count() - other.pixel_count()) // 16
        s -= self.pixels_changed(other) // 16
        return s

    def merge(self, other):
        for i in range(0, 16):
            if i % 2 == 0:
                self.words[i] = (self.words[i] & 0xAAAA) | (other.words[i] & 0x5555)
            else:
                self.words[i] = (self.words[i] & 0x5555) | (other.words[i] & 0xAAAA)

    def __eq__(self, other):
        return self.words == other.words

    def __hash__(self):
        return hash((self.words[0], self.words[1], self.words[2], self.words[3], self.words[4], self.words[5], self.words[6], self.words[7], self.words[8], self.words[9], self.words[10], self.words[11], self.words[12], self.words[13], self.words[14], self.words[15]))

    def __str__(self):
        return "\n" + "\n".join(map(lambda x: f'{x:016b}', self.words))

known_chunks = {
    0xFFF: Chunk([0xFFFF] * 16),
    0xFFE: Chunk([0x0000] * 16),
}
known_chunks_rev = {
    Chunk([0xFFFF] * 16): 0xFFF,
    Chunk([0x0000] * 16): 0xFFE,
}

class ImgType(Enum):
    Normal = 1
    Mask = 2

class SubImage:
    def __init__(self, im, type):
        self.chunks = []
        self.indexes = []
        self.type = type
        for y in range(0, im.height, 16):
            for x in range(0, im.width, 16):
                words = []
                for cy in range(0, 16):
                    w = 0
                    for cx in range(0, 16):
                        p = im.getpixel((x + cx, y + cy))
                        w >>= 1
                        w |= (int(p != 0) << 15)
                    words.append(w)
                c = Chunk(words)
                try:
                    i = self.chunk_index(c)
                except ValueError:
                    i = len(self.chunks)
                    self.chunks.append(c)
                self.indexes.append(i)

    def chunk_at(self, i):
        if i in known_chunks:
            return known_chunks[i]
        else:
            return self.chunks[i]

    def chunk_index(self, c):
        if c in known_chunks_rev:
            return known_chunks_rev[c]
        else:
            return self.chunks.index(c)

    def histogram(self):
        # Generate chunk histogram
        histogram = {}
        for i in self.indexes:
            if i > 0xFF0:  # known chunk
                continue
            if i in histogram:
                histogram[i] += 1
            else:
                histogram[i] = 1
        histogram = list(histogram.items())
        histogram.sort(key=lambda v: v[1])
        return histogram

    def dump_histogram(self):
        histogram = self.histogram()
        print('Chunk              count')
        for p in histogram:
            c = p[0]
            n = p[1]
            print(f'{c} {n}')

    def replace_index(self, idx1, idx2):
        for i in range(0, len(self.indexes)):
            if self.indexes[i] == idx1:
                self.indexes[i] = idx2

    def disk_size(self):
        return 3 + len(self.chunks) * 32 + 3 + math.ceil(len(self.indexes) / 2) * 3

    def deduplicate(self, similarity_threshold):
        histogram = self.histogram()
        dead_chunks = []
        # Start with the least common patterns
        for i in range(0, len(histogram) - 1):
            idx1 = histogram[i][0]
            # Then search backwards starting at the most common patterns
            a = self.chunks[idx1]
            similar_chunks = list(map(
                lambda v: (v[0], v[1], a.similarity(self.chunks[v[0]])),
                reversed(histogram[i + 1:])
            ))
            similar_chunks.sort(key=lambda v: v[2], reverse=True)
            if similar_chunks[0][2] < similarity_threshold:
                # Bad match, move on
                continue
            idx2 = similar_chunks[0][0]
            if args.merge:
                self.chunks[idx2].merge(self.chunks[idx1])
            self.replace_index(idx1, idx2)
            dead_chunks.append(idx1)

        # Reverse sort so we work from the back forward and don't have to adjust any
        # earlier indexes in the dead chunk list
        dead_chunks.sort(reverse=True)

        # Clean dead chunks and shift indexes
        for idx in dead_chunks:
            self.chunks = self.chunks[:idx] + self.chunks[idx+1:]
            # Adjust every index above idx down one (except known chunks)
            for i in range(0, len(self.indexes)):
                if self.indexes[i] > 0xFF0:
                    continue
                if self.indexes[i] == idx:
                    raise Exception('dead index found: ' + str(idx))
                if self.indexes[i] > idx:
                    self.indexes[i] -= 1

        return len(dead_chunks)

    def into_image(self, width, height):
        im = Image.new('1', (width, height))
        x = 0
        y = 0
        for idx in self.indexes:
            c = self.chunk_at(idx)
            for cy in range(0, 16):
                w = c.words[cy]
                for cx in range(0, 16):
                    im.putpixel((x + cx, y + cy), (w >> cx) & 1)
            x += 16
            if x == im.width:
                x = 0
                y += 16

        return im

    def write(self, f):
        # chunk list section, number of chunks
        base_type = 8 * self.type.value
        f.write(struct.pack("<BH", base_type + 0, len(self.chunks)))
        for c in self.chunks:
            for w in c.words:
                f.write(struct.pack("<H", w))
        # index section, width and height
        f.write(struct.pack("<B", base_type + 1))
        c = 0
        for i in range(0, len(self.indexes), 2):
            b1 = self.indexes[i] & 0xFF
            b2 = (self.indexes[i] & 0xF00) >> 8
            if i + 1 < len(self.indexes):
                b2 |= (self.indexes[i+1] & 0xF00) >> 4
                b3 = self.indexes[i+1] & 0xFF
            else:
                b3 = 0
            f.write(struct.pack("<BBB", b1, b2, b3))
            c += 3
        print(c, "packed index bytes")

class ChunkedImage:
    def __init__(self, im):
        if im.width % 16 != 0 or im.height % 16 != 0:
            raise Exception('width and height must be multiples of 16')
        self.width = im.width
        self.height = im.height
        self.subimages = [SubImage(im, ImgType.Normal)]

    def add_mask(self, im):
        if im.width != self.width or im.height != self.height:
            raise Exception('mask dimensions do not match image')
        self.subimages.append(SubImage(im, ImgType.Mask))
        
    def disk_size(self):
        return 4 + sum(map(lambda i: i.disk_size(), self.subimages))

    def write_to_file(self, filename):
        f = open(filename, 'wb')
        # magic
        f.write(b"\xA7ci\x00")
        chunk_width = int(self.width / 16)
        chunk_height = int(self.height / 16)
        f.write(struct.pack('<BB', chunk_width, chunk_height))
        for i in self.subimages:
            i.write(f)
        f.close()

    def into_image(self):
        base = self.subimages[0].into_image(self.width, self.height)
        if len(self.subimages) > 1:
            base = base.convert('RGBA')
            mask = self.subimages[1].into_image(self.width, self.height)
            transparent = Image.new('RGBA', (self.width, self.height), color='#FF000000')
            base = Image.composite(base, transparent, mask)
        return base

def parse_args():
    parser = argparse.ArgumentParser(prog='chunk-compress')
    parser.add_argument('input_file')
    parser.add_argument('output_file')
    parser.add_argument('-s', '--similarity', type=int, default=25, help='similarity rating')
    parser.add_argument('-t', '--target-size', type=int, help='decrease similarity incrementally to hit a target size')
    parser.add_argument('-m', '--merge', action='store_true', help='merge chunks instead of replacing them')
    parser.add_argument('--no-dedup', action='store_true', help='do not run deduplication step (identical and well-known chunks will still be deduplicated)')
    parser.add_argument('--mask', help='Add a mask image')
    parser.add_argument('--png', help='also output a PNG proofing image')

    return parser.parse_args()

args = parse_args()

im = Image.open(args.input_file)
ci = ChunkedImage(im)
if args.mask:
    cim = Image.open(args.mask)
    ci.add_mask(cim)

original_chunk_len = len(ci.subimages[0].chunks)
print(f'{original_chunk_len} distinct chunks')

if not args.no_dedup:
    if args.target_size:
        similarity = args.similarity
        while True:
            print('attempting deduplication with similarity', similarity)
            ci_temp = copy.deepcopy(ci)
            dead_chunk_len = ci_temp.subimages[0].deduplicate(similarity)
            if ci_temp.disk_size() <= args.target_size:
                ci = ci_temp
                break
            similarity -= 1
    else:
        print('deduplicating chunks with similarity', args.similarity)
        dead_chunk_len = ci.subimages[0].deduplicate(args.similarity)
    print('deduplicated', dead_chunk_len, 'chunks', f'({int(dead_chunk_len / original_chunk_len * 100):d}% reduction)')

print('size', ci.disk_size())

if args.png:
    im = ci.into_image()
    im.save(args.png)

ci.write_to_file(args.output_file)