/tools/midi_convert.py
#!/usr/bin/env python3
import sys, functools, itertools
import miditoolkit

mobj = miditoolkit.midi.parser.MidiFile(sys.argv[1])
print(mobj)
bpm = int(mobj.tempo_changes[0].tempo)
print('bpm:', bpm)

# These are copied from the internet, therefore they're probably wrong
# Starts at C0
freq_table = [
    [ 16.35, 32.7,  65.41,  130.81, 261.63, 523.25, 1046.5,  2093,    4186    ],
    [ 17.32, 34.65, 69.3,   138.59, 277.18, 554.37, 1108.73, 2217.46, 4434.92 ],
    [ 18.35, 36.71, 73.42,  146.83, 293.66, 587.33, 1174.66, 2349.32, 4698.63 ],
    [ 19.45, 38.89, 77.78,  155.56, 311.13, 622.25, 1244.51, 2489,    4978    ],
    [ 20.6,  41.2,  82.41,  164.81, 329.63, 659.25, 1318.51, 2637,    5274    ],
    [ 21.83, 43.65, 87.31,  174.61, 349.23, 698.46, 1396.91, 2793.83, 5587.65 ],
    [ 23.12, 46.25, 92.5,   185,    369.99, 739.99, 1479.98, 2959.96, 5919.91 ],
    [ 24.5,  49,    98,     196,    392,    783.99, 1567.98, 3135.96, 6271.93 ],
    [ 25.96, 51.91, 103.83, 207.65, 415.3,  830.61, 1661.22, 3322.44, 6644.88 ],
    [ 27.5,  55,    110,    220,    440,    880,    1760,    3520,    7040    ],
    [ 29.14, 58.27, 116.54, 233.08, 466.16, 932.33, 1864.66, 3729.31, 7458.62 ],
    [ 30.87, 61.74, 123.47, 246.94, 493.88, 987.77, 1975.53, 3951,    7902.13 ],
]

class SnEvents:
    def __init__(self, t, bytes):
        self.t = t
        self.bytes = bytes

def tick_convert(t):
    beats = t / mobj.ticks_per_beat
    s = 60 * beats / bpm
    return round(s * 50)

def tick_time(t):
    beats = t / mobj.ticks_per_beat
    s = 60 * beats / bpm
    return f'@{int(s / 60)}:{s % 60}s'

def wait(frames):
    buf = []
    # Subtract one because of how the player engine counts
    frames -= 1
    while frames > 126:
        buf.extend([0xFE])
        frames -= 126
    if frames > 0:
        buf.extend([0x80 | frames])
    return buf

def convert_track(t, target_track):
    events = []
    current_t = 0
    current_vel = 0
    tt_val = (target_track & 0b11) << 5

    # Sort notes by start
    t.notes.sort(key=lambda x: x.start)

    # Trim durations to prevent overlap
    for (a, b) in itertools.pairwise(t.notes):
        if a.start == b.start:
            raise Exception('overlapping note! ' + tick_time(a.start))
        if a.end > b.start:
            a.end = b.start

    for i in range(0, len(t.notes)):
        n = t.notes[i]

        # note on
        note = n.pitch % 12
        octave = int(n.pitch / 12) - 1
        if octave < 0:
            raise Exception('note is too low: ' + str(n))
        if n.end - n.start < 10:
            raise Exception('runt note ' + tick_time(n.start))
        # 62500 is 2MHz divided by the 32x counter
        cv = round(62500 / freq_table[note][octave])
        if cv > 2**10 - 1:
            raise Exception('impossible counter value: ' + str(cv) + ' ' + tick_time(n.start))
        start_event = SnEvents(tick_convert(n.start), [0x80 | tt_val | (cv & 0xF), (cv & 0x3F0) >> 4])
        events.append(start_event)
        if n.velocity != current_vel:
            start_event.bytes.extend([0x90 | tt_val | 0xF - (n.velocity >> 3)])
            current_vel = n.velocity

        if i < len(t.notes) - 1 and n.end == t.notes[i+1].start:
            # continuous notes; don't send an off command
            continue

        # note off
        events.append(SnEvents(tick_convert(n.end), [0x9F | tt_val]))
        current_vel = 0
    return events

def normalize(*il):
    max_velocity = 0
    for i in il:
        max_velocity = max(max_velocity, max(map(lambda n: n.velocity, i.notes)))
    print('max velocity:', max_velocity)
    scale = 100 / max_velocity
    for i in il:
        for n in i.notes:
            n.velocity = min(int(n.velocity * scale), 127)

def merge_tracks(*tl):
    el = functools.reduce(lambda a, b: a + b, list(tl))
    el.sort(key=lambda x: x.t)
    el = itertools.groupby(el, lambda x: x.t)
    return el

def convert_to_sn_stream(sns):
    time_counter = 0
    buf = []
    for t, evs in sns:
        if t > time_counter:
            buf.extend(wait(t - time_counter))
            time_counter = t
        cmd_bytes = functools.reduce(lambda a, b: a + b, map(lambda x: x.bytes, evs))
        if len(cmd_bytes) > 10:
            raise Exception('suspiciously long command ' + str(t / 50) + 's: ' + str(cmd_bytes))
        buf.extend([len(cmd_bytes)])
        buf.extend(cmd_bytes);
    return bytes(buf)


print(len(mobj.instruments), 'tracks')
for i in range(0, len(mobj.instruments)):
    print('track', i, 'len', len(mobj.instruments[i].notes))
instruments = [mobj.instruments[x] for x in [0, 1, 2]]
normalize(*instruments)
t1 = convert_track(instruments[0], 0)
print('t1', len(t1))
t2 = convert_track(instruments[1], 1)
print('t2', len(t2))
t3 = convert_track(instruments[2], 2)
print('t3', len(t3))
el = merge_tracks(t1, t2, t3)
out = open(sys.argv[2], 'wb')
out.write(convert_to_sn_stream(el))
out.close()