/keyfile.py
import pygame
import imp
import os.path
from bmevent import *

class KeyFile:
	bmelist = None
	numkeys = 0

	# Loaded from a file
	player = None
	genre = None
	title = None
	artist = None
	stagefile = None
	playlevel = -1
	rank = 3
	stagefile = None
	volwav = 1.0
	wavs = None
	bmps = None
	offset = 0

	# Allows us to continuously map time to the current beat
	beatfunc = None

	# Keymapping is a sequence of keycodes, mapping the index to the
	# key
	def __init__(self):
		self.bmelist = []
		self.wavs = {}
		self.bmps = {}

	def add(self,bme):
		self.bmelist.append(bme)

	def remove(self,bme):
		return self.bmelist.remove(bme)

	def sort(self):
		def sortfun(a,b):
			if a.beat == b.beat:
				return 0
			elif a.beat > b.beat:
				return 1
			elif a.beat < b.beat:
				return -1
		self.bmelist.sort(sortfun)

	def dump(self):
		for b in self.bmelist:
			print str(b)

	# AWFUL DIRTY NO GOOD HACK
	# (but does find the end of the song within 10ms)
	def length(self):
		if len(self.beatfunc) >= 2:
			t = self.beatfunc[-2][0]
		else:
			t = 0
		last_beat = self.bmelist[-1].beat
		b = self.eval_beatfunc(t)
		while b < last_beat:
			t+= 10
			b = self.eval_beatfunc(t)
		return t

	# It's time to get func-y
	def generate_beatfunc(self):
		self.beatfunc = []
		bpms = filter(lambda x: x.type & (BME_TEMPO | BME_LONGMEASURE | BME_STOP), self.bmelist)
		self.beatfunc = self.generate_beatfunc_r(bpms)

	def generate_beatfunc_r(self, bpms, ct=0):
		if len(bpms) == 0:
			return [(3600000.0,0)]
		beat = bpms[0].beat
		type = bpms[0].type

		if type == BME_TEMPO:
			ms_per_beat = 60000.0 / bpms[0].dataref
			func = lambda t: (t - ct) / ms_per_beat + beat
			if len(bpms) == 1:
				next = 4000
			else:
				next = bpms[1].beat
			duration = (next - beat) * ms_per_beat
			self.lastbpm = bpms[0].dataref
		elif type == BME_LONGMEASURE:
			# Long measures only last one measure, so we do
			# the slow measure just like a tempo change,
			# then add a tempo change back at the end.
			ms_per_beat = (60000.0 / self.lastbpm) * bpms[0].dataref
			func = lambda t: (t - ct) / ms_per_beat + beat

			print "Last BPM:",self.lastbpm
			duration = 4 * ms_per_beat
			bpms.insert(1,BMEvent(beat+4,BME_TEMPO,0,self.lastbpm))
		elif type == BME_STOP:
			func = lambda t: beat
			duration = bpms[0].dataref
		else:
			raise Exception("WTF? Invalid type in generate_beatfunc_r")
		l = self.generate_beatfunc_r(bpms[1:], ct + duration)
		l.insert(0,(ct,func))
		return l

	def eval_beatfunc(self,t):
		for n in range(0,len(self.beatfunc)-1):
			if self.beatfunc[n][0] <= t and self.beatfunc[n+1][0] > t:
				return self.beatfunc[n][1](t)

	def show_beatfunc(self,surface):
		end = self.length()
		xscale = end / surface.get_width()
		yscale = self.bmelist[-1].beat / surface.get_height()

		for t in range(0,end,100):
			beat = self.eval_beatfunc(t)
			if beat:
				pygame.draw.circle(surface,(255,255,255),(t/xscale,480 - beat / yscale),1)
			else:
				pygame.draw.line(surface,(255,0,0), (t/140.625,0), (t/140.625,480))


class BMEListIter:
	bmelist = None

	def __init__(self,bmelist):
		self.bmelist = bmelist 
		self.b = 0.0
		self.i = 0

	def goto(self,b):
		l = len(self.bmelist) - 1
		self.b = b

		while self.i > 0 and b < self.bmelist[self.i].beat:
			self.i -= 1
		while self.i < l and b > self.bmelist[self.i].beat:
			self.i += 1

	def window(self,db,type=None):
		l = len(self.bmelist)
		eb = self.b + db
		ei = self.i

		while ei < l and eb >= self.bmelist[ei].beat:
			ei += 1

		if type:
			return filter(lambda x: x.type & type, self.bmelist[self.i:ei])
		else:
			return self.bmelist[self.i:ei]


screen = None
font = None

loaders = []
for x in ["BMloader","SMloader"]:
	f = open("loaders/" + x + ".py")
	loaders.append(imp.load_module(x,f,x + ".py",(".py",'r',imp.PY_SOURCE)))
	#f.close()

def vmessage(message):
	fs = font.render(message, 0, (255,255,255),(0,0,0))
	screen.fill((0,0,0),(0,450,640,30))
	screen.blit(fs,(0,450))

def vstatus(type,arg):
	if type == "STAGEFILE":
		screen.blit(arg,(0,0))
	elif type == "WAV":
		vmessage("Loaded WAV " + arg)
	elif type == "BMP":
		vmessage("Loaded BMP " + arg)
	elif type == "TRACK":
		vmessage("Parsing track " + str(arg))
	elif type == "ERROR":
		vmessage("ERROR: " + arg)
	pygame.display.flip()

def likelihood(file):
	likelihoods = map(lambda l: l.detect(file), loaders)
	gl = 0.0
	gn = None
	for n in range(0,len(likelihoods)):
		if likelihoods[n] > gl:
			gl = likelihoods[n]
			gn = n
	return gn

def kf_load(file):
	global screen,font
	screen = pygame.display.get_surface()
	font = pygame.font.SysFont("Helvetica Normal",30)
	gn = likelihood(file)
	if gn != None:
		print "Load..."
		kf = loaders[gn].load(file,vstatus)
		kf.generate_beatfunc()
		return kf
	else:
		return None
	font = None

def kf_info(file):
	gn = likelihood(file)
	d = loaders[gn].info(file)
	d['loader.name'] = loaders[gn].name
	d['loader.version'] = loaders[gn].version

	return d